Source code for xpmir.letor.trainers.pairwise

from dataclasses import InitVar
import math
import sys
from typing import Iterator, Any
import torch
from torch import nn
from torch.functional import Tensor
import torch.nn.functional as F
from experimaestro import Config, Param
from xpmir.learning.context import Loss
from xpmir.learning.metrics import ScalarMetric
from xpmir.letor.records import (
    PairwiseRecord,
    PairwiseRecordWithTarget,
    PairwiseRecords,
    PairwiseRecordsWithTarget,
)
from xpmir.letor.samplers import PairwiseSampler, SerializableIterator
from xpmir.letor.trainers import TrainerContext, LossTrainer
import numpy as np
from xpmir.rankers import LearnableScorer, ScorerOutputType
from xpmir.utils.utils import foreach
from xpmir.utils.iter import MultiprocessSerializableIterator


[docs]class PairwiseLoss(Config, nn.Module): """Base class for any pairwise loss""" NAME = "?" weight: Param[float] = 1.0 """The weight :math:`w` with which the loss is multiplied (useful when combining with other ones)""" def initialize(self, ranker: LearnableScorer): pass def process(self, scores: Tensor, context: TrainerContext): value = self.compute(scores, context) context.add_loss(Loss(f"pair-{self.NAME}", value, self.weight))
[docs] def compute(self, scores: Tensor, info: TrainerContext) -> Tensor: """Computes the loss :param scores: A (batch x 2) tensor (positive/negative) :param info: the trainer context :return: a torch scalar """ raise NotImplementedError()
[docs]class CrossEntropyLoss(PairwiseLoss): r"""Cross-Entropy Loss Computes the cross-entropy loss Classification loss (relevant vs non-relevant) where the logit is equal to the difference between the relevant and the non relevant document (or equivalently, softmax then mean log probability of relevant documents) Reference: C. Burges et al., “Learning to rank using gradient descent,” 2005. *warning*: this loss assumes the score returned by the scorer is a logit .. math:: \frac{w}{N} \sum_{(s^+,s-)} \log \frac{\exp(s^+)}{\exp(s^+)+\exp(s^-)} """ NAME = "cross-entropy" def compute(self, rel_scores_by_record, info: TrainerContext): target = ( torch.zeros(rel_scores_by_record.shape[0]) .long() .to(rel_scores_by_record.device) ) return F.cross_entropy(rel_scores_by_record, target, reduction="mean")
[docs]class HingeLoss(PairwiseLoss): r"""Hinge (or max-margin) loss .. math:: \frac{w}{N} \sum_{(s^+,s-)} \max(0, m - (s^+ - s^-)) """ NAME = "hinge" margin: Param[float] = 1.0 """The margin for the Hinge loss""" def compute(self, rel_scores_by_record, info: TrainerContext): return F.relu( self.margin - rel_scores_by_record[:, 0] + rel_scores_by_record[:, 1] ).mean()
class BCEWithLogLoss(nn.Module): """Custom cross-entropy loss when outputs are log probabilities""" def __call__(self, log_probs: torch.Tensor, info: TrainerContext): # Assumes target is a two column matrix (rel. / not rel.) assert torch.all(log_probs < 0.0) loss = -log_probs[:, 0].sum() - (1.0 - log_probs[:, 1].exp()).log().sum() return loss / log_probs.numel()
[docs]class PointwiseCrossEntropyLoss(PairwiseLoss): r"""Point-wise cross-entropy loss This is a point-wise loss adapted as a pairwise one. This loss adapts to the ranker output type: - If real, uses a BCELossWithLogits (sigmoid transformation) - If probability, uses the BCELoss - If log probability, uses a BCEWithLogLoss .. math:: \frac{w}{2N} \sum_{(s^+,s-)} \log \frac{\exp(s^+)}{\exp(s^+)+\exp(s^-)} + \log \frac{\exp(s^-)}{\exp(s^+)+\exp(s^-)} """ NAME = "pointwise-cross-entropy" def initialize(self, ranker: LearnableScorer): super().initialize(ranker) self.rankerOutputType = ranker.outputType if ranker.outputType == ScorerOutputType.REAL: self.loss = nn.BCEWithLogitsLoss() elif ranker.outputType == ScorerOutputType.PROBABILITY: self.loss = nn.BCELoss() elif ranker.outputType == ScorerOutputType.LOG_PROBABILITY: self.loss = BCEWithLogLoss() else: raise Exception("Not implemented") def compute(self, rel_scores_by_record, info: TrainerContext): if self.rankerOutputType == ScorerOutputType.LOG_PROBABILITY: return self.loss(rel_scores_by_record, info) device = rel_scores_by_record.device dim = rel_scores_by_record.shape[0] target = torch.cat( (torch.ones(dim, device=device), torch.zeros(dim, device=device)) ) return self.loss(rel_scores_by_record.T.flatten(), target)
[docs]class PairwiseTrainer(LossTrainer): """Pairwise trainer uses samples of the form (query, positive, negative)""" lossfn: Param[PairwiseLoss] """The loss function""" sampler: Param[PairwiseSampler] """The pairwise sampler""" sampler_iter: InitVar[SerializableIterator[PairwiseRecord, Any]] def initialize( self, random: np.random.RandomState, context: TrainerContext, ): super().initialize(random, context) self.lossfn.initialize(self.ranker) foreach(context.hooks(PairwiseLoss), lambda loss: loss.initialize(self.ranker)) self.sampler.initialize(random) self.sampler_iter = MultiprocessSerializableIterator( self.sampler.pairwise_batch_iter(self.batch_size) ) def train_batch(self, records: PairwiseRecords): # Get the next batch and compute the scores for each query/document rel_scores = self.ranker(records, self.context) if torch.isnan(rel_scores).any() or torch.isinf(rel_scores).any(): self.logger.error("nan or inf relevance score detected. Aborting.") sys.exit(1) # Reshape to get the pairs and compute the loss pairwise_scores = rel_scores.reshape(2, len(records)).T self.lossfn.process(pairwise_scores, self.context) self.context.add_metric( ScalarMetric( "accuracy", float(self.acc(pairwise_scores).item()), len(rel_scores) ) ) def acc(self, scores_by_record) -> Tensor: with torch.no_grad(): count = scores_by_record.shape[0] * (scores_by_record.shape[1] - 1) return ( scores_by_record[:, :1] > scores_by_record[:, 1:] ).sum().float() / count
class PairwiseLossWithTarget(Config): NAME = "?" weight: Param[float] = 1.0 def initialize(self, ranker: LearnableScorer): pass def process(self, scores: Tensor, targets: Tensor, context: TrainerContext): value = self.compute(scores, targets, context) context.add_loss(Loss(f"duo-{self.NAME}", value, self.weight))
[docs]class PairwiseLossWithTarget(PairwiseLossWithTarget): NAME = "logproba" def initialize(self, ranker: LearnableScorer): self.loss = { ScorerOutputType.REAL: nn.BCEWithLogitsLoss, ScorerOutputType.LOG_PROBABILITY: None, ScorerOutputType.PROBABILITY: nn.BCELoss, }[ranker.outputType]() def compute(self, scores: Tensor, targets: Tensor, context: TrainerContext): return self.loss(scores, targets)
[docs]class DuoPairwiseTrainer(LossTrainer): """The pairwise trainer for duobert. The iter_batch method can be the same as the pairwiseTrainer """ lossfn: Param[PairwiseLossWithTarget] """The loss function""" sampler: Param[PairwiseSampler] """The pairwise sampler""" sampler_iter: InitVar[SerializableIterator[PairwiseRecord, Any]] def initialize(self, random: np.random.RandomState, context: TrainerContext): super().initialize(random, context) self.lossfn.initialize(self.ranker) self.score_threshold = { ScorerOutputType.LOG_PROBABILITY: math.log(0.5), ScorerOutputType.PROBABILITY: 0.5, ScorerOutputType.REAL: 0, }[self.ranker.outputType] foreach(context.hooks(PairwiseLoss), lambda loss: loss.initialize(self.ranker)) self.sampler.initialize(random) self.sampler_iter = self.sampler.pairwise_iter() def iter_batches(self) -> Iterator[PairwiseRecordsWithTarget]: while True: batch = PairwiseRecordsWithTarget() for _, record in zip(range(self.batch_size), self.sampler_iter): # randomly swap the first and second document if self.random.random() < 0.5: batch.add( PairwiseRecordWithTarget( record.query, record.positive, record.negative, 1 ) ) else: batch.add( PairwiseRecordWithTarget( record.query, record.negative, record.positive, 0 ) ) yield batch def train_batch(self, records: PairwiseRecords): # Get the next batch and compute the scores for each query/document # forward pass rel_scores = self.ranker(records, self.context) # shape: (bs) targets = torch.Tensor(records.get_target()).to(rel_scores.device) if torch.isnan(rel_scores).any() or torch.isinf(rel_scores).any(): self.logger.error("nan or inf relevance score detected. Aborting.") sys.exit(1) # Reshape to get the pairs and compute the loss self.lossfn.process( rel_scores, torch.Tensor(targets), self.context, ) self.context.add_metric( ScalarMetric( "accuracy", self.acc( rel_scores, torch.Tensor(targets), ).item(), len(rel_scores), ) ) def acc(self, scores_by_record, target) -> Tensor: with torch.no_grad(): positives = scores_by_record > self.score_threshold return (positives == target).sum() / len(positives)