Source code for xpmir.letor.records

import torch
import itertools
from datamaestro_text.data.ir import (
    TopicRecord,
    DocumentRecord,
    TextItem,
    create_record,
)
from typing import (
    Iterable,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)


[docs]class PointwiseRecord: """A record from a pointwise sampler""" # The query topic: TopicRecord # The document document: DocumentRecord # The relevance relevance: Optional[float] def __init__( self, topic: TopicRecord, document: DocumentRecord, relevance: Optional[float] = None, ): self.topic = topic self.document = document self.relevance = relevance @property def query(self): return self.topic
RT = TypeVar("RT") class BaseRecords(List[RT]): """Base records just exposes iterables on (query, document) pairs Records can be structured, i.e. the same queries and documents can be used more than once. To allow optimization (e.g. pre-computing document/query representation), """ topics: Iterable[TopicRecord] documents: Iterable[DocumentRecord] is_product = False @property def unique_topics(self) -> Iterable[TopicRecord]: return self.topics unique_queries = unique_topics @property def unique_documents(self) -> Iterable[DocumentRecord]: return self.documents @property def queries(self): """Deprecated: use topics""" return self.topics def pairs(self) -> Tuple[Iterable[int], Iterable[int]]: """Returns two iterators (over queries and documents) Returns the list of query/document indices for which we should compute the score, or None if all (cartesian product). This method should be used with `unique` set to true to get the queries/documents """ raise NotImplementedError(f"pairs() in {self.__class__}") def __getitem__(self, ix: Union[slice, int]): """Sub-sample""" raise NotImplementedError(f"__getitem__() in {self.__class__}") def __len__(self): """Returns the number of records The length is dependant on the type of records, and is mainly used to divide the data into batches """ raise NotImplementedError(f"__len__() in {self.__class__}") class PointwiseRecords(BaseRecords[PointwiseRecord]): """Pointwise records are a set of triples (query, document, relevance)""" # The queries topics: List[TopicRecord] # Text of the documents documents: List[DocumentRecord] # The relevances relevances: List[float] def __init__(self): self.topics = [] self.documents = [] self.relevances = [] def add(self, record: PointwiseRecord): self.queries.append(record.query) self.relevances.append(record.relevance or 0) self.documents.append(record.document) def __len__(self): return len(self.queries) def __getitem__(self, ix: Union[slice, int]): if isinstance(ix, slice): records = PointwiseRecords() for i in range(ix.start, min(ix.stop, len(self.topics)), ix.step or 1): records.add( PointwiseRecord( self.topics[i], self.documents[i], self.relevances[i] ) ) return records return PointwiseRecord(self.topics[ix], self.documents[ix], self.relevances[ix]) def pairs(self) -> Tuple[List[int], List[int]]: ix = list(range(len(self.queries))) return (ix, ix) @staticmethod def from_texts( topics: List[str], documents: List[str], relevances: Optional[List[float]] = None, ): records = PointwiseRecords() records.topics = list(map(lambda t: create_record(text=t), topics)) records.documents = list(map(lambda t: create_record(text=t), documents)) records.relevances = relevances return records
[docs]class PairwiseRecord: """A pairwise record is composed of a query, a positive and a negative document""" query: TopicRecord positive: DocumentRecord negative: DocumentRecord def __init__( self, query: TopicRecord, positive: DocumentRecord, negative: DocumentRecord ): self.query = query self.positive = positive self.negative = negative
class PairwiseRecordWithTarget(PairwiseRecord): """A pairwise record is composed of a query, a positive and a negative document, and the indetifier which says the one on the first is pos or neg""" target: int def __init__( self, query: TopicRecord, positive: DocumentRecord, negative: DocumentRecord, target: int, ): super().__init__(query, positive, negative) self.target = target class PairwiseRecords(BaseRecords): """Pairwise records of queries associated with (positive, negative) pairs""" # The queries _topics: List[TopicRecord] # The document IDs (positive) positives: List[DocumentRecord] # The scores of the retriever negatives: List[DocumentRecord] def __init__(self): self._topics = [] self.positives = [] self.negatives = [] def add(self, record: PairwiseRecord): self._topics.append(record.query) self.positives.append(record.positive) self.negatives.append(record.negative) @property def topics(self): return itertools.chain(self._topics, self._topics) def set_unique_topics(self, topics: List[TopicRecord]): assert len(topics) == len( self._topics ), f"Number of topics do not match ({len(topics)} vs {len(self._topics)})" self._topics = topics def set_unique_documents(self, documents: List[DocumentRecord]): N = len(self._topics) assert len(documents) == N * 2 self.positives = documents[:N] self.negatives = documents[N:] queries = topics @property def unique_topics(self): return self._topics unique_queries = unique_topics @property def documents(self): return itertools.chain(self.positives, self.negatives) def pairs(self): indices = list(range(len(self._topics))) return indices * 2, list(range(2 * len(self.positives))) def __len__(self): return len(self._topics) def __getitem__(self, ix: Union[slice, int]): if isinstance(ix, slice): records = PairwiseRecords() for i in range(ix.start, min(ix.stop, len(self._topics)), ix.step or 1): records.add( PairwiseRecord( self._topics[i], self.positives[i], self.negatives[i] ) ) return records return PairwiseRecord(self._topics[ix], self.positives[ix], self.negatives[ix]) class PairwiseRecordsWithTarget(PairwiseRecords): """Pairwise records associated with a label (saying which document is better)""" target: List[int] def __init__(self): super().__init__() self.target = [] def add(self, record: PairwiseRecordWithTarget): self._topics.append(record.query) self.positives.append(record.positive) self.negatives.append(record.negative) self.target.append(record.target) def get_target(self): return self.target def __getitem__(self, ix: Union[slice, int]): if isinstance(ix, slice): records = PairwiseRecordsWithTarget() for i in range(ix.start, min(ix.stop, len(self._topics)), ix.step or 1): records.add( PairwiseRecordWithTarget( self._topics[i], self.positives[i], self.negatives[i], self.target[i], ) ) return records return PairwiseRecordWithTarget( self._topics[ix], self.positives[ix], self.negatives[ix], self.target[ix] ) class BatchwiseRecords(BaseRecords): """Several documents (with associated [pseudo]relevance) per query Assumes that the number of documents per query is always the same (even though documents themselves can be different) """ relevances: torch.Tensor def __getitem__(self, ix: Union[slice, int]): """Sub-sample""" raise NotImplementedError(f"__getitem__() in {self.__class__}") class ProductRecords(BatchwiseRecords): """Computes the score for all the documents and queries The relevance matrix Attributes: _topics: The list of queries _documents: The list of documents _relevances: (query x document) matrix with relevance score (between 0 and 1) """ _topics: List[TopicRecord] """The list of queries to score""" _documents: List[DocumentRecord] """The list of documents to score""" relevances: torch.Tensor """A 2D tensor (query x document) indicating the relevance of the each query/document pair""" is_product = True def __init__(self): self._topics = [] self._documents = [] def add_topics(self, *topics: TopicRecord): self._topics.extend(topics) def add_documents(self, *documents: DocumentRecord): self._documents.extend(documents) def set_relevances(self, relevances: torch.Tensor): assert relevances.shape[0] == len(self._topics), ( f"The number of queries {len(self._topics)} " + "does not match the number of rows {relevances.shape[0]}" ) assert relevances.shape[1] == len(self._documents), ( f"The number of documents {len(self._documents)} " + "does not match the number of columns {relevances.shape[1]}" ) self.relevances = relevances def __len__(self): return len(self._topics) @property def topics(self): for q in self._topics: for _ in self._documents: yield q queries = topics @property def unique_topics(self): return self._topics unique_queries = unique_topics @property def documents(self): for _ in self._topics: for d in self._documents: yield d @property def unique_documents(self): return self._documents def pairs(self) -> Tuple[Iterable[int], Iterable[int]]: topics = [] documents = [] for q in range(len(self._topics)): for d in range(len(self._documents)): topics.append(q) documents.append(d) return topics, documents class DocumentRecords(List[DocumentRecord]): """Masked Language Modeling Records are a set of documents""" # Text of the documents documents: List[DocumentRecord] def __init__(self): super().__init__() self.documents = [] def add(self, record: DocumentRecord): self.documents.append(record) def __len__(self): return len(self.documents) def __getitem__(self, ix: Union[slice, int]): if isinstance(ix, slice): records = DocumentRecord() for i in range(ix.start, min(ix.stop, len(self)), ix.step or 1): records.add(self.documents[i]) return records return DocumentRecords(self.documents[ix]) @staticmethod def from_texts( documents: List[str], ): records = DocumentRecords() records.documents = list(documents) return records def to_texts(self) -> List[str]: texts = [] for doc in self.documents: texts.append(doc.document[TextItem].text) return texts