Source code for xpmir.letor.distillation.pairwise

import sys
from typing import List, TypedDict
from typing_extensions import ReadOnly
import numpy as np
import torch
from torch import nn
from torch.functional import Tensor
from experimaestro import Config, Param, field
from xpmir.letor.records import (
    PairwiseItem,
    PairwiseItems,
)
from xpm_torch.trainers import TrainerContext, LossTrainer
from xpm_torch.losses import Loss

from .samplers import PairwiseDistillationSample

from xpmir.text import TokenizedTexts
from xpmir.rankers import AbstractModuleScorer


[docs] class DistillationPairwiseLoss(Config, nn.Module): """The abstract loss for pairwise distillation""" weight: Param[float] = field(default=1.0, ignore_default=True) NAME = "?" def initialize(self, ranker: AbstractModuleScorer): pass def process( self, student_scores: Tensor, teacher_scores: Tensor, info: TrainerContext ): loss = self.compute(student_scores, teacher_scores, info) info.add_loss(Loss(f"pairwise-{self.NAME}", loss, self.weight))
[docs] def compute( self, student_scores: Tensor, teacher_scores: Tensor, context: TrainerContext ) -> torch.Tensor: """ Compute the loss Arguments: student_scores: A (batch x 2) tensor teacher_scores: A (batch x 2) tensor """ raise NotImplementedError()
[docs] class MSEDifferenceLoss(DistillationPairwiseLoss): """Computes the MSE between the score differences Compute ((student 1 - student 2) - (teacher 1 - teacher 2))**2 """ NAME = "delta-MSE" def initialize(self, ranker): super().initialize(ranker) self.loss = nn.MSELoss() def compute( self, student_scores: Tensor, teacher_scores: Tensor, info: TrainerContext ) -> torch.Tensor: return self.loss( student_scores[:, 1] - student_scores[:, 0], teacher_scores[:, 1] - teacher_scores[:, 0], )
[docs] class DistillationKLLoss(DistillationPairwiseLoss): """ Distillation loss from: Distilling Dense Representations for Ranking using Tightly-Coupled Teachers https://arxiv.org/abs/2010.11386 """ NAME = "Distil-KL" def initialize(self, ranker): super().initialize(ranker) self.loss = nn.KLDivLoss(reduction="none") def compute( self, student_scores: Tensor, teacher_scores: Tensor, info: TrainerContext ) -> torch.Tensor: pos_student = student_scores[:, 0].unsqueeze(0) neg_student = student_scores[:, 1].unsqueeze(0) pos_teacher = teacher_scores[:, 0].unsqueeze(0) neg_teacher = teacher_scores[:, 1].unsqueeze(0) scores = torch.cat([pos_student, neg_student], dim=1) local_scores = torch.log_softmax(scores, dim=1) teacher_scores = torch.cat( [pos_teacher.unsqueeze(-1), neg_teacher.unsqueeze(-1)], dim=1 ) teacher_scores = torch.softmax(teacher_scores, dim=1) return self.loss(local_scores, teacher_scores).sum(dim=1).mean(dim=0)
class DistillationPairwiseInputs(TypedDict): """A record with just a text item""" records: ReadOnly[PairwiseItems] tokenized_records: ReadOnly[TokenizedTexts] teacher_scores: ReadOnly[Tensor] def distillation_pairwise_collate( samples: List[PairwiseDistillationSample], ) -> DistillationPairwiseInputs: """Collate function for Distillation Pairwise trainer Args: samples: List of pairwise distillation samples transform_records: A function to transform the records before feeding them to the model. """ teacher_scores = torch.empty(len(samples), 2) records = PairwiseItems() for ix, sample in enumerate(samples): records.add( PairwiseItem( sample.query, sample.documents[0].document, # positive sample.documents[1].document, # negative ) ) teacher_scores[ix, 0] = sample.documents[0].score teacher_scores[ix, 1] = sample.documents[1].score return DistillationPairwiseInputs( records=records, tokenized_records=None, teacher_scores=teacher_scores )
[docs] class DistillationPairwiseTrainer(LossTrainer): """Pairwise trainer for distillation""" lossfn: Param[DistillationPairwiseLoss] """The distillation pairwise batch function""" def initialize( self, random: np.random.RandomState, context: TrainerContext, ): super().initialize(random, context) self.lossfn.initialize(self.model) for loss in context.hooks(DistillationPairwiseLoss): loss.initialize(self.model) self.sampler.initialize(random) dataset = self.sampler.as_dataset() # if we can extract the tokenization function from model, we wrap the collate with it. if hasattr(self.model, "get_tokenizer_fn"): tokenization_fn = self.model.get_tokenizer_fn() def collate_fn_with_tokenization( samples: List[PairwiseDistillationSample], ) -> DistillationPairwiseInputs: inputs = distillation_pairwise_collate(samples) inputs["tokenized_records"] = tokenization_fn(inputs["records"]) return inputs collate_fn = collate_fn_with_tokenization else: collate_fn = distillation_pairwise_collate self._create_dataloader(dataset, collate_fn=collate_fn) def train_batch(self, inputs: DistillationPairwiseInputs): # Builds records and teacher score matrix records, teacher_scores, tokenized_records = ( inputs["records"], inputs["teacher_scores"], inputs["tokenized_records"], ) # teacher_scores_ = torch.empty(len(records), 2) # for ix, record in enumerate(records): # teacher_scores_[ix, 0] = record.positive_document["score"] # teacher_scores_[ix, 1] = record.negative_document["score"] # Get the next batch and compute the scores for each query/document pair scores = ( self.model(records, tokenized=tokenized_records).reshape(2, len(records)).T ) if torch.isnan(scores).any() or torch.isinf(scores).any(): self.logger.error( "nan or inf relevance score detected. Aborting (pairwise distillation)." ) sys.exit(1) # Call the losses (distillation, pairwise and pointwise) teacher_scores = teacher_scores.to( scores.device ) # no op with fabric but ensures that the teacher scores are on the same device as the student scores self.lossfn.process(scores, teacher_scores, self.context)