Knowledge distillation

Knowledge distillation trains a student model to mimic the output distribution of a stronger teacher model. This is commonly used to transfer the accuracy of a cross-encoder teacher into a faster bi-encoder or sparse student.

Samplers

Samplers that pair documents with teacher scores for distillation training.

XPM Configxpmir.letor.distillation.samplers.DistillationPairwiseSampler(*, samples)[source]

Bases: Sampler

Just loops over samples

samples: datamaestro_ir.data.distillation.PairwiseDistillationSamples
XPM Taskxpmir.letor.samplers.TeacherModelBasedHardNegativesTripletSampler(*, sampler, document_store, topic_store, teacher_model)[source]

Bases: Task, Sampler

Submit type: xpm_torch.base.Sampler[xpmir.letor.records.PairwiseItem]

Builds a teacher file for pairwise distillation losses

sampler: xpm_torch.base.Sampler[xpmir.letor.records.PairwiseItem]

The list of exsting hard negatives which we can sample from

document_store: datamaestro_ir.data.DocumentStore

The document store

topic_store: xpmir.datasets.adapters.TextStore

The query_document store

teacher_model: xpmir.rankers.scorer.Scorer

The teacher model which scores the positive and negative document

hard_negative_triplet: pathgenerated

The path to store the generated triplets

Trainer

XPM Configxpmir.letor.distillation.pairwise.DistillationPairwiseTrainer(*, hooks, model, sampler, batch_size, num_workers, lossfn)[source]

Bases: LossTrainer

Pairwise trainer for distillation

hooks: List[xpm_torch.trainers.context.TrainingHook] = []

Hooks for this trainer: this includes the losses, but can be adapted for other uses The specific list of hooks depends on the specific trainer

model: xpm_torch.module.Module

If the model to optimize is different from the model passsed to Learn, this parameter can be used – initialization is still expected to be done at the learner level

batcher: xpm_torch.batchers.Batchergenerated

How to batch samples together

sampler: xpm_torch.base.Sampler

The sampler to use

batch_size: int = 16

Number of samples per batch

num_workers: int = 2

Number of DataLoader workers

lossfn: xpmir.letor.distillation.pairwise.DistillationPairwiseLoss

The distillation pairwise batch function

Pairwise losses

Losses operating on pairs of documents with teacher scores.

XPM Configxpmir.letor.distillation.pairwise.DistillationPairwiseLoss(*, weight)[source]

Bases: Config, Module

The abstract loss for pairwise distillation

weight: float = 1.0
compute(student_scores: Tensor, teacher_scores: Tensor, context: TrainerContext) Tensor[source]

Compute the loss

Parameters:
  • student_scores – A (batch x 2) tensor

  • teacher_scores – A (batch x 2) tensor

XPM Configxpmir.letor.distillation.pairwise.MSEDifferenceLoss(*, weight)[source]

Bases: DistillationPairwiseLoss

Computes the MSE between the score differences

Compute ((student 1 - student 2) - (teacher 1 - teacher 2))**2

weight: float = 1.0
XPM Configxpmir.letor.distillation.pairwise.DistillationKLLoss(*, weight)[source]

Bases: DistillationPairwiseLoss

Distillation loss from: Distilling Dense Representations for Ranking using Tightly-Coupled Teachers https://arxiv.org/abs/2010.11386

weight: float = 1.0
XPM Configxpmir.letor.distillation.listwise.ListwiseInfoNCE(*, weight)[source]

Bases: DistillationListwiseLoss

Standard InfoNCE loss for listwise supervised training.

This loss expects binary relevance labels (1 for positive, 0 for negative). If multiple positives are present, it averages the cross-entropy loss over them. TODO CHECK

weight: float = 1.0

Data types

class xpmir.letor.distillation.samplers.PairwiseDistillationSample(query: QueryT, documents: Tuple[DocT, DocT])[source]

Bases: Generic[DocT, QueryT]

documents: Tuple[DocT, DocT]

Positive/negative document with teacher scores

query: QueryT

The query