Source code for xpmir.letor.samplers.hydrators

from typing import Iterator, Optional, List, Any
from experimaestro import Config, Param

import datamaestro_text.data.ir.base as ir
from datamaestro_text.data.ir import DocumentStore
from xpmir.datasets.adapters import TextStore
from xpmir.letor.samplers import PairwiseSampler
from xpmir.letor.records import (
    PairwiseRecords,
    PairwiseRecord,
    TopicRecord,
    DocumentRecord,
)
from xpmir.utils.iter import (
    SerializableIterator,
    SkippingIterator,
    SerializableIteratorTransform,
)


[docs]class SampleTransform(Config): pass
[docs]class SampleHydrator(SampleTransform): """Base class for document/topic hydrators""" documentstore: Param[Optional[DocumentStore]] """The store for document texts if needed""" querystore: Param[Optional[TextStore]] """The store for query texts if needed""" def transform_topics(self, topics: List[ir.Topic]): if self.querystore is None: return None return ( ir.GenericTopic(topic.get_id(), self.querystore[topic.get_id()]) for topic in topics ) def transform_documents(self, documents: List[ir.Document]): if self.documentstore is None: return None return self.documentstore.documents_ext([d.id for d in documents])
[docs]class PairwiseTransformAdapter(PairwiseSampler): """Transforms pairwise samples using an adapter It is interesting to use this adapter since the transformation is only performed if the samples are used: when using a SkippingIterator, when recovering a checkpoint, all the records might have to be processed otherwise. """ sampler: Param[PairwiseSampler] """The distillation samples without texts for query and documents""" adapter: Param[SampleTransform] """The transformation""" def transform_record(self, record: PairwiseRecord) -> PairwiseRecord: (topic,) = self.adapter.transform_topics([record.query.topic]) pos, neg = self.adapter.transform_documents( [record.positive.document, record.negative.document] ) return PairwiseRecord( TopicRecord(topic), DocumentRecord(pos), DocumentRecord(neg) ) def pairwise_iter(self) -> Iterator[PairwiseRecord]: iterator = self.sampler.pairwise_iter() return SerializableIteratorTransform( SkippingIterator.make_serializable(iterator), self.transform_record ) def transform_records(self, records: PairwiseRecords) -> PairwiseRecord: if topics := self.adapter.transform_topics( [tr.topic for tr in records.unique_topics] ): records.set_unique_topics([TopicRecord(topic) for topic in topics]) if documents := self.adapter.transform_documents( [dr.document for dr in records.unique_documents] ): records.set_unique_documents( [DocumentRecord(document) for document in documents] ) return records def pairwise_batch_iter(self, size) -> SerializableIterator[PairwiseRecords, Any]: iterator = self.sampler.pairwise_batch_iter(size) return SerializableIteratorTransform( SkippingIterator.make_serializable(iterator), self.transform_records )