Source code for xpmir.neural.common

import torch
from experimaestro import Config


[docs]class Similarity(Config): """A similarity between vector representations""" def __call__(self, queries: torch.Tensor, documents: torch.Tensor) -> torch.Tensor: raise NotImplementedError()
[docs]class L2Distance(Similarity): def __call__(self, queries, documents): return ( (-1.0 * ((queries.unsqueeze(2) - documents.unsqueeze(1)) ** 2).sum(-1)) .max(-1) .values.sum(-1) )
[docs]class CosineSimilarity(Similarity): def __call__(self, queries, documents): return (queries @ documents.permute(0, 2, 1)).max(2).values.sum(1)