Source code for xpmir.neural.interaction.drmm

import math
from typing import Optional
from experimaestro import Config, Param, default
import torch
from torch import nn
from typing_extensions import Annotated
from xpmir.index import Index
from xpmir.neural.interaction import InteractionScorer
import xpmir.neural.modules as modules

# The code below is heavily borrowed from OpenNIR


[docs]class CountHistogram(Config, nn.Module): """Base histogram class Attributes: nbins: number of bins in matching histogram """ nbins: Param[int] = 29 def forward(self, simmat, dlens, dtoks, qtoks): BATCH, CHANNELS, QLEN, DLEN = simmat.shape # +1e-5 to nudge scores of 1 to above threshold bins = ((simmat + 1.00001) / 2.0 * (self.nbins - 1)).int() weights = ( (dtoks != -1).reshape(BATCH, 1, DLEN).expand(BATCH, QLEN, DLEN) * (qtoks != -1).reshape(BATCH, QLEN, 1).expand(BATCH, QLEN, DLEN) ).float() # apparently no way to batch this... # https://discuss.pytorch.org/t/histogram-function-in-pytorch/5350 # WARNING: this line (and the similar line below) improve performance # tenfold when on GPU bins, weights = ( bins.cpu(), weights.cpu(), ) histogram = [] for superbins, w in zip(bins, weights): result = [] for b in superbins: result.append( torch.stack( [torch.bincount(q, x, self.nbins) for q, x in zip(b, w)], dim=0 ) ) result = torch.stack(result, dim=0) histogram.append(result) histogram = torch.stack(histogram, dim=0) # WARNING: this line (and the similar line above) improve performance # tenfold when on GPU histogram = histogram.to(simmat.device) return histogram
[docs]class NormalizedHistogram(CountHistogram): def forward(self, simmat, dlens, dtoks, qtoks): result = super().forward(simmat, dlens, dtoks, qtoks) BATCH, QLEN, _ = simmat.shape return result / dlens.reshape(BATCH, 1).expand(BATCH, QLEN)
[docs]class LogCountHistogram(CountHistogram): def forward(self, simmat, dlens, dtoks, qtoks): result = super().forward(simmat, dlens, dtoks, qtoks) return (result.float() + 1e-5).log()
[docs]class Combination(Config, nn.Module): pass
[docs]class SumCombination(Combination): def forward(self, scores, idf): return scores.sum(dim=1)
[docs]class IdfCombination(Combination): def forward(self, scores, idf): idf = idf.softmax(dim=1) return (scores * idf).sum(dim=1)
[docs]class Drmm(InteractionScorer): """Deep Relevance Matching Model (DRMM) Implementation of the DRMM model from: Jiafeng Guo, Yixing Fan, Qingyao Ai, and William Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM. """ hist: Annotated[CountHistogram, default(LogCountHistogram())] """The histogram type""" hidden: Param[int] = 5 """Hidden layer dimension for the feed forward matching network""" index: Param[Optional[Index]] """The index (only used when using IDF to combine)""" combine: Annotated[Combination, default(IdfCombination())] """How to combine the query term scores""" def __validate__(self): super().__validate__() assert (self.combine != "idf") or ( self.index is not None ), "index must be provided if using IDF" def __initialize__(self, options): if not self.vocab.static(): self.logger.warning( "In most cases, using vocab.train=True will not have an effect on DRMM " "because the histogram is not differentiable. An exception might be if " "the gradient is proped back by another means, e.g. BERT [CLS] token." ) super().__initialize__(options) self.simmat = modules.InteractionMatrix(self.vocab.pad_tokenid) channels = self.vocab.emb_views() self.hidden_1 = nn.Linear(self.hist.nbins * channels, self.hidden) self.hidden_2 = nn.Linear(self.hidden, 1) self.needs_idf = isinstance(self.combine, IdfCombination) def _forward(self, inputs, info): simmat, tokq, tokd = self.simmat.encode_query_doc( self.vocab, inputs, d_maxlen=self.dlen, q_maxlen=self.qlen ) # Computes the IDF if needed query_idf = None if self.needs_idf: assert self.index is not None query_idf = torch.full_like(tokq.ids, float("-inf"), dtype=torch.float) log_nd = math.log(self.index.documentcount + 1) for i, tok in enumerate(tokq.tokens): for j, t in zip(range(self.qlen), tok): query_idf[i, j] = log_nd - math.log(self.index.term_df(t) + 1) qterm_features = self.histogram_pool(simmat, tokq, tokd) BAT, QLEN, _ = qterm_features.shape qterm_scores = self.hidden_2(torch.relu(self.hidden_1(qterm_features))).reshape( BAT, QLEN ) return self.combine(qterm_scores, query_idf) def histogram_pool(self, simmat, tokq, tokd): histogram = self.hist(simmat, tokd.lens, tokd.ids, tokq.ids) BATCH, CHANNELS, QLEN, BINS = histogram.shape histogram = histogram.permute(0, 2, 3, 1) histogram = histogram.reshape(BATCH, QLEN, BINS * CHANNELS) return histogram