# This package contains all rankers
from abc import ABC, abstractmethod
from experimaestro import tqdm
from typing import (
Dict,
List,
Optional,
TYPE_CHECKING,
)
from experimaestro import Param, Config, field
from datamaestro_ir.data import (
Documents,
DocumentStore,
IDTextRecord,
AdhocRun,
)
from xpm_torch import ModuleContainer
from datamaestro_ir.data.base import ScoredDocument
if TYPE_CHECKING:
pass
import logging
logger = logging.getLogger(__name__)
[docs]
class Retriever(Config, ModuleContainer, ABC):
"""A retriever is a model to return top-scored documents given a query"""
store: Param[Optional[DocumentStore]] = field(default=None, ignore_default=True)
"""Give the document store associated with this retriever"""
def initialize(self):
pass
[docs]
def collection(self):
"""Returns the document collection object"""
raise NotImplementedError()
[docs]
def retrieve_all(
self, queries: Dict[str, IDTextRecord]
) -> Dict[str, List[ScoredDocument]]:
"""Retrieves for a set of documents
By default, iterate using `self.retrieve`, but this leaves some room open
for optimization
Args:
queries: A dictionary where the key is the ID of the query, and the value
is the text
"""
results = {}
for key, record in tqdm(list(queries.items())):
results[key] = self.retrieve(record)
return results
[docs]
@abstractmethod
def retrieve(self, record: IDTextRecord) -> List[ScoredDocument]:
"""Retrieves documents, returning a list sorted by decreasing score
if `content` is true, includes the document full text
"""
...
def _store(self) -> Optional[DocumentStore]:
"""Returns the associated document store (if any) that can be
used to get the full text of the documents"""
def get_store(self) -> Optional[DocumentStore]:
return self.store or self._store()
[docs]
class RetrieverHydrator(Retriever):
"""Hydrate retrieved results with document text"""
retriever: Param[Retriever]
"""The retriever to hydrate"""
store: Param[DocumentStore]
"""The store for document texts"""
def initialize(self):
return self.retriever.initialize()
def retrieve(self, record: IDTextRecord) -> List[ScoredDocument]:
return [
ScoredDocument(self.store.document_ext(sd.document["id"]), sd.score)
for sd in self.retriever.retrieve(record)
]
[docs]
class RunRetriever(Retriever):
"""A retriever that returns documents from a pre-computed run
Can be useful to build a two-stage retriever with precomputed first stage (e.g for validation when training a scorer model)
"""
run: Param[AdhocRun]
"""The pre-computed run"""
documents: Param[Documents]
"""Associated documents"""
def initialize(self):
super().initialize()
self._run_dict = self.run.get_dict()
def collection(self):
return self.documents
def retrieve(self, record: IDTextRecord) -> List[ScoredDocument]:
qid = record["id"]
results = self._run_dict.get(qid, {})
# Sort by score descending
sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
# Hydrate documents if the documents object is a store
if isinstance(self.documents, DocumentStore):
doc_ids = [doc_id for doc_id, _ in sorted_results]
hydrated_docs = self.documents.documents_ext(doc_ids)
return [
ScoredDocument(doc, float(score))
for doc, (_, score) in zip(hydrated_docs, sorted_results)
]
# Fallback to only ID
return [
ScoredDocument({"id": doc_id}, float(score))
for doc_id, score in sorted_results
]
def _store(self) -> Optional[DocumentStore]:
return self.documents if isinstance(self.documents, DocumentStore) else None