# This package contains all rankers
from abc import ABC, abstractmethod
from typing import (
Dict,
Iterable,
List,
Optional,
Protocol,
Tuple,
Union,
TYPE_CHECKING,
TypedDict,
)
from typing_extensions import ReadOnly
from xpmir.text import TokenizedTexts
import torch
import torch.nn as nn
from lightning_fabric import Fabric
from experimaestro import Param, Config, Meta, field, tqdm
from datamaestro_ir.data import (
Documents,
IDTextRecord,
SimpleTextItem,
)
import os
import torch.distributed as dist
from xpm_torch import Module, Random
from xpm_torch.utils.utils import Initializable, to_device
from xpm_torch.utils.logging import EasyLogger
from xpm_torch.datasets import IndexedDataset, ShardedIterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader
from xpm_torch.learner import TrainerContext
from xpm_torch.losses import ModuleOutputType
from xpmir.letor.records import (
BaseItems,
PairwiseItem,
PairwiseItems,
PointwiseItem,
PointwiseItems,
ProductItems,
)
from datamaestro_ir.data.base import ScoredDocument
from .retriever import Retriever
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from xpmir.evaluation import RetrieverFactory
[docs]
class Scorer(Config, Initializable, EasyLogger, ABC):
"""Query-document scorer
A model able to give a score to a list of documents given a query
"""
_initialized = False
outputType: ModuleOutputType = ModuleOutputType.REAL
"""Determines the type of output scalar (log probability, probability, logit) """
doc: Meta[str] = ""
"""Paper description or title (used in HF Hub README)"""
bibtex: Meta[str] = ""
"""BibTeX citation (used in HF Hub README)"""
def __initialize__(self):
"""Initialize the scorer"""
pass
[docs]
def rsv(
self,
topic: Union[str, IDTextRecord],
documents: Union[List[ScoredDocument], ScoredDocument, str, List[str]],
) -> List[ScoredDocument]:
"""Compute the Retrieval Status Value (RSV) for a query and a set of documents.
This method is the primary entry point for scoring a set of documents
against a single query. It handles input normalization and delegates
to the :meth:`compute` method.
Note:
For large-scale evaluation involving multiple queries, using
:meth:`Retriever.retrieve_all` via a :class:`TwoStageRetriever`
is preferred as it allows for cross-query batching on GPUs.
"""
# Convert into document records
if isinstance(documents, str):
documents = [ScoredDocument({"text_item": SimpleTextItem(documents)}, None)]
elif isinstance(documents[0], str):
documents = [
ScoredDocument({"text_item": SimpleTextItem(scored_document)}, None)
for scored_document in documents
]
# Convert into topic record
if isinstance(topic, str):
topic = {"text_item": SimpleTextItem(topic)}
return self.compute(topic, documents)
[docs]
@abstractmethod
def compute(
self, topic: IDTextRecord, documents: Iterable[ScoredDocument]
) -> List[ScoredDocument]:
"""Score all documents with respect to a single topic.
This method should be implemented by subclasses to provide the actual
scoring logic. It is query-atomic (processes one query at a time).
"""
...
[docs]
def getRetriever(
self,
retriever: "Retriever",
batch_size: int,
top_k=None,
device=None,
):
"""Returns a two stage re-ranker from this retriever and a scorer
:param device: Device for the ranker or None if no change should be made
:param batch_size: The number of documents in each batch
:param top_k: Number of documents to re-rank (or None for all)
"""
return TwoStageRetriever.C(
retriever=retriever,
scorer=self,
batchsize=batch_size,
top_k=top_k if top_k else None,
)
[docs]
def scorer_retriever(
documents: Documents,
*,
retrievers: "RetrieverFactory",
scorer: Scorer,
key: str = None,
**kwargs,
):
"""Helper function that returns a two stage retriever. This is useful
when used with partial (when the scorer is not known).
:param documents: The document collection
:param retrievers: A retriever factory
:param scorer: The scorer
:return: A retriever, calling the :meth:scorer.getRetriever
"""
assert retrievers is not None, "The retrievers have not been given"
assert scorer is not None, "The scorer has not been given"
return scorer.getRetriever(retrievers(documents, key=key), **kwargs)
[docs]
class RandomScorer(Scorer):
"""A random scorer"""
random: Param[Random]
"""The random number generator"""
def compute(
self, record: IDTextRecord, scored_documents: Iterable[ScoredDocument]
) -> List[ScoredDocument]:
result = []
random = self.random.state
for scored_document in scored_documents:
result.append(ScoredDocument(scored_document.document, random.random()))
return result
class AbstractModuleScorerCall(Protocol):
def __call__(
self,
inputs: "BaseItems",
*,
tokenized: Optional[TokenizedTexts] = None,
): ...
[docs]
class AbstractModuleScorer(Scorer, Module):
"""Base class for all torch-based Modules implementing the `xpmir.rankers.Scorer`.
While :meth:`compute` (inherited from :class:`Scorer`) processes documents
for a single query, :class:`AbstractModuleScorer` also supports cross-query
batching when called directly through its `forward` method (aliased as `__call__`).
When used in a :class:`TwoStageRetriever` with a `batchsize > 0`, the retriever
will use the :class:`PointwiseItems` batching to maximize GPU utilization across
multiple queries.
"""
# Ensures basic operations are redirected to torch.nn.Module methods
__call__: AbstractModuleScorerCall = nn.Module.__call__
train = nn.Module.train
def __init__(self):
logger.info(f"Initializing {self.__class__.__name__}")
nn.Module.__init__(self)
super().__init__()
self._initialized = False
def __initialize__(self):
"""Initialize a learnable scorer (structure only)"""
return self
def get_forward_methods(self) -> list:
"""Returns the list of forward methods for this scorer. By default, it is just `forward`, but it can be extended to support multiple forward methods (e.g. for different scoring strategies)"""
return ["rsv"]
def compute(
self, topic: IDTextRecord, scored_documents: Iterable[ScoredDocument]
) -> List[ScoredDocument]:
"""Atomic scoring for a single query using ProductItems.
This implementation leverages the :meth:`forward` method by wrapping
the single query and its documents into a :class:`ProductItems` object.
"""
# Prepare the inputs and call the model
inputs = ProductItems()
inputs.add_topics(topic)
inputs.add_documents(*[sd.document for sd in scored_documents])
with torch.no_grad():
scores = self(inputs, None).cpu().float().numpy()
# Returns the scored documents
scoredDocuments = []
for i, sd in enumerate(scored_documents):
scoredDocuments.append(
ScoredDocument(
sd.document,
float(scores[i].item()),
)
)
return scoredDocuments
[docs]
class DuoLearnableScorer(AbstractModuleScorer):
"""Base class for models that can score a triplet (query, document 1, document 2)"""
def forward(self, inputs: "PairwiseItems", info: Optional[TrainerContext]):
"""Returns scores for pairs of documents (given a query)"""
raise NotImplementedError(f"abstract __call__ in {self.__class__}")
[docs]
class AbstractTwoStageRetriever(Retriever):
"""Abstract class for all two stage retrievers (i.e. scorers and duo-scorers)"""
retriever: Param[Retriever]
"""The base retriever"""
scorer: Param[Scorer]
"""The scorer used to re-rank the documents"""
top_k: Param[Optional[int]]
"""The number of returned documents (if None, returns all the documents)"""
batchsize: Meta[int] = field(default=0, ignore_default=True)
"""The batch size for the re-ranker"""
def initialize(self):
self.retriever.initialize()
self.scorer.initialize()
class RerankingInputs(TypedDict):
"""Inputs for re-ranking"""
records: ReadOnly[PointwiseItems]
"""The pointwise records"""
batch: ReadOnly[List[PointwiseItem]]
"""The original batch of items"""
tokenized_records: ReadOnly[Optional[TokenizedTexts]]
"""The tokenized records (if any)"""
def reranking_collate(
batch: List[PointwiseItem],
) -> RerankingInputs:
"""Collate PointwiseItems into a PointwiseItems batch."""
batch_items = PointwiseItems()
for item in batch:
batch_items.add(item)
return RerankingInputs(records=batch_items, batch=batch, tokenized_records=None)
class ReRankingDataset(ShardedIterableDataset):
"""A dataset that yields PointwiseItem records for re-ranking"""
def __init__(self, queries: Dict[str, IDTextRecord], retriever: Retriever):
super().__init__()
self.queries = list(queries.items())
self.retriever = retriever
def iter_shard(self, shard_id: int, num_shards: int):
import os
rank = os.environ.get(
"SLURM_PROCID", os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0"))
)
logger.debug(
"ReRankingDataset (rank %s): shard %d/%d starting (total queries: %d)",
rank,
shard_id,
num_shards,
len(self.queries),
)
for i in range(shard_id, len(self.queries), num_shards):
qid, query = self.queries[i]
# Pull first-stage results on-the-fly to avoid materialising everything
scored_docs = self.retriever.retrieve(query)
for sd in scored_docs:
yield PointwiseItem(query, sd.document, sd.score)
[docs]
class TwoStageRetriever(AbstractTwoStageRetriever):
"""Use on retriever to select the top-K documents which are the re-ranked
given a scorer.
Multi-GPU support:
When set up with a :class:`lightning.Fabric` instance, :meth:`retrieve_all`
shards the re-ranking task across GPUs and gathers the results. It uses
efficient cross-query batching to maximize GPU throughput.
"""
def retrieve(self, record: IDTextRecord):
# Calls the retriever
scoredDocuments = self.retriever.retrieve(record)
# Score all documents
_scoredDocuments = self.scorer.rsv(record, scoredDocuments)
_scoredDocuments.sort(reverse=True)
return _scoredDocuments[: (self.top_k or len(_scoredDocuments))]
def build_reranking_dataloader(self, queries: Dict[str, IDTextRecord]):
"""Builds a dataloader for re-ranking all documents for a set of queries
Allows for efficient two-stage retrieval with cross-query batching on GPU when the scorer supports it (i.e. is an AbstractModuleScorer and batchsize > 0).
"""
# We don't materialise everything, but iterate on the fly
dataset = ReRankingDataset(queries, self.retriever)
# get underlying module if wrapped (e.g. Fabric)
scorer = self.scorer.module if hasattr(self.scorer, "module") else self.scorer
if hasattr(scorer, "get_tokenizer_fn"):
tokenization_fn = scorer.get_tokenizer_fn()
def collate_fn(batch: List[PointwiseItem]) -> RerankingInputs:
inputs = reranking_collate(batch)
inputs["tokenized_records"] = tokenization_fn(inputs["records"])
return inputs
else:
collate_fn = reranking_collate
dataloader = StatefulDataLoader(
dataset,
batch_size=self.batchsize,
num_workers=min(int(os.environ.get("SLURM_CPUS_PER_TASK", 4)), 4),
pin_memory=True if torch.cuda.is_available() else False,
collate_fn=collate_fn,
)
fabric: Fabric = getattr(self, "fabric", None)
if fabric:
dataloader = fabric.setup_dataloaders(dataloader)
return dataloader
@torch.no_grad()
def retrieve_all(
self, queries: Dict[str, IDTextRecord]
) -> Dict[str, List[ScoredDocument]]:
"""Retrieves documents for all queries in an efficient two - stage fashion:
- populate a `PointWiseItem` dataset with the documents from first stage
- reranks them on the fly with the scorer with given batch size
- if self.batchsize is 0, scores all documents from the same query at once (will cause OOM large top_k first stages)
"""
# Scorer in evaluation mode
self.scorer.eval()
if self.batchsize == 0:
# Fallback to per-query retrieval if no batchsize
return super().retrieve_all(queries)
# get underlying module if wrapped (e.g. Fabric)
scorer = self.scorer.module if hasattr(self.scorer, "module") else self.scorer
scorer_type = type(scorer)
fabric: Fabric = getattr(self, "fabric", None)
dataloader = self.build_reranking_dataloader(queries)
# Process in batches
scored_results = {qid: [] for qid in queries}
seen_qids = set()
disable_tqdm = fabric is not None and not fabric.is_global_zero
# Calculate local total for the progress bar to reach 100% on rank 0
total_to_process = len(queries)
if fabric and fabric.world_size > 1:
# ReRankingDataset shards such that rank R handles indices i where (i // num_workers) % world_size == R
_dataloader = (
dataloader.dataloader
if hasattr(dataloader, "dataloader")
else dataloader
)
num_workers = max(1, getattr(_dataloader, "num_workers", 0))
total_to_process = sum(
1
for i in range(len(queries))
if (i // num_workers) % fabric.world_size == fabric.global_rank
)
desc = f"(Rank {fabric.global_rank if fabric else 0}): Re-Ranking"
else:
desc = "Re-Ranking"
batch_size_info = (
f"batch size {self.batchsize}"
if issubclass(scorer_type, AbstractModuleScorer)
else "rsv (one-by-one)"
)
logger.info(
f"{desc} with '{scorer_type.__name__}' using {batch_size_info}... "
f"({total_to_process}/{len(queries)} queries on this rank)"
)
pbar = (
tqdm(
total=total_to_process,
desc=desc,
unit="query",
)
if not disable_tqdm
else None
)
for inputs in dataloader:
batch_items = inputs["records"]
batch = inputs["batch"]
tokenized_records = inputs.get("tokenized_records")
if issubclass(scorer_type, AbstractModuleScorer):
# Use scorer.forward if it's an AbstractModuleScorer to batch across queries
scores = (
self.scorer(batch_items, tokenized=tokenized_records)
.cpu()
.float()
.numpy()
)
for score, item in zip(scores, batch):
qid = item.topic["id"]
if qid not in seen_qids:
seen_qids.add(qid)
if pbar is not None:
pbar.update(1)
scored_results[qid].append(
to_device(
ScoredDocument(item.document, float(score.item())),
"cpu",
)
)
else:
# Fallback: group by query and use rsv (score one by one)
by_query = {}
for item in batch:
qid = item.topic["id"]
if qid not in seen_qids:
seen_qids.add(qid)
if pbar is not None:
pbar.update(1)
by_query.setdefault(qid, []).append(
to_device(
ScoredDocument(item.document, item.relevance),
"cpu",
)
)
for qid, docs in by_query.items():
scored_results[qid].extend(self.scorer.rsv(queries[qid], docs))
if pbar is not None:
pbar.close()
if fabric and fabric.world_size > 1:
# Gather results from all GPUs
# We use a list of tuples and filter empty results to reduce
# communication overhead. Using torch.distributed directly for
# objects is more robust than fabric.all_gather which is tensor-oriented.
local_results = [
(qid, docs) for qid, docs in scored_results.items() if docs
]
gathered_data = [None] * fabric.world_size
dist.all_gather_object(gathered_data, local_results)
if fabric.is_global_zero:
# Merge them back into one master dictionary
final_results = {qid: [] for qid in queries}
for rank_results in gathered_data:
for qid, docs in rank_results:
final_results[qid].extend(docs)
scored_results = final_results
else:
return {}
# Sort and truncate
for qid in scored_results:
scored_results[qid].sort(reverse=True)
if self.top_k:
scored_results[qid] = scored_results[qid][: self.top_k]
return scored_results
[docs]
class DuoTwoStageRetriever(AbstractTwoStageRetriever):
"""The two stage retriever for pairwise scorers.
For pairwise scorer, we need to aggregate the pairwise scores in some
way.
"""
def retrieve(self, query: IDTextRecord):
"""call the _retrieve function by using the batcher and do an
aggregation of all the scores
"""
# get the documents from the retriever
scoredDocuments_previous = self.retriever.retrieve(query)
# Scorer in evaluation mode
self.scorer.eval()
# Generator for pairs to avoid materialising everything
def iter_pairs():
for i in range(len(scoredDocuments_previous)):
for j in range(len(scoredDocuments_previous)):
if i != j:
yield (scoredDocuments_previous[i], scoredDocuments_previous[j])
_scores_pairs = [] # the scores for each pair of documents
if self.batchsize > 0:
# Duo-reranking often involves a small number of docs (N=50, 100),
# but N^2 can still be large (10k). We use a list for now as
# IndexedDataset needs a sequence, but the materialisation is limited
# to ONE query at a time.
pairs = list(iter_pairs())
dataset = IndexedDataset(pairs)
dataloader = StatefulDataLoader(
dataset, batch_size=self.batchsize, collate_fn=lambda x: x
)
for batch in dataloader:
_scores_pairs.extend(self.rsv(query, batch))
else:
pairs = list(iter_pairs())
_scores_pairs = self.rsv(query, pairs)
# Use the sum aggregation strategy
_scores_pairs = torch.Tensor(_scores_pairs).reshape(
len(scoredDocuments_previous), -1
)
_scores_per_document = torch.sum(
_scores_pairs, dim=1
) # scores for each document.
# construct the ScoredDocument object from the score we just get.
scoredDocuments = []
for i in range(len(scoredDocuments_previous)):
scoredDocuments.append(
ScoredDocument(
scoredDocuments_previous[i], float(_scores_per_document[i])
)
)
scoredDocuments.sort(reverse=True)
return scoredDocuments[: (self.top_k or len(scoredDocuments))]
def rsv(
self,
record: IDTextRecord,
documents: List[Tuple[ScoredDocument, ScoredDocument]],
) -> List[float]:
"""Given the query and documents in tuple
return the score for each triplets
"""
inputs = PairwiseItems()
for doc1, doc2 in documents:
inputs.add(PairwiseItem(record, doc1, doc2))
with torch.no_grad():
scores = self.scorer(inputs, None).cpu().float() # shape (batchsizes)
return scores.tolist()