Source code for xpmir.letor.records

import torch
import itertools
from datamaestro_ir.data import (
    IDTextRecord,
    SimpleTextItem,
)

from typing import (
    Generic,
    Iterable,
    List,
    Optional,
    Tuple,
    TypeVar,
    TypedDict,
    Union,
)
from typing_extensions import ReadOnly


TopicRecord = DocumentRecord = IDTextRecord


## TypeddDicts


class ScoreRecord(TypedDict):
    """A record with just a score"""

    score: ReadOnly[float]


class ScoreDocumentRecord(IDTextRecord, ScoreRecord):
    """A record with an ID, a text item and a score"""

    pass


## Dataclasses / generics

DocT = TypeVar("DocT")
DocT2 = TypeVar("DocT2")
QueryT = TypeVar("QueryT")
QueryT2 = TypeVar("QueryT2")


class SampleItem(Generic[DocT, QueryT]):
    """Base class for sample items with document/query access.

    All item types (pointwise, pairwise, listwise) expose a uniform
    interface for extracting and replacing documents/queries, enabling
    generic batch processing in RecordsProcessor.
    """

    def get_documents(self) -> List[DocT]:
        raise NotImplementedError

    def with_documents(self, ds: "List[DocT2]") -> "SampleItem[DocT2, QueryT]":
        raise NotImplementedError

    def get_queries(self) -> List[QueryT]:
        raise NotImplementedError

    def with_queries(self, qs: "List[QueryT2]") -> "SampleItem[DocT, QueryT2]":
        raise NotImplementedError


[docs] class PointwiseItem(SampleItem[DocT, QueryT]): """An Item from a pointwise sampler""" # The query topic: QueryT # The document document: DocT # The relevance relevance: Optional[float] def __init__( self, topic: QueryT, document: DocT, relevance: Optional[float] = None, ): self.topic = topic self.document = document self.relevance = relevance @property def query(self): return self.topic def get_queries(self) -> List[QueryT]: return [self.topic] def with_queries(self, qs: "List[QueryT2]") -> "PointwiseItem[DocT, QueryT2]": return PointwiseItem(qs[0], self.document, self.relevance) def get_documents(self) -> List[DocT]: return [self.document] def with_documents(self, ds: "List[DocT2]") -> "PointwiseItem[DocT2, QueryT]": return PointwiseItem(self.topic, ds[0], self.relevance)
RT = TypeVar("RT") class BaseItems(List[RT]): """Base items just exposes iterables on (query, document) pairs items can be structured, i.e. the same queries and documents can be used more than once. To allow optimization (e.g. pre-computing document/query representation), """ topics: Iterable[IDTextRecord] documents: Iterable[IDTextRecord] is_product = False def __repr__(self): return f"<{self.__class__.__name__}(count={len(self)})>" @property def unique_topics(self) -> List[IDTextRecord]: return list(self.topics) unique_queries = unique_topics @property def unique_documents(self) -> List[IDTextRecord]: return list(self.documents) @property def queries(self): """Deprecated: use topics""" return self.topics def to(self, *args, **kwargs): """Moves the records to a device (e.g. for the relevance matrix in ProductItems) Default implementation does nothing, but can be implemented by specific records that have tensors as attributes (e.g. ProductItems) """ return self def pairs(self) -> Tuple[Iterable[int], Iterable[int]]: """Returns two iterators (over queries and documents) Returns the list of query/document indices for which we should compute the score, or None if all (cartesian product). This method should be used with `unique` set to true to get the queries/documents """ raise NotImplementedError(f"pairs() in {self.__class__}") def __getitem__(self, ix: Union[slice, int]): """Sub-sample""" raise NotImplementedError(f"__getitem__() in {self.__class__}") def __len__(self): """Returns the number of records The length is dependant on the type of records, and is mainly used to divide the data into batches """ raise NotImplementedError(f"__len__() in {self.__class__}") def __iter__(self): for i in range(len(self)): yield self[i] class PointwiseItems(BaseItems[PointwiseItem]): """Pointwise items are a set of triples (query, document, relevance)""" # The queries topics: List[IDTextRecord] # Text of the documents documents: List[IDTextRecord] # The relevances relevances: List[float] def __init__(self): self.topics = [] self.documents = [] self.relevances = [] def add(self, record: PointwiseItem): self.queries.append(record.query) self.relevances.append(record.relevance or 0) self.documents.append(record.document) def __len__(self): return len(self.queries) def __getitem__(self, ix: Union[slice, int]): if isinstance(ix, slice): records = PointwiseItems() for i in range(ix.start, min(ix.stop, len(self.topics)), ix.step or 1): records.add( PointwiseItem(self.topics[i], self.documents[i], self.relevances[i]) ) return records return PointwiseItem(self.topics[ix], self.documents[ix], self.relevances[ix]) def pairs(self) -> Tuple[List[int], List[int]]: ix = list(range(len(self.queries))) return (ix, ix) @staticmethod def from_texts( topics: List[str], documents: List[str], relevances: Optional[List[float]] = None, ): records = PointwiseItems() records.topics = list(map(lambda t: {"text_item": SimpleTextItem(t)}, topics)) records.documents = list( map(lambda t: {"text_item": SimpleTextItem(t)}, documents) ) records.relevances = relevances return records
[docs] class PairwiseItem(SampleItem[DocT, QueryT]): """A pairwise record is composed of a query, a positive and a negative document""" query: QueryT positive: DocT negative: DocT def __init__(self, query: QueryT, positive: DocT, negative: DocT): self.query = query self.positive = positive self.negative = negative def get_queries(self) -> List[QueryT]: return [self.query] def with_queries(self, qs: "List[QueryT2]") -> "PairwiseItem[DocT, QueryT2]": return PairwiseItem(qs[0], self.positive, self.negative) def get_documents(self) -> List[DocT]: return [self.positive, self.negative] def with_documents(self, ds: "List[DocT2]") -> "PairwiseItem[DocT2, QueryT]": return PairwiseItem(self.query, ds[0], ds[1])
class PairwiseItemWithTarget(PairwiseItem): """A pairwise Item composed of a query, a positive and a negative document, and the indetifier which says the one on the first is pos or neg """ target: int def __init__( self, query: IDTextRecord, positive: IDTextRecord, negative: IDTextRecord, target: int, ): super().__init__(query, positive, negative) self.target = target class PairwiseItems(BaseItems): """Pairwise records of queries associated with (positive, negative) pairs""" # The queries _topics: List[IDTextRecord] # The document IDs (positive) positives: List[IDTextRecord] # The scores of the retriever negatives: List[IDTextRecord] def __init__(self): self._topics = [] self.positives = [] self.negatives = [] def add(self, record: PairwiseItem): self._topics.append(record.query) self.positives.append(record.positive) self.negatives.append(record.negative) @property def topics(self): return itertools.chain(self._topics, self._topics) def set_unique_topics(self, topics: List[IDTextRecord]): assert len(topics) == len(self._topics), ( f"Number of topics do not match ({len(topics)} vs {len(self._topics)})" ) self._topics = topics def set_unique_documents(self, documents: List[IDTextRecord]): N = len(self._topics) assert len(documents) == N * 2 self.positives = documents[:N] self.negatives = documents[N:] queries = topics @property def unique_topics(self): return self._topics unique_queries = unique_topics @property def documents(self): return itertools.chain(self.positives, self.negatives) def pairs(self): """Returns the list of query/document indices for which we should compute the score, or None if all (cartesian product). This method should be used with `unique_topics`""" indices = list(range(len(self._topics))) return indices * 2, list(range(2 * len(self.positives))) def __len__(self): return len(self._topics) def __getitem__(self, ix: Union[slice, int]): if isinstance(ix, slice): records = PairwiseItems() for i in range(ix.start, min(ix.stop, len(self._topics)), ix.step or 1): records.add( PairwiseItem(self._topics[i], self.positives[i], self.negatives[i]) ) return records return PairwiseItem(self._topics[ix], self.positives[ix], self.negatives[ix]) class PairwiseItemsWithTarget(PairwiseItems): """Pairwise items associated with a label (saying which document is better)""" target: List[int] def __init__(self): super().__init__() self.target = [] def add(self, record: PairwiseItemWithTarget): self._topics.append(record.query) self.positives.append(record.positive) self.negatives.append(record.negative) self.target.append(record.target) def get_target(self): return self.target def __getitem__(self, ix: Union[slice, int]): if isinstance(ix, slice): records = PairwiseItemsWithTarget() for i in range(ix.start, min(ix.stop, len(self._topics)), ix.step or 1): records.add( PairwiseItemWithTarget( self._topics[i], self.positives[i], self.negatives[i], self.target[i], ) ) return records return PairwiseItemWithTarget( self._topics[ix], self.positives[ix], self.negatives[ix], self.target[ix] )
[docs] class ListwiseItem(SampleItem[DocT, QueryT]): """A listwise Item is a generic data class composed of a query and a list of documents""" query: QueryT documents: List[DocT] def __init__(self, query: QueryT, documents: List[DocT]): self.query = query self.documents = documents def get_queries(self) -> List[QueryT]: return [self.query] def with_queries(self, qs: "List[QueryT2]") -> "ListwiseItem[DocT, QueryT2]": return ListwiseItem(qs[0], self.documents) def get_documents(self) -> List[DocT]: return self.documents def with_documents(self, ds: "List[DocT2]") -> "ListwiseItem[DocT2, QueryT]": return ListwiseItem(self.query, list(ds))
class ListwiseItems(BaseItems): """Listwise items of queries associated with lists of documents""" # The queries _topics: List[IDTextRecord] # The list of documents per query _documents: List[List[IDTextRecord]] def __init__(self): self._topics = [] self._documents = [] def add(self, record: ListwiseItem): self._topics.append(record.query) self._documents.append(record.documents) @property def topics(self): return itertools.chain(self._topics, self._topics) def set_unique_topics(self, topics: List[IDTextRecord]): assert len(topics) == len(self._topics), ( f"Number of topics do not match ({len(topics)} vs {len(self._topics)})" ) self._topics = topics def set_unique_documents(self, documents: List[IDTextRecord]): raise NotImplementedError( f"set_unique_documents() in {self.__class__.__name__}" ) queries = topics @property def unique_topics(self): return self._topics unique_queries = unique_topics @property def documents(self): return itertools.chain.from_iterable(self._documents) def pairs(self): indices = list(range(len(self._topics))) return indices * 2, list(range(2 * len(self._documents))) def __len__(self): return len(self._topics) def __getitem__(self, ix: Union[slice, int]): if isinstance(ix, slice): records = ListwiseItems() for i in range(ix.start, min(ix.stop, len(self._topics)), ix.step or 1): records.add(ListwiseItem(self._topics[i], self._documents[i])) return records return ListwiseItem(self._topics[ix], self._documents[ix])
[docs] class BatchwiseItems(BaseItems): """Several documents (with associated [pseudo]relevance) per query Assumes that the number of documents per query is always the same (even though documents themselves can be different) """ relevances: torch.Tensor def __getitem__(self, ix: Union[slice, int]): """Sub-sample""" raise NotImplementedError(f"__getitem__() in {self.__class__.__name__}")
class ProductItems(BatchwiseItems): """Computes the score for all the documents and queries The relevance matrix Attributes: _topics: The list of queries _documents: The list of documents _relevances: (query x document) matrix with relevance score (between 0 and 1) """ _topics: List[IDTextRecord] """The list of queries to score""" _documents: List[IDTextRecord] """The list of documents to score""" relevances: torch.Tensor """A 2D tensor (query x document) indicating the relevance of the each query/document pair""" is_product = True def __init__(self): self._topics = [] self._documents = [] def add_topics(self, *topics: IDTextRecord): self._topics.extend(topics) def add_documents(self, *documents: IDTextRecord): self._documents.extend(documents) def set_relevances(self, relevances: torch.Tensor): assert relevances.shape[0] == len(self._topics), ( f"The number of queries {len(self._topics)} " + "does not match the number of rows {relevances.shape[0]}" ) assert relevances.shape[1] == len(self._documents), ( f"The number of documents {len(self._documents)} " + "does not match the number of columns {relevances.shape[1]}" ) self.relevances = relevances def __len__(self): return len(self._topics) @property def topics(self): for q in self._topics: for _ in self._documents: yield q queries = topics @property def unique_topics(self): return self._topics unique_queries = unique_topics @property def documents(self): for _ in self._topics: for d in self._documents: yield d @property def unique_documents(self) -> Iterable[IDTextRecord]: return self._documents def pairs(self) -> Tuple[Iterable[int], Iterable[int]]: topics = [] documents = [] for q in range(len(self._topics)): for d in range(len(self._documents)): topics.append(q) documents.append(d) return topics, documents def __getitem__(self, ix: Union[slice, int]): if isinstance(ix, slice): start, stop, step = ix.indices(len(self._topics)) records = ProductItems() # add selected topics for i in range(start, stop, step): records.add_topics(self._topics[i]) # keep same documents for d in self._documents: records.add_documents(d) # slice relevances rows if present if hasattr(self, "relevances") and self.relevances is not None: # build list of row indices and index into tensor rows = list(range(start, stop, step)) records.set_relevances(self.relevances[rows]) return records # integer index: return a PointwiseItems containing this topic # paired with every document (preserving relevances when present) if ix < 0: ix += len(self._topics) if ix < 0 or ix >= len(self._topics): raise IndexError("ProductItems index out of range") records = PointwiseItems() topic = self._topics[ix] if hasattr(self, "relevances") and self.relevances is not None: row = self.relevances[ix] try: rels = list(row.tolist()) except Exception: rels = list(row) for d, r in zip(self._documents, rels): records.add(PointwiseItem(topic, d, float(r))) else: for d in self._documents: records.add(PointwiseItem(topic, d, None)) return records class DocumentRecords(List[IDTextRecord]): """Masked Language Modeling Records are a set of documents""" # Text of the documents documents: List[IDTextRecord] def __init__(self): super().__init__() self.documents = [] def add(self, record: IDTextRecord): self.documents.append(record) def __len__(self): return len(self.documents) def __getitem__(self, ix: Union[slice, int]): if isinstance(ix, slice): records = DocumentRecords() for i in range(ix.start, min(ix.stop, len(self)), ix.step or 1): records.add(self.documents[i]) return records return DocumentRecords(self.documents[ix]) @staticmethod def from_texts( documents: List[str], ): records = DocumentRecords() records.documents = list(documents) return records def to_texts(self) -> List[str]: texts = [] for doc in self.documents: texts.append(doc.document["text_item"].text) return texts