from abc import abstractmethod
import itertools
from typing import Iterable, Union, List, Optional, TypeVar, Generic, Sequence
import torch
from datamaestro_text.data.ir import TextItem
from xpmir.learning.context import TrainerContext
from xpmir.letor.records import BaseRecords, ProductRecords, TopicRecord, DocumentRecord
from xpmir.rankers import LearnableScorer
QueriesRep = TypeVar("QueriesRep", bound=Sequence)
DocsRep = TypeVar("DocsRep", bound=Sequence)
[docs]class DualRepresentationScorer(LearnableScorer, Generic[QueriesRep, DocsRep]):
"""Neural scorer based on (at least a partially) independent representation
of the document and the question.
This is the base class for all scorers that depend on a map
of cosine/inner products between query and document tokens.
"""
def forward(self, inputs: BaseRecords, info: Optional[TrainerContext] = None):
# Forward to model
enc_queries = self.encode_queries(list(inputs.unique_queries))
enc_documents = self.encode_documents(list(inputs.unique_documents))
# Score product
if isinstance(inputs, ProductRecords):
return self.score_product(
enc_queries,
enc_documents,
info,
).flatten()
# Score pairs
pairs = inputs.pairs()
q_ix, d_ix = pairs
return self.score_pairs(
enc_queries[
q_ix
],
enc_documents[
d_ix
],
info,
).flatten()
def encode(self, texts: Iterable[str]) -> Union[DocsRep, QueriesRep]:
"""Encode a list of texts (document or query)
The return value is model dependent"""
raise NotImplementedError()
def encode_documents(self, records: Iterable[DocumentRecord]) -> DocsRep:
"""Encode a list of texts (document or query)
The return value is model dependent"""
return self.encode([record[TextItem].text for record in records])
def encode_queries(self, records: Iterable[TopicRecord]) -> QueriesRep:
"""Encode a list of texts (document or query)
The return value is model dependent, but should be sequence
By default, uses `merge`
"""
return self.encode([record[TextItem].text for record in records])
def merge_queries(self, queries: QueriesRep):
"""Merge query batches encoded with `encode_queries`
By default, uses `merge`
"""
return self.merge(queries)
def merge_documents(self, documents: DocsRep):
"""Merge query batches encoded with `encode_documents`"""
return self.merge(documents)
def merge(self, objects: Union[DocsRep, QueriesRep]):
"""Merge objects
- for tensors, uses torch.cat
- for lists, concatenate all of them
"""
assert isinstance(
objects, List
), f"Merging can only be done with lists, got {type(objects)}"
# Just returns the only object to merge
if len(objects) == 1:
return objects[0]
if isinstance(objects[0], torch.Tensor):
return torch.cat(objects)
if isinstance(objects[0], List):
return list(itertools.chain(objects))
from xpmir.text.encoders import TextsRepresentationOutput
from xpmir.text.tokenizers import TokenizedTexts
if isinstance(objects[0], TextsRepresentationOutput):
def merge_mask(mask):
min_batch_size = torch.min(
torch.tensor(list(map(lambda x: x.shape, mask))), dim=0
)[0][0]
batch_equalized = list(
itertools.chain(
*map(lambda x: x.t().split(min_batch_size, dim=1), mask)
)
)
pad = torch.nn.utils.rnn.pad_sequence(batch_equalized)
return pad.reshape(pad.shape[0], -1).t()
tokenized = list(map(lambda x: x.tokenized, objects))
tokens = list(
filter(lambda x: x is not None, map(lambda x: x.tokens, tokenized))
)
tokens = None if len(tokens) == 0 else list(itertools.chain(*tokens))
ids = merge_mask(list(map(lambda x: x.ids, tokenized)))
lens = list(itertools.chain(*map(lambda x: x.lens, tokenized)))
mask = list(map(lambda x: x.mask, tokenized))
mask = None if len(mask) == 0 else merge_mask(mask)
token_type_ids = list(
filter(
lambda x: x is not None, map(lambda x: x.token_type_ids, tokenized)
)
)
token_type_ids = (
None if len(token_type_ids) == 0 else torch.cat(token_type_ids)
)
return TextsRepresentationOutput(
torch.cat(list(map(lambda x: x.value, objects))),
TokenizedTexts(tokens, ids, lens, mask, token_type_ids),
)
raise RuntimeError(f"Cannot deal with objects of type {type(list[0])}")
[docs] @abstractmethod
def score_product(
self,
queries: QueriesRep,
documents: DocsRep,
info: Optional[TrainerContext] = None,
) -> torch.Tensor:
"""Computes the score of all possible pairs of query and document
Args:
queries (Any): The encoded queries
documents (Any): The encoded documents
info (Optional[TrainerContext]): The training context (if learning)
Returns:
torch.Tensor:
A tensor of dimension (N, P) where N is the number of queries
and P the number of documents
"""
...
[docs] @abstractmethod
def score_pairs(
self,
queries: QueriesRep,
documents: DocsRep,
info: Optional[TrainerContext] = None,
) -> torch.Tensor:
"""Score the specified pairs of queries/documents.
There are as many queries as documents. The exact type of
queries and documents depends on the specific instance of the
dual representation scorer.
Args:
queries (QueriesRep): The list of encoded queries
documents (DocsRep): The matching list of encoded documents
info (Optional[TrainerContext]): _description_
Returns:
torch.Tensor:
A tensor of dimension (N) where N is the number of documents/queries
"""
...