from typing import Any, Dict, List, Tuple
import torch
from datamaestro_ir.data import Documents
from experimaestro import field, Meta, Param
from experimaestro import tqdm
from xpm_torch.batchers import Batcher
from xpmir.letor.records import DocumentRecord, TopicRecord
from xpmir.neural import DualRepresentationScorer
from xpmir.rankers import Retriever, ScoredDocument
[docs]
class FullRetriever(Retriever):
"""Retrieves all the documents of the collection
This can be used to build a small validation set on a subset of the
collection - in that case, the scorer can be used through a
TwoStageRetriever, with this retriever as the base retriever.
"""
documents: Param[Documents]
def retrieve(self, record: TopicRecord) -> List[ScoredDocument]:
return [ScoredDocument(doc, 0.0) for doc in self.documents]
[docs]
class FullRetrieverRescorer(Retriever):
"""Scores all the documents from a collection
Encodes all queries at once, then processes documents in batches,
scoring the full query×document matrix each batch. This is more
efficient than the TwoStageRetriever approach for small collections.
"""
documents: Param[Documents]
"""The set of documents to consider"""
scorer: Param[DualRepresentationScorer]
"""The scorer (a dual representation scorer)"""
batchsize: Param[int] = field(default=0, ignore_default=True)
batcher: Meta[Batcher] = field(default=Batcher.C(), ignore_default=True)
def initialize(self):
self.query_batcher = self.batcher.initialize(self.batchsize)
self.document_batcher = self.batcher.initialize(self.batchsize)
def encode_queries(self, queries: List[Tuple[str, str]], encoded: List[Any], pbar):
encoded.append(self.scorer.encode_queries([text for _, text in queries]))
pbar.update(len(queries))
return encoded
def score(
self,
documents: List[DocumentRecord],
queries: List,
scored_documents: List[List[ScoredDocument]],
pbar,
):
encoded = self.scorer.encode_documents(documents)
new_scores = [[] for _ in documents]
for ix in range(len(queries)):
query = queries[ix : (ix + 1)]
scores = self.scorer.score_product(query.to(encoded.device), encoded, None)
scores = scores.flatten().detach()
for doc_ix, (document, score) in enumerate(zip(documents, scores)):
new_scores[doc_ix].append(ScoredDocument(document, float(score)))
pbar.update(1)
scored_documents.extend(new_scores)
def retrieve(self, record: TopicRecord) -> List[ScoredDocument]:
return self.retrieve_all({"_": record})["_"]
def retrieve_all(
self, queries: Dict[str, TopicRecord]
) -> Dict[str, List[ScoredDocument]]:
self.scorer.eval()
all_queries = list(queries.items())
with torch.no_grad():
with tqdm(total=len(all_queries), desc="Encoding queries") as pbar:
enc_queries = self.query_batcher.reduce(
all_queries, self.encode_queries, [], pbar
)
enc_queries = self.scorer.merge_queries(enc_queries)
scored_documents: List[List[ScoredDocument]] = []
with tqdm(
total=len(all_queries) * self.documents.documentcount,
desc="Scoring documents",
) as pbar:
self.document_batcher.process(
self.documents, self.score, enc_queries, scored_documents, pbar
)
qids = [qid for qid, _ in all_queries]
return {qid: [sd[ix] for sd in scored_documents] for ix, qid in enumerate(qids)}