Source code for xpmir.rankers.full

from typing import List, Optional, Tuple, Dict, Any
from experimaestro import Param, Meta, tqdm
import torch
from datamaestro_text.data.ir import Documents
from xpmir.neural.dual import DualRepresentationScorer
from xpmir.learning.batchers import Batcher
from xpmir.learning import ModuleInitMode
from xpmir.letor import Device
from xpmir.letor.records import DocumentRecord, TopicRecord
from . import Retriever, ScoredDocument


[docs]class FullRetriever(Retriever): """Retrieves all the documents of the collection This can be used to build a small validation set on a subset of the collection - in that case, the scorer can be used through a TwoStageRetriever, with this retriever as the base retriever. """ documents: Param[Documents] def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: return [ScoredDocument(doc, 0.0) for doc in self.documents]
[docs]class FullRetrieverRescorer(Retriever): """Scores all the documents from a collection""" documents: Param[Documents] """The set of documents to consider""" scorer: Param[DualRepresentationScorer] """The scorer (a dual representation scorer)""" batchsize: Param[int] = 0 batcher: Meta[Batcher] = Batcher() device: Meta[Optional[Device]] = None def initialize(self): self.query_batcher = self.batcher.initialize(self.batchsize) self.document_batcher = self.batcher.initialize(self.batchsize) self.scorer.initialize(ModuleInitMode.DEFAULT.to_options()) # Compute with the scorer if self.device is not None: self.scorer.to(self.device.value) def _retrieve( self, batch: List[ScoredDocument], query: str, scoredDocuments: List[ScoredDocument], ): scoredDocuments.extend(self.scorer.rsv(query, batch)) def encode_queries(self, queries: List[Tuple[str, str]], encoded: List[Any], pbar): """Encode queries and append the tensor of encoded queries to the encoded Args: queries (List[Tuple[str, str]]): The input queries (id/text) encoded (List[Tuple[List[str], torch.Tensor]]): Full list of topics ?? it should be the List[torch.Tensor] """ encoded.append(self.scorer.encode_queries([text for _, text in queries])) pbar.update(len(queries)) return encoded def score( self, documents: List[DocumentRecord], queries: List, scored_documents: List[List[ScoredDocument]], pbar, ): """Score documents for a set of queries Every time the score process a batch of document together with whole set of queries scored_documents is filled with document batches, i.e. it contains [ [s(q_0, d_0), ..., s(q_n, d0)], ..., [s(q_0, d_m), ..., s(q_n, d_m)] ] --> list of m*n :param documents: the batch of documents :param queries: List of queries :param scored_documents: (output) current lists of scored documents (one per query) """ # Encode documents encoded = self.scorer.encode_documents(documents) # Process query by query new_scores = [[] for _ in documents] for ix in range(len(queries)): # Get a range of query records query = queries[ix : (ix + 1)] # Returns a query x document matrix scores = self.scorer.score_product(query.to(encoded.device), encoded, None) # Adds up to the lists scores = scores.flatten().detach() for ix, (document, score) in enumerate(zip(documents, scores)): new_scores[ix].append(ScoredDocument(document, float(score))) pbar.update(1) # Add each result to the full document list scored_documents.extend(new_scores) def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: # Only use retrieve_all return self.retrieve_all({"_": record})["_"] def retrieve_all( self, queries: Dict[str, TopicRecord] ) -> Dict[str, List[ScoredDocument]]: """Input is a dictionary of query {id:text}, return the a dictionary of {query_id: List of ScoredDocuments under the query} """ self.scorer.eval() all_queries = list(queries.items()) with torch.no_grad(): # Encode all queries # each time the batcher will just encode a batchsize of queries # and then concat them together with tqdm(total=len(all_queries), desc="Encoding queries") as pbar: enc_queries = self.query_batcher.reduce( all_queries, self.encode_queries, [], pbar ) enc_queries = self.scorer.merge_queries( enc_queries ) # shape (len(queries), dimension) # Encode documents and score them scored_documents: List[List[ScoredDocument]] = [] with tqdm( total=len(all_queries) * self.documents.documentcount, desc="Scoring documents", ) as pbar: self.document_batcher.process( self.documents, self.score, enc_queries, scored_documents, pbar ) qids = [qid for qid, _ in all_queries] return {qid: [sd[ix] for sd in scored_documents] for ix, qid in enumerate(qids)}