Source code for xpmir.letor.distillation.samplers

from typing import (
    Iterable,
    Iterator,
    NamedTuple,
    Tuple,
    List,
    Any,
)

import numpy as np
from datamaestro.data import File
from datamaestro_text.data.ir.base import (
    TopicRecord,
    DocumentRecord,
    ScoredItem,
    SimpleTextItem,
    IDItem,
    create_record,
)
from experimaestro import Config, Meta, Param
from xpmir.learning import Sampler
from xpmir.letor.samplers.hydrators import SampleHydrator
from xpmir.rankers import ScoredDocument
from xpmir.utils.iter import (
    SerializableIterator,
    SkippingIterator,
    SerializableIteratorTransform,
)


[docs]class PairwiseDistillationSample(NamedTuple): query: TopicRecord """The query""" documents: Tuple[DocumentRecord, DocumentRecord] """Positive/negative document with teacher scores"""
[docs]class PairwiseDistillationSamples(Config, Iterable[PairwiseDistillationSample]): """Pairwise distillation file""" def __iter__(self) -> Iterator[PairwiseDistillationSample]: raise NotImplementedError()
[docs]class PairwiseHydrator(PairwiseDistillationSamples, SampleHydrator): """Hydrate ID-based samples with document and/or query content""" samples: Param[PairwiseDistillationSamples] """The distillation samples without texts for query and documents""" def transform(self, sample: PairwiseDistillationSample): topic, documents = sample.query, sample.documents if transformed := self.querystore.transforme_topics(): topic = TopicRecord(*transformed) if transformed := self.documentstore.transforme_documents(): documents = tuple( ScoredDocument(d, sd.score) for d, sd in zip(transformed, sample.documents) ) return PairwiseDistillationSample(topic, documents) def __iter__(self) -> Iterator[PairwiseDistillationSample]: iterator = iter(self.samples) return SerializableIteratorTransform( SkippingIterator.make_serializable(iterator), self.transform )
[docs]class PairwiseDistillationSamplesTSV(PairwiseDistillationSamples, File): """A TSV file (Score 1, Score 2, Query, Document 1, Document 2)""" with_docid: Meta[bool] with_queryid: Meta[bool] def __iter__(self) -> Iterator[PairwiseDistillationSample]: return self.iter() def iter(self) -> Iterator[PairwiseDistillationSample]: import csv def iterate(): with self.path.open("rt") as fp: for row in csv.reader(fp, delimiter="\t"): if self.with_queryid: query = create_record(id=row[2]) else: query = create_record(text=row[2]) if self.with_docid: documents = ( DocumentRecord(IDItem(row[3]), ScoredItem(float(row[0]))), DocumentRecord(IDItem(row[4]), ScoredItem(float(row[1]))), ) else: documents = ( DocumentRecord( SimpleTextItem(row[3]), ScoredItem(float(row[0])) ), DocumentRecord( SimpleTextItem(row[4]), ScoredItem(float(row[1])) ), ) yield PairwiseDistillationSample(query, documents) return SkippingIterator(iterate())
[docs]class DistillationPairwiseSampler(Sampler): """Just loops over samples""" samples: Param[PairwiseDistillationSamples] def initialize(self, random: np.random.RandomState): super().initialize(random) def pairwise_iter(self) -> SerializableIterator[PairwiseDistillationSample, Any]: return SkippingIterator.make_serializable(iter(self.samples)) def pairwise_batch_iter( self, size ) -> SerializableIterator[List[PairwiseDistillationSample], Any]: """Batchwise iterator Can be subclassed by some classes to be more efficient""" class BatchIterator(SerializableIterator): def __init__(self, sampler: DistillationPairwiseSampler): self.iter = sampler.pairwise_iter() def state_dict(self): return self.iter.state_dict() def load_state_dict(self, state): self.iter.load_state_dict(state) def __next__(self): batch = [] for _, record in zip(range(size), self.iter): batch.append(record) return batch return BatchIterator(self)