Source code for xpmir.letor.samplers.hydrators

from abc import ABC, abstractmethod
from typing import Iterator, Optional, List, Any
from experimaestro import Config, Param
import numpy as np
import datamaestro_text.data.ir.base as ir
from datamaestro_text.data.ir import DocumentStore, IDItem
from xpmir.datasets.adapters import TextStore
from xpmir.letor.samplers import PairwiseSampler
from xpmir.letor.records import (
    PairwiseRecords,
    PairwiseRecord,
)
from xpmir.utils.iter import (
    SerializableIterator,
    SkippingIterator,
    SerializableIteratorTransform,
)


[docs]class SampleTransform(Config, ABC): @abstractmethod def transform_topics( self, topics: Iterator[ir.TopicRecord] ) -> Optional[List[ir.TopicRecord]]: ... @abstractmethod def transform_documents( self, documents: Iterator[ir.DocumentRecord] ) -> Optional[List[ir.DocumentRecord]]: ...
[docs]class SampleHydrator(SampleTransform): """Base class for document/topic hydrators""" documentstore: Param[Optional[DocumentStore]] """The store for document texts if needed""" querystore: Param[Optional[TextStore]] """The store for query texts if needed""" def transform_topics(self, topics: List[ir.TopicRecord]): if self.querystore is None: return None return [ ir.create_record( id=topic[IDItem].id, text=self.querystore[topic[IDItem].id] ) for topic in topics ] def transform_documents(self, documents: List[ir.DocumentRecord]): if self.documentstore is None: return None return self.documentstore.documents_ext([d[IDItem].id for d in documents])
[docs]class SamplePrefixAdding(SampleTransform): """Transform the query and documents by adding the prefix""" query_prefix: Param[str] = "" """The prefix for the query""" document_prefix: Param[str] = "" """The prefix for the document""" def transform_topics( self, topics: List[ir.TopicRecord] ) -> Optional[List[ir.TopicRecord]]: if self.query_prefix == "" or len(topics) == 0: return None if isinstance(topics[0], ir.GenericTopic): return [ ir.GenericTopic(topic[IDItem].id, self.query_prefix + topic.text) for topic in topics ] elif isinstance(topics[0], ir.TextTopic): return [ir.TextTopic(self.query_prefix + topic.text) for topic in topics] def transform_documents( self, documents: List[ir.DocumentRecord] ) -> Optional[List[ir.DocumentRecord]]: if self.document_prefix == "" or len(documents) == 0: return None if isinstance(documents[0], ir.GenericDocument): return [ ir.GenericDocument( document[IDItem].id, self.document_prefix + document.text ) for document in documents ] elif isinstance(documents[0], ir.TextDocument): return [ ir.TextDocument(self.document_prefix + document.text) for document in documents ]
[docs]class SampleTransformList(SampleTransform): """A class which group a list of sample transforms""" adapters: Param[List[SampleTransform]] """The list of sample transform to be applied""" def transform_topics(self, topics: List[ir.TopicRecord]) -> List[ir.TopicRecord]: for adapter in self.adapters: topics = adapter.transform_topics(topics) or topics return topics def transform_documents( self, documents: List[ir.DocumentRecord] ) -> List[ir.DocumentRecord]: for adapter in self.adapters: documents = adapter.transform_documents(documents) or documents return documents
[docs]class PairwiseTransformAdapter(PairwiseSampler): """Transforms pairwise samples using an adapter It is interesting to use this adapter since the transformation is only performed if the samples are used: when using a SkippingIterator, when recovering a checkpoint, all the records might have to be processed otherwise. """ sampler: Param[PairwiseSampler] """The distillation samples without texts for query and documents""" adapter: Param[SampleTransform] """The transformation""" def initialize(self, random: Optional[np.random.RandomState] = None): super().initialize(random) self.sampler.initialize(random) def transform_record(self, record: PairwiseRecord) -> PairwiseRecord: topics = [record.query] docs = [record.positive, record.negative] topics = self.adapter.transform_topics(topics) or topics docs = self.adapter.transform_documents(docs) or docs return PairwiseRecord(topics[0], docs[0], docs[1]) def pairwise_iter(self) -> Iterator[PairwiseRecord]: iterator = self.sampler.pairwise_iter() return SerializableIteratorTransform( SkippingIterator.make_serializable(iterator), self.transform_record ) def transform_records(self, records: PairwiseRecords) -> PairwiseRecords: if topics := self.adapter.transform_topics( topic for topic in records.unique_topics ): records.set_unique_topics(topics) if documents := self.adapter.transform_documents( document for document in records.unique_documents ): records.set_unique_documents(documents) return records def pairwise_batch_iter(self, size) -> SerializableIterator[PairwiseRecords, Any]: iterator = self.sampler.pairwise_batch_iter(size) return SerializableIteratorTransform( SkippingIterator.make_serializable(iterator), self.transform_records )