Source code for xpmir.neural.interaction.drmm

import math
import numpy as np
from typing import Optional, List
from experimaestro import Config, Param, default
import torch
from torch import nn
from typing_extensions import Annotated
from xpmir.index import Index
import xpmir.neural.modules as modules
from xpmir.neural.interaction import (
    InteractionScorer,
    SimilarityOutput,
    TrainerContext,
    TokenizedTextEncoderBase,
    TokenizerOptions,
    TokensEncoderOutput,
)
from .common import SimilarityInputWithTokens

# The code below is heavily borrowed from OpenNIR


[docs]class CountHistogram(Config, nn.Module): """Base histogram class Attributes: nbins: number of bins in matching histogram """ nbins: Param[int] = 29 def forward(self, simmat: torch.Tensor, dlens: List[int], mask: torch.BoolTensor): """Computes the histograms for each query term :param simmat: A (B... x Lq x Ld) matrix :param mask: A (B... x Lq x Ld) mask :param dlens: The document lengths (vector of size D) :return: A (B... x Lq x nbins) matrix containing counts """ # +1e-5 to nudge scores of 1 to above threshold bins = ((simmat + 1.00001) / 2.0 * (self.nbins - 1)).long() weights = mask.float() hist = torch.zeros( *simmat.shape[:-1], self.nbins, device=simmat.device, dtype=simmat.dtype ) return hist.scatter_add_(simmat.ndim - 1, bins, weights)
[docs]class NormalizedHistogram(CountHistogram): def forward(self, simmat, dlens, mask): result = super().forward(simmat, dlens, mask) BATCH, QLEN, _ = simmat.shape return result / dlens.reshape(BATCH, 1).expand(BATCH, QLEN)
[docs]class LogCountHistogram(CountHistogram): def forward(self, simmat, dlens, mask): result = super().forward(simmat, dlens, mask) return (result.float() + 1e-5).log()
[docs]class Combination(Config, nn.Module): def forward(self, scores: torch.Tensor, idf: torch.Tensor): """Combines term scores with IDF :param scores: A (B... x Lq) tensor :param idf: A (B... x Lq) tensor """ ...
[docs]class SumCombination(Combination): def forward(self, scores, idf): return scores.sum(dim=-1)
[docs]class IdfCombination(Combination): def forward(self, scores: torch.Tensor, idf: torch.Tensor): idf = idf.softmax(dim=-1) return (scores * idf).sum(dim=-1)
[docs]class Drmm(InteractionScorer): """Deep Relevance Matching Model (DRMM) Implementation of the DRMM model from: Jiafeng Guo, Yixing Fan, Qingyao Ai, and William Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM. """ hist: Annotated[CountHistogram, default(LogCountHistogram())] """The histogram type""" hidden: Param[int] = 5 """Hidden layer dimension for the feed forward matching network""" index: Param[Optional[Index]] """The index (only used when using IDF to combine)""" combine: Annotated[Combination, default(IdfCombination())] """How to combine the query term scores""" def __validate__(self): super().__validate__() assert (self.combine != "idf") or ( self.index is not None ), "index must be provided if using IDF" def __initialize__(self, options): super().__initialize__(options) self.simmat = modules.InteractionMatrix(self.encoder.pad_tokenid) self.hidden_1 = nn.Linear(self.hist.nbins, self.hidden) self.hidden_2 = nn.Linear(self.hidden, 1) self.needs_idf = isinstance(self.combine, IdfCombination) def _encode( self, texts: List[str], encoder: TokenizedTextEncoderBase[str, TokensEncoderOutput], options: TokenizerOptions, ) -> SimilarityInputWithTokens: encoded = encoder(texts, options=options) max_len = max(encoded.tokenized.lens) padded_tokens = [ (t + [""] * (max_len - len(t))) for t in encoded.tokenized.tokens ] return self.similarity.preprocess( SimilarityInputWithTokens( encoded.value, encoded.tokenized.mask, np.array(padded_tokens, dtype=str), ) ) def compute_scores( self, queries: SimilarityInputWithTokens, documents: SimilarityInputWithTokens, value: SimilarityOutput, info: Optional[TrainerContext] = None, ): """Compute the scores given the tensor of similarities (B x Lq x Ld) or (Bq x Lq x Bd x Ld)""" # Computes the IDF if needed query_idf = None if self.needs_idf: assert self.index is not None query_idf = torch.full_like(queries.mask, float("-inf"), dtype=torch.float) log_nd = math.log(self.index.documentcount + 1) for i, tokens_i in enumerate(queries.tokens): for j, t in enumerate(tokens_i): query_idf[i, j] = log_nd - math.log(self.index.term_df(t) + 1) mask = value.q_view(queries.mask) * value.d_view(documents.mask) similarity = value.similarity dlens = torch.LongTensor([len(tokens) for tokens in documents.tokens]) if value.similarity.ndim == 4: # Transform into B... x Lq x Ld shape similarity = value.similarity.transpose(1, 2) mask = mask.transpose(1, 2) query_idf = query_idf.unsqueeze(0) qterm_features = self.histogram_pool(similarity, dlens, mask) qterm_scores = self.hidden_2(torch.relu(self.hidden_1(qterm_features))).squeeze( -1 ) return self.combine(qterm_scores, query_idf) def histogram_pool(self, simmat, dlens, mask): histogram = self.hist(simmat, dlens, mask) return histogram