Source code for xpmir.letor.distillation.pairwise

import sys
from typing import List
import torch
from torch import nn
from torch.functional import Tensor
from experimaestro import Config, Param
from xpmir.letor.records import (
    DocumentRecord,
    PairwiseRecord,
    PairwiseRecords,
)
from xpmir.learning.context import Loss
from xpmir.letor.trainers import TrainerContext, LossTrainer
from xpmir.utils.utils import foreach
from .samplers import DistillationPairwiseSampler, PairwiseDistillationSample
from xpmir.utils.iter import MultiprocessSerializableIterator
import numpy as np
from xpmir.rankers import LearnableScorer


[docs]class DistillationPairwiseLoss(Config, nn.Module): """The abstract loss for pairwise distillation""" weight: Param[float] = 1.0 NAME = "?" def initialize(self, ranker: LearnableScorer): pass def process( self, student_scores: Tensor, teacher_scores: Tensor, info: TrainerContext ): loss = self.compute(student_scores, teacher_scores, info) info.add_loss(Loss(f"pairwise-{self.NAME}", loss, self.weight))
[docs] def compute( self, student_scores: Tensor, teacher_scores: Tensor, context: TrainerContext ) -> torch.Tensor: """ Compute the loss Arguments: student_scores: A (batch x 2) tensor teacher_scores: A (batch x 2) tensor """ raise NotImplementedError()
[docs]class MSEDifferenceLoss(DistillationPairwiseLoss): """Computes the MSE between the score differences Compute ((student 1 - student 2) - (teacher 1 - teacher 2))**2 """ NAME = "delta-MSE" def initialize(self, ranker): super().initialize(ranker) self.loss = nn.MSELoss() def compute( self, student_scores: Tensor, teacher_scores: Tensor, info: TrainerContext ) -> torch.Tensor: return self.loss( student_scores[:, 1] - student_scores[:, 0], teacher_scores[:, 1] - teacher_scores[:, 0], )
[docs]class DistillationKLLoss(DistillationPairwiseLoss): """ Distillation loss from: Distilling Dense Representations for Ranking using Tightly-Coupled Teachers https://arxiv.org/abs/2010.11386 """ NAME = "Distil-KL" def initialize(self, ranker): super().initialize(ranker) self.loss = nn.KLDivLoss(reduction="none") def compute( self, student_scores: Tensor, teacher_scores: Tensor, info: TrainerContext ) -> torch.Tensor: pos_student = student_scores[:, 0].unsqueeze(0) neg_student = student_scores[:, 1].unsqueeze(0) pos_teacher = teacher_scores[:, 0].unsqueeze(0) neg_teacher = teacher_scores[:, 1].unsqueeze(0) scores = torch.cat([pos_student, neg_student], dim=1) local_scores = torch.log_softmax(scores, dim=1) teacher_scores = torch.cat( [pos_teacher.unsqueeze(-1), neg_teacher.unsqueeze(-1)], dim=1 ) teacher_scores = torch.softmax(teacher_scores, dim=1) return self.loss(local_scores, teacher_scores).sum(dim=1).mean(dim=0)
[docs]class DistillationPairwiseTrainer(LossTrainer): """Pairwise trainer for distillation""" sampler: Param[DistillationPairwiseSampler] """The sampler""" lossfn: Param[DistillationPairwiseLoss] """The distillation pairwise batch function""" def initialize( self, random: np.random.RandomState, context: TrainerContext, ): super().initialize(random, context) self.lossfn.initialize(self.ranker) foreach( context.hooks(DistillationPairwiseLoss), lambda loss: loss.initialize(self.ranker), ) self.sampler.initialize(random) self.sampler_iter = self.sampler.pairwise_iter() self.sampler_iter = MultiprocessSerializableIterator( self.sampler.pairwise_batch_iter(self.batch_size) ) def train_batch(self, samples: List[PairwiseDistillationSample]): # Builds records and teacher score matrix teacher_scores = torch.empty(len(samples), 2) records = PairwiseRecords() for ix, sample in enumerate(samples): records.add( PairwiseRecord( sample.query.as_record(), DocumentRecord(sample.documents[0].document), DocumentRecord(sample.documents[1].document), ) ) teacher_scores[ix, 0] = sample.documents[0].score teacher_scores[ix, 1] = sample.documents[1].score # Get the next batch and compute the scores for each query/document scores = self.ranker(records, self.context).reshape(2, len(records)).T if torch.isnan(scores).any() or torch.isinf(scores).any(): self.logger.error( "nan or inf relevance score detected. Aborting (pairwise distillation)." ) sys.exit(1) # Call the losses (distillation, pairwise and pointwise) teacher_scores = teacher_scores.to(scores.device) self.lossfn.process(scores, teacher_scores, self.context)