Source code for xpmir.neural.splade

from typing import List, Optional, Generic
from experimaestro import Config, Param
import torch.nn as nn
import torch
from xpmir.learning import ModuleInitOptions
from xpmir.distributed import DistributableModel
from xpmir.text.huggingface import (
    OneHotHuggingFaceEncoder,
    TransformerTokensEncoderWithMLMOutput,
)
from xpmir.text import TokenizerOptions
from xpmir.text.huggingface import HFTokenizerBase
from xpmir.text.encoders import (
    TextEncoder,
    TextEncoderBase,
    InputType as EncoderInputType,
    TextsRepresentationOutput,
)
from xpmir.neural.dual import DotDense, ScheduledFlopsRegularizer
from xpmir.text.huggingface.base import HFMaskedLanguageModel
from xpmir.utils.utils import easylog

logger = easylog()


[docs]class Aggregation(Config): """The aggregation function for Splade""" def get_output_module(self, linear: nn.Module): return AggregationModule(linear, self)
[docs]class MaxAggregation(Aggregation): """Aggregate using a max""" def __call__(self, logits, mask): # Get the maximum (masking the values) values, _ = torch.max( torch.relu(logits) * mask.to(logits.device).unsqueeze(-1), dim=1, ) # Computes log(1+x) return torch.log1p(values.clamp(min=0))
[docs]class SumAggregation(Aggregation): """Aggregate using a sum""" def __call__(self, logits, mask): return torch.sum( torch.log1p(torch.relu(logits) * mask.to(logits.device).unsqueeze(-1)), dim=1, )
class AggregationModule(nn.Module): def __init__(self, linear: nn.Linear, aggregation: Aggregation): super().__init__() self.linear = linear self.aggregation = aggregation def forward(self, input: torch.Tensor, mask: torch.Tensor): return self.aggregation(self.linear(input), mask) class SpladeTextEncoderModel(nn.Module): def __init__( self, encoder: TransformerTokensEncoderWithMLMOutput, aggregation: Aggregation ): super().__init__() self.encoder = encoder self.aggregation = aggregation def forward(self, tokenized): # We stock all the outputs in order to get the embedding matrix # Here as the automodel is not the same as the normal AutoModel, # So here the output has the attribute logits, the w_ij in the paper # which is of shape (bs, len(texts), vocab_size) out = self.encoder(tokenized, all_outputs=True) out = self.aggregation(out.logits, tokenized.mask) return out
[docs]class SpladeTextEncoder(TextEncoder, DistributableModel): """Splade model It is only a text encoder since the we use `xpmir.neural.dual.DotDense` as the scorer class """ encoder: Param[TransformerTokensEncoderWithMLMOutput] """The encoder from Hugging Face""" aggregation: Param[Aggregation] """How to aggregate the vectors""" maxlen: Param[Optional[int]] = None """Max length for texts""" def __initialize__(self, options: ModuleInitOptions): self.encoder.initialize(options) self.model = SpladeTextEncoderModel(self.encoder, self.aggregation) def forward(self, texts: List[str]) -> torch.Tensor: """Returns a batch x vocab tensor""" tokenized = self.encoder.batch_tokenize(texts, mask=True, maxlen=self.maxlen) out = self.model(tokenized) return out @property def dimension(self): return self.encoder.model.config.vocab_size def static(self): return False def distribute_models(self, update): self.model = update(self.model)
[docs]class SpladeTextEncoderV2( TextEncoderBase[EncoderInputType, TextsRepresentationOutput], DistributableModel, Generic[EncoderInputType], ): # TODO: use "SpladeTextEncoder" identifier until # https://github.com/experimaestro/experimaestro-python/issues/56 is fixed __xpmid__ = str(SpladeTextEncoder.__getxpmtype__().identifier) """Splade model text encoder (V2) It is only a text encoder since the we use `xpmir.neural.dual.DotDense` as the scorer class. Compared to V1, it uses the new text HF encoder abstractions. """ tokenizer: Param[HFTokenizerBase[EncoderInputType]] """The tokenizer from Hugging Face""" encoder: Param[HFMaskedLanguageModel] """The encoder from Hugging Face""" aggregation: Param[Aggregation] """How to aggregate the vectors""" maxlen: Param[Optional[int]] = None """Max length for texts""" def __initialize__(self, options: ModuleInitOptions): self.encoder.initialize(options) self.tokenizer.initialize(options) # Adds the aggregation head right away - this could allows # optimization e.g. for the Max aggregation method output_embeddings = self.encoder.model.get_output_embeddings() assert isinstance( output_embeddings, nn.Linear ), f"Cannot handle output embeddings of class {output_embeddings.__cls__}" self.encoder.model.set_output_embeddings(nn.Identity()) self.aggregation = self.aggregation.get_output_module(output_embeddings) def forward(self, texts: EncoderInputType) -> TextsRepresentationOutput: """Returns a batch x vocab tensor""" tokenized = self.tokenizer.tokenize( texts, options=TokenizerOptions(self.maxlen) ) value = self.aggregation(self.encoder(tokenized).logits, tokenized.mask) return TextsRepresentationOutput(value, tokenized) @property def dimension(self): return self.encoder.model.config.vocab_size def static(self): return False def distribute_models(self, update): self.encoder = update(self.encoder)
def _splade( lambda_q: float, lambda_d: float, aggregation: Aggregation, lambda_warmup_steps: int = 0, hf_id: str = "distilbert-base-uncased", ): # Unlike the cross-encoder, here the encoder returns the whole last layer # In the paper we use the DistilBERT-based as the checkpoint encoder = TransformerTokensEncoderWithMLMOutput(model_id=hf_id, trainable=True) # make use the output of the BERT and do an aggregation doc_encoder = SpladeTextEncoder( aggregation=aggregation, encoder=encoder, maxlen=200 ) query_encoder = SpladeTextEncoder( aggregation=aggregation, encoder=encoder, maxlen=30 ) return DotDense( encoder=doc_encoder, query_encoder=query_encoder ), ScheduledFlopsRegularizer( lambda_q=lambda_q, lambda_d=lambda_d, lambda_warmup_steps=lambda_warmup_steps, ) def _splade_doc( lambda_q: float, lambda_d: float, aggregation: Aggregation, lambda_warmup_steps: int = 0, hf_id: str = "distilbert-base-uncased", ): # Unlike the cross-encoder, here the encoder returns the whole last layer # The doc_encoder is the traditional one, and the query encoder return a vector # contains only 0 and 1 # In the paper we use the DistilBERT-based as the checkpoint encoder = TransformerTokensEncoderWithMLMOutput(model_id=hf_id, trainable=True) # make use the output of the BERT and do an aggregation doc_encoder = SpladeTextEncoder( aggregation=aggregation, encoder=encoder, maxlen=256 ) query_encoder = OneHotHuggingFaceEncoder(model_id=hf_id, maxlen=30) return DotDense( encoder=doc_encoder, query_encoder=query_encoder ), ScheduledFlopsRegularizer( lambda_q=lambda_q, lambda_d=lambda_d, lambda_warmup_steps=lambda_warmup_steps, ) def spladeV1( lambda_q: float, lambda_d: float, lambda_warmup_steps: int = 0, hf_id: str = "distilbert-base-uncased", ): """Returns the Splade architecture""" return _splade(lambda_q, lambda_d, SumAggregation(), lambda_warmup_steps, hf_id) def spladeV2_max( lambda_q: float, lambda_d: float, lambda_warmup_steps: int = 0, hf_id: str = "distilbert-base-uncased", ): """Returns the Splade-max architecture SPLADE v2: Sparse Lexical and Expansion Model for Information Retrieval (arXiv:2109.10086) """ return _splade(lambda_q, lambda_d, MaxAggregation(), lambda_warmup_steps, hf_id) def spladeV2_doc( lambda_q: float, lambda_d: float, lambda_warmup_steps: int = 0, hf_id: str = "distilbert-base-uncased", ): """Returns the Splade-doc architecture SPLADE v2: Sparse Lexical and Expansion Model for Information Retrieval (arXiv:2109.10086) """ return _splade_doc(lambda_q, lambda_d, MaxAggregation(), lambda_warmup_steps, hf_id)