Source code for xpmir.rankers.full

from typing import List, Optional, Tuple, Dict, Any
from experimaestro import Param, Meta, tqdm
import torch
from import Documents
from xpmir.neural.dual import DualRepresentationScorer
from xpmir.learning.batchers import Batcher
from xpmir.learning import ModuleInitMode
from xpmir.letor import Device
from xpmir.letor.records import DocumentRecord, TopicRecord
from . import Retriever, ScoredDocument

[docs]class FullRetriever(Retriever): """Retrieves all the documents of the collection This can be used to build a small validation set on a subset of the collection - in that case, the scorer can be used through a TwoStageRetriever, with this retriever as the base retriever. """ documents: Param[Documents] def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: return [ScoredDocument(doc, 0.0) for doc in self.documents]
[docs]class FullRetrieverRescorer(Retriever): """Scores all the documents from a collection""" documents: Param[Documents] """The set of documents to consider""" scorer: Param[DualRepresentationScorer] """The scorer (a dual representation scorer)""" batchsize: Param[int] = 0 batcher: Meta[Batcher] = Batcher() device: Meta[Optional[Device]] = None def initialize(self): self.query_batcher = self.batcher.initialize(self.batchsize) self.document_batcher = self.batcher.initialize(self.batchsize) self.scorer.initialize(ModuleInitMode.DEFAULT.to_options()) # Compute with the scorer if self.device is not None: def _retrieve( self, batch: List[ScoredDocument], query: str, scoredDocuments: List[ScoredDocument], ): scoredDocuments.extend(self.scorer.rsv(query, batch)) def encode_queries(self, queries: List[Tuple[str, str]], encoded: List[Any], pbar): """Encode queries and append the tensor of encoded queries to the encoded Args: queries (List[Tuple[str, str]]): The input queries (id/text) encoded (List[Tuple[List[str], torch.Tensor]]): Full list of topics ?? it should be the List[torch.Tensor] """ encoded.append(self.scorer.encode_queries([text for _, text in queries])) pbar.update(len(queries)) return encoded def score( self, documents: List[DocumentRecord], queries: List, scored_documents: List[List[ScoredDocument]], pbar, ): """Score documents for a set of queries Every time the score process a batch of document together with whole set of queries scored_documents is filled with document batches, i.e. it contains [ [s(q_0, d_0), ..., s(q_n, d0)], ..., [s(q_0, d_m), ..., s(q_n, d_m)] ] --> list of m*n :param documents: the batch of documents :param queries: List of queries :param scored_documents: (output) current lists of scored documents (one per query) """ # Encode documents encoded = self.scorer.encode_documents(documents) # Process query by query new_scores = [[] for _ in documents] for ix in range(len(queries)): # Get a range of query records query = queries[ix : (ix + 1)] # Returns a query x document matrix scores = self.scorer.score_product(, encoded, None) # Adds up to the lists scores = scores.flatten().detach() for ix, (document, score) in enumerate(zip(documents, scores)): new_scores[ix].append(ScoredDocument(document, float(score))) pbar.update(1) # Add each result to the full document list scored_documents.extend(new_scores) def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: # Only use retrieve_all return self.retrieve_all({"_": record})["_"] def retrieve_all( self, queries: Dict[str, TopicRecord] ) -> Dict[str, List[ScoredDocument]]: """Input is a dictionary of query {id:text}, return the a dictionary of {query_id: List of ScoredDocuments under the query} """ self.scorer.eval() all_queries = list(queries.items()) with torch.no_grad(): # Encode all queries # each time the batcher will just encode a batchsize of queries # and then concat them together with tqdm(total=len(all_queries), desc="Encoding queries") as pbar: enc_queries = self.query_batcher.reduce( all_queries, self.encode_queries, [], pbar ) enc_queries = self.scorer.merge_queries( enc_queries ) # shape (len(queries), dimension) # Encode documents and score them scored_documents: List[List[ScoredDocument]] = [] with tqdm( total=len(all_queries) * self.documents.documentcount, desc="Scoring documents", ) as pbar: self.document_batcher.process( self.documents, self.score, enc_queries, scored_documents, pbar ) qids = [qid for qid, _ in all_queries] return {qid: [sd[ix] for sd in scored_documents] for ix, qid in enumerate(qids)}