Source code for xpmir.letor.distillation.listwise

import sys
from typing import List
from typing_extensions import ReadOnly, TypedDict
import torch
from torch import nn, Tensor
from experimaestro import Config, Param, field
import torch.nn.functional as F
from xpmir.text import TokenizedTexts
from xpmir.letor.records import (
    PointwiseItem,
    PointwiseItems,
)
from xpm_torch.trainers import TrainerContext, LossTrainer
from xpm_torch.losses import Loss, ModuleOutputType, bce_with_logits_loss

from .samplers import ListwiseDistillationSample
import numpy as np
from xpmir.rankers import AbstractModuleScorer

### Losses


[docs] class DistillationListwiseLoss(Config, nn.Module): """The abstract loss for listwise distillation""" weight: Param[float] = field(default=1.0, ignore_default=True) NAME = "?" def initialize(self, ranker: AbstractModuleScorer): pass def process( self, student_scores: Tensor, teacher_scores: Tensor, info: TrainerContext ): loss = self.compute(student_scores, teacher_scores, info) info.add_loss(Loss(f"listwise-{self.NAME}", loss, self.weight)) def compute( self, student_scores: Tensor, teacher_scores: Tensor, context: TrainerContext ) -> torch.Tensor: """ Compute the loss Arguments: student_scores: A (batch x 2) tensor teacher_scores: A (batch x 2) tensor """ raise NotImplementedError()
[docs] class DistillRankNetLoss(DistillationListwiseLoss): """Adaptation of the pairwise RankNET loss to lists of passages ranked by a LLM. Follows Rank-DistiLLM: Closing the Effectiveness Gap Between Cross-Encoders and LLMs for Passage Re-Ranking, 2025 """ NAME = "DistillRankNET" def initialize(self, ranker): super().initialize(ranker) self.loss = torch.nn.functional.binary_cross_entropy_with_logits @staticmethod def get_pairwise_idcs(targets: torch.Tensor) -> tuple[torch.Tensor, ...]: """Get pairwise indices for positive and negative samples based on targets. Function copied from the official implementation of Rank-DistiLLM: https://github.com/webis-de/lightning-ir/blob/main/lightning_ir/loss/base.py#L131 """ # positive items are items where label is greater than other label in sample return torch.nonzero(targets[..., None] > targets[:, None], as_tuple=True) def compute( self, student_scores: Tensor, teacher_scores: Tensor, context: TrainerContext ) -> torch.Tensor: """ Compute the DistillRankNet loss Arguments: student_scores: A (batch x num_docs) tensor teacher_scores: A (batch x num_docs) tensor """ query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(teacher_scores) pos = student_scores[query_idcs, pos_idcs] neg = student_scores[query_idcs, neg_idcs] margin = pos - neg loss = self.loss(margin, torch.ones_like(margin)) return loss
[docs] class ADR_MSE(DistillationListwiseLoss): """New loss to distill from lists of passages ranked by LLM, proposed by Rank-DistiLLM: Closing the Effectiveness Gap Between Cross-Encoders and LLMs for Passage Re-Ranking, 2025 """ NAME = "ADR-MSE" def initialize(self, ranker): super().initialize(ranker) self.loss = nn.MSELoss(reduction="none") self.discount = "log2" self.temperature = 1 @staticmethod def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor: """Compute approximate ranks from scores. Function copied from the official implementation of Rank-DistiLLM: https://github.com/webis-de/lightning-ir/blob/main/lightning_ir/loss/approximate.py#L34 """ score_diff = scores[:, None] - scores[..., None] normalized_score_diff = torch.sigmoid(score_diff / temperature) # set diagonal to 0 normalized_score_diff = normalized_score_diff * ( 1 - torch.eye(scores.shape[1], device=scores.device) ) approx_ranks = normalized_score_diff.sum(-1) + 1 return approx_ranks def compute( self, student_scores: Tensor, teacher_scores: Tensor, context: TrainerContext ) -> torch.Tensor: """ Compute the ADR-MSE loss Arguments: student_scores: A (batch x num_docs) tensor teacher_scores: A (batch x num_docs) tensor """ student_ranks = self.get_approx_ranks(student_scores, self.temperature) # teacher ranks are integer (Long) after argsort; cast to student's dtype/device teacher_ranks = ( torch.argsort(torch.argsort(teacher_scores, descending=True)) + 1 ) teacher_ranks = teacher_ranks.to( dtype=student_ranks.dtype, device=student_ranks.device ) loss = self.loss(student_ranks, teacher_ranks) if self.discount == "log2": weight = 1 / torch.log2(teacher_ranks + 1) else: weight = 1 loss = loss * weight loss = loss.mean() return loss
[docs] class ListwiseSoftmaxCrossEntropy(DistillationListwiseLoss): """Reproduces the original `SoftmaxCrossEntropy` behavior used in batchwise losses, adapted to listwise distillation. The original formula is: `-logsumexp(normalize(scores) + (1 - 1.0 / relevances), dim=-1).mean()` where `normalize` depends on the model output type. """ NAME = "infonce" def initialize(self, ranker: AbstractModuleScorer): super().initialize(ranker) self.normalize = { ModuleOutputType.REAL: lambda x: F.log_softmax(x, -1), ModuleOutputType.LOG_PROBABILITY: lambda x: x, ModuleOutputType.PROBABILITY: lambda x: x.log(), }[ranker.outputType] def compute( self, student_scores: Tensor, teacher_scores: Tensor, context: TrainerContext ) -> torch.Tensor: # teacher_scores used as "relevances" in the original formula. # Guard against zeros to avoid division-by-zero. eps = 1e-8 rel = teacher_scores.clone() rel = torch.where( rel == 0, torch.tensor(eps, device=rel.device, dtype=rel.dtype), rel ) term = self.normalize(student_scores) + (1.0 - 1.0 / rel) # sum over documents, mean over queries loss = -torch.logsumexp(term, dim=-1).sum() / student_scores.shape[0] return loss
[docs] class ListwiseInfoNCE(DistillationListwiseLoss): """Standard InfoNCE loss for listwise supervised training. This loss expects binary relevance labels (1 for positive, 0 for negative). If multiple positives are present, it averages the cross-entropy loss over them. TODO CHECK """ NAME = "listwise-infonce" def initialize(self, ranker: AbstractModuleScorer): super().initialize(ranker) self.normalize = { ModuleOutputType.REAL: lambda x: F.log_softmax(x, -1), ModuleOutputType.LOG_PROBABILITY: lambda x: x, ModuleOutputType.PROBABILITY: lambda x: x.log(), }[ranker.outputType] def compute( self, student_scores: Tensor, teacher_scores: Tensor, context: TrainerContext ) -> torch.Tensor: # teacher_scores are binary (1 for positive, 0 for negative) log_probs = self.normalize(student_scores) # Binary mask for positives is_positive = (teacher_scores > 0).float() # Number of positives per query num_positives = is_positive.sum(dim=-1, keepdim=True) # Target distribution: uniform over positives targets = is_positive / torch.clamp(num_positives, min=1.0) # Cross entropy: -sum(targets * log_probs) # This computes the average log-probability of positive documents loss = -(targets * log_probs).sum(dim=-1) # Mask out queries with no positives to avoid contributing to the mean mask = (num_positives > 0).float().squeeze(-1) loss = (loss * mask).sum() / torch.clamp(mask.sum(), min=1.0) return loss
[docs] class ListwiseBCE(DistillationListwiseLoss): """Point-wise cross-entropy loss for listwise samples. Computes BCE for each document in the list, adapting to the student's output type. """ NAME = "bce" def initialize(self, ranker: AbstractModuleScorer): super().initialize(ranker) self.output_type = ranker.outputType if self.output_type == ModuleOutputType.REAL: self.loss_fn = nn.BCEWithLogitsLoss() elif self.output_type == ModuleOutputType.PROBABILITY: self.loss_fn = nn.BCELoss() elif self.output_type == ModuleOutputType.LOG_PROBABILITY: # Using your custom autograd function for log-space stability self.loss_fn = bce_with_logits_loss else: raise NotImplementedError(f"Output type {self.output_type} not supported.") def compute( self, student_scores: Tensor, teacher_scores: Tensor, context: TrainerContext ) -> torch.Tensor: # student_scores: (Batch, NumPassages) # teacher_scores: (Batch, NumPassages) - binary 1s and 0s # Ensure targets match the student's precision and device targets = teacher_scores.to(student_scores.dtype) if self.output_type == ModuleOutputType.LOG_PROBABILITY: # Your custom BCEWithLogLoss expects vectors (1D tensors) # Flattening ensures compatibility regardless of batch size/list length return self.loss_fn(student_scores.flatten(), targets.flatten()) # For standard nn.Modules, (Batch, NumPassages) is handled automatically return self.loss_fn(student_scores, targets)
[docs] class ListwiseHingeLoss(DistillationListwiseLoss): """Pairwise Hinge loss for listwise samples. Computes max(0, margin - (s_pos - s_neg)) for each negative. This implementation assumes a fixed number of positives per query (typically 1) as provided by the DistillationNegativesSampler. """ NAME = "hinge" margin: Param[float] = field(default=1.0, ignore_default=True) def compute( self, student_scores: Tensor, teacher_scores: Tensor, context: TrainerContext ) -> torch.Tensor: # teacher_scores are binary (1 for positive, 0 for negative) is_positive = teacher_scores > 0 # We assume that each query has at least one positive and one negative # DistillationNegativesSampler gives exactly 1 pos and passages_per_query - 1 negs. # Flattening and reshaping is risky unless shapes are guaranteed. # If using fixed 1-pos, many-neg: batch_size = student_scores.shape[0] pos_scores = student_scores[is_positive].reshape(batch_size, -1) # (B, P) neg_scores = student_scores[~is_positive].reshape(batch_size, -1) # (B, N) # (B, num_pos, 1) - (B, 1, num_neg) -> (B, num_pos, num_neg) loss = F.relu(self.margin - pos_scores.unsqueeze(2) + neg_scores.unsqueeze(1)) return loss.mean()
### Trainer class DistillationListwiseInputs(TypedDict): records: ReadOnly[PointwiseItems] tokenized_records: ReadOnly[TokenizedTexts] teacher_scores: ReadOnly[Tensor] def distillation_listwise_collate( samples: List[ListwiseDistillationSample], ) -> DistillationListwiseInputs: """Collate function for Distillation Listwise trainer""" teacher_scores = torch.empty(len(samples), len(samples[0].documents)) records = PointwiseItems() for ix, sample in enumerate(samples): for doc in sample.documents: records.add(PointwiseItem(sample.query, doc.document, doc.score)) teacher_scores[ix] = torch.tensor([doc.score for doc in sample.documents]) return DistillationListwiseInputs( records=records, tokenized_records=None, teacher_scores=teacher_scores )
[docs] class DistillationListwiseTrainer(LossTrainer): """Listwise trainer for distillation""" lossfn: Param[DistillationListwiseLoss] """The distillation pairwise batch function""" def initialize( self, random: np.random.RandomState, context: TrainerContext, ): super().initialize(random, context) self.lossfn.initialize(self.model) for loss in context.hooks(DistillationListwiseLoss): loss.initialize(self.model) self.sampler.initialize(random) dataset = self.sampler.as_dataset() # if we can extract the tokenization function from model, we wrap the collate with it. if hasattr(self.model, "get_tokenizer_fn"): tokenization_fn = self.model.get_tokenizer_fn() def collate_fn_with_tokenization( samples: List[ListwiseDistillationSample], ) -> DistillationListwiseInputs: inputs = distillation_listwise_collate(samples) inputs["tokenized_records"] = tokenization_fn(inputs["records"]) return inputs collate_fn = collate_fn_with_tokenization else: collate_fn = distillation_listwise_collate self._create_dataloader(dataset, collate_fn=collate_fn) def train_batch(self, inputs: DistillationListwiseInputs): # Builds records and teacher score matrix records, teacher_scores, tokenized_records = ( inputs["records"], inputs["teacher_scores"], inputs.get("tokenized_records", None), ) # Get the next batch and compute the scores for each query/document scores = self.model(records, tokenized=tokenized_records) if torch.isnan(scores).any() or torch.isinf(scores).any(): self.logger.error( "nan or inf relevance score detected. Aborting (listwise distillation)." ) sys.exit(1) # Call the losses (distillation, pairwise and pointwise) teacher_scores = teacher_scores.to(scores.device) self.lossfn.process( scores.reshape_as(teacher_scores), teacher_scores, self.context )