Source code for xpmir.rankers

# This package contains all rankers

from abc import ABC, abstractmethod
from experimaestro import tqdm
from enum import Enum
from typing import (
    Dict,
    Generic,
    Iterable,
    List,
    Optional,
    Protocol,
    Tuple,
    TypeVar,
    Union,
    TYPE_CHECKING,
)
import torch
import torch.nn as nn
import attrs
from experimaestro import Param, Config, Meta
from datamaestro_text.data.ir import (
    Documents,
    DocumentStore,
    create_record,
    IDItem,
)
from datamaestro_text.data.ir.base import DocumentRecord
from xpmir.utils.utils import Initializable
from xpmir.letor import Device, Random
from xpmir.learning import ModuleInitMode, ModuleInitOptions
from xpmir.learning.batchers import Batcher
from xpmir.learning.context import TrainerContext
from xpmir.learning.optim import Module
from xpmir.letor.records import (
    TopicRecord,
    BaseRecords,
    PairwiseRecord,
    PairwiseRecords,
    ProductRecords,
)
from xpmir.utils.utils import EasyLogger, easylog

if TYPE_CHECKING:
    from xpmir.evaluation import RetrieverFactory

logger = easylog()


[docs]@attrs.define() class ScoredDocument: """A data structure that associated a score with a document""" document: DocumentRecord """The document""" score: float """The associated score""" def __repr__(self): return f"document({self.document}, {self.score})" def __lt__(self, other): return self.score < other.score
class ScorerOutputType(Enum): REAL = 0 """An unbounded scalar value""" LOG_PROBABILITY = 1 """A log probability, bounded by 0""" PROBABILITY = 2 """A probability, in ]0,1["""
[docs]class Scorer(Config, Initializable, EasyLogger, ABC): """Query-document scorer A model able to give a score to a list of documents given a query """ outputType: ScorerOutputType = ScorerOutputType.REAL """Determines the type of output scalar (log probability, probability, logit) """ def __initialize__(self, options: ModuleInitOptions): """Initialize the scorer :param options: Options for initialization """ pass def rsv( self, topic: Union[str, TopicRecord], documents: Union[List[ScoredDocument], ScoredDocument, str, List[str]], ) -> List[ScoredDocument]: # Convert into document records if isinstance(documents, str): documents = [ScoredDocument(create_record(text=documents), None)] elif isinstance(documents[0], str): documents = [ ScoredDocument(create_record(text=scored_document), None) for scored_document in documents ] # Convert into topic record if isinstance(topic, str): topic = create_record(text=topic) return self.compute(topic, documents) @abstractmethod def compute( self, topic: TopicRecord, documents: Iterable[ScoredDocument] ) -> List[ScoredDocument]: """Score all documents with respect to the topic""" ...
[docs] def eval(self): """Put the model in inference/evaluation mode""" pass
[docs] def to(self, device): """Move the scorer to another device""" pass
[docs] def getRetriever( self, retriever: "Retriever", batch_size: int, batcher: Batcher = Batcher(), top_k=None, device=None, ): """Returns a two stage re-ranker from this retriever and a scorer :param device: Device for the ranker or None if no change should be made :param batch_size: The number of documents in each batch :param top_k: Number of documents to re-rank (or None for all) """ return TwoStageRetriever( retriever=retriever, scorer=self, batchsize=batch_size, batcher=batcher, device=device, top_k=top_k if top_k else None, )
[docs]def scorer_retriever( documents: Documents, *, retrievers: "RetrieverFactory", scorer: Scorer, **kwargs, ): """Helper function that returns a two stage retriever. This is useful when used with partial (when the scorer is not known). :param documents: The document collection :param retrievers: A retriever factory :param scorer: The scorer :return: A retriever, calling the :meth:scorer.getRetriever """ assert retrievers is not None, "The retrievers have not been given" assert scorer is not None, "The scorer has not been given" return scorer.getRetriever(retrievers(documents), **kwargs)
[docs]class RandomScorer(Scorer): """A random scorer""" random: Param[Random] """The random number generator""" def compute( self, record: TopicRecord, scored_documents: Iterable[ScoredDocument] ) -> List[ScoredDocument]: result = [] random = self.random.state for scored_document in scored_documents: result.append(ScoredDocument(scored_document.document, random.random())) return result
class AbstractModuleScorerCall(Protocol): def __call__(self, inputs: "BaseRecords", info: Optional[TrainerContext]): ...
[docs]class AbstractModuleScorer(Scorer, Module): """Base class for all learnable scorer This class provides a `compute` method that calls the forward method, """ # Ensures basic operations are redirected to torch.nn.Module methods __call__: AbstractModuleScorerCall = nn.Module.__call__ to = nn.Module.to train = nn.Module.train def __init__(self): self.logger.info("Initializing %s", self) nn.Module.__init__(self) super().__init__() self._initialized = False def __str__(self): return f"scorer {self.__class__.__qualname__}" def eval(self): """Put the model in training mode""" self.train(False) def __initialize__(self, options: ModuleInitOptions): """Initialize a learnable scorer Initialization can either be determined by a checkpoint (if set) or otherwise (random or pre-trained checkpoint depending on the models) """ # Sets the current random seed if options.random is not None: seed = options.random.randint((2**32) - 1) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) return self def compute( self, topic: TopicRecord, scored_documents: Iterable[ScoredDocument] ) -> List[ScoredDocument]: # Prepare the inputs and call the model inputs = ProductRecords() inputs.add_topics(topic) inputs.add_documents(*[sd.document for sd in scored_documents]) with torch.no_grad(): scores = self(inputs, None).cpu().numpy() # Returns the scored documents scoredDocuments = [] for i in range(len(scored_documents)): scoredDocuments.append( ScoredDocument( scored_documents[i].document, float(scores[i]), ) ) return scoredDocuments
[docs]class LearnableScorer(AbstractModuleScorer): """Learnable scorer A scorer with parameters that can be learnt""" def forward(self, inputs: "BaseRecords", info: Optional[TrainerContext]): """Computes the score of all (query, document) pairs Different subclasses can process the input more or less efficiently based on the `BaseRecords` instance (pointwise, pairwise, or structured) """ raise NotImplementedError(f"forward in {self.__class__}")
[docs]class DuoLearnableScorer(LearnableScorer): """Base class for models that can score a triplet (query, document 1, document 2)""" def forward(self, inputs: "PairwiseRecords", info: Optional[TrainerContext]): """Returns scores for pairs of documents (given a query)""" raise NotImplementedError(f"abstract __call__ in {self.__class__}")
[docs]class Retriever(Config, ABC): """A retriever is a model to return top-scored documents given a query""" store: Param[Optional[DocumentStore]] = None """Give the document store associated with this retriever""" def initialize(self): pass
[docs] def collection(self): """Returns the document collection object""" raise NotImplementedError()
[docs] def retrieve_all( self, queries: Dict[str, TopicRecord] ) -> Dict[str, List[ScoredDocument]]: """Retrieves for a set of documents By default, iterate using `self.retrieve`, but this leaves some room open for optimization Args: queries: A dictionary where the key is the ID of the query, and the value is the text """ results = {} for key, record in tqdm(list(queries.items())): results[key] = self.retrieve(record) return results
[docs] @abstractmethod def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: """Retrieves documents, returning a list sorted by decreasing score if `content` is true, includes the document full text """ ...
def _store(self) -> Optional[DocumentStore]: """Returns the associated document store (if any) that can be used to get the full text of the documents""" def get_store(self) -> Optional[DocumentStore]: return self.store or self._store()
[docs]class AbstractTwoStageRetriever(Retriever): """Abstract class for all two stage retrievers (i.e. scorers and duo-scorers)""" retriever: Param[Retriever] """The base retriever""" scorer: Param[Scorer] """The scorer used to re-rank the documents""" top_k: Param[Optional[int]] = None """The number of returned documents (if None, returns all the documents)""" batchsize: Meta[int] = 0 """The batch size for the re-ranker""" batcher: Meta[Batcher] = Batcher() """How to provide batches of documents""" device: Meta[Optional[Device]] = None """Device on which the model is run""" def initialize(self): self.retriever.initialize() self._batcher = self.batcher.initialize(self.batchsize) self.scorer.initialize(ModuleInitMode.DEFAULT.to_options()) # Compute with the scorer if self.device is not None: self.scorer.to(self.device.value)
[docs]class TwoStageRetriever(AbstractTwoStageRetriever): """Use on retriever to select the top-K documents which are the re-ranked given a scorer""" def _retrieve( self, batch: List[ScoredDocument], query: str, scoredDocuments: List[ScoredDocument], ): scoredDocuments.extend(self.scorer.rsv(query, batch)) def retrieve(self, record: TopicRecord): # Calls the retriever scoredDocuments = self.retriever.retrieve(record) # Scorer in evaluation mode self.scorer.eval() _scoredDocuments = [] scoredDocuments = self._batcher.process( scoredDocuments, self._retrieve, record, _scoredDocuments ) _scoredDocuments.sort(reverse=True) return _scoredDocuments[: (self.top_k or len(_scoredDocuments))]
[docs]class DuoTwoStageRetriever(AbstractTwoStageRetriever): """The two stage retriever for pairwise scorers. For pairwise scorer, we need to aggregate the pairwise scores in some way. """ def _retrieve( self, batch: List[Tuple[ScoredDocument, ScoredDocument]], query: str, scoredDocuments: List[float], ): """call the function rsv to get the information for each batch because of the batchsize is independent on k, we may seperate the triplets belongs to the same query into different batches. """ scoredDocuments.extend(self.rsv(query, batch)) def retrieve(self, query: TopicRecord): """call the _retrieve function by using the batcher and do an aggregation of all the scores """ # get the documents from the retriever scoredDocuments_previous = self.retriever.retrieve(query) # transform them into the pairs (i, j) # for i != j ranging from 1 to nb of documents pairs = [] for i in range(len(scoredDocuments_previous)): for j in range(len(scoredDocuments_previous)): if i != j: pairs.append( (scoredDocuments_previous[i], scoredDocuments_previous[j]) ) # Scorer in evaluation mode self.scorer.eval() _scores_pairs = [] # the scores for each pair of documents self._batcher.process(pairs, self._retrieve, query, _scores_pairs) # Use the sum aggregation strategy _scores_pairs = torch.Tensor(_scores_pairs).reshape( len(scoredDocuments_previous), -1 ) _scores_per_document = torch.sum( _scores_pairs, dim=1 ) # scores for each document. # construct the ScoredDocument object from the score we just get. scoredDocuments = [] for i in range(len(scoredDocuments_previous)): scoredDocuments.append( ScoredDocument( scoredDocuments_previous[i], float(_scores_per_document[i]) ) ) scoredDocuments.sort(reverse=True) return scoredDocuments[: (self.top_k or len(scoredDocuments))] def rsv( self, record: TopicRecord, documents: List[Tuple[ScoredDocument, ScoredDocument]], ) -> List[float]: """Given the query and documents in tuple return the score for each triplets """ inputs = PairwiseRecords() for doc1, doc2 in documents: inputs.add(PairwiseRecord(record, doc1, doc2)) with torch.no_grad(): scores = self.scorer(inputs, None).cpu().float() # shape (batchsizes) return scores.tolist()
ARGS = TypeVar("ARGS") KWARGS = TypeVar("KWARGS") T = TypeVar("T") class DocumentsFunction(Protocol, Generic[KWARGS, ARGS, T]): def __call__(self, documents: Documents, *args: ARGS, **kwargs: KWARGS) -> T: ... def document_cache(fn: DocumentsFunction[KWARGS, ARGS, T]): """Decorator Allows to cache the result of a function that depends on the document dataset ID """ retrievers = {} def _fn(*args: ARGS, **kwargs: KWARGS): def cached(documents: Documents) -> T: dataset_id = documents.__identifier__().all if dataset_id not in retrievers: retrievers[dataset_id] = fn(documents, *args, **kwargs) return retrievers[dataset_id] return cached return _fn
[docs]class RetrieverHydrator(Retriever): """Hydrate retrieved results with document text""" retriever: Param[Retriever] """The retriever to hydrate""" store: Param[DocumentStore] """The store for document texts""" def initialize(self): return self.retriever.initialize() def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: return [ ScoredDocument(self.store.document_ext(sd.document[IDItem].id), sd.score) for sd in self.retriever.retrieve(record) ]