Source code for xpmir.neural.colbert

"""ColBERT: late interaction over per-token BERT embeddings.

This module implements a ColBERT-style dual scorer using the
:class:`~xpmir.text.encoders.TokensRepresentationOutput` abstraction. The query
and document encoders return one vector per input token; scoring is done using
the late-interaction "MaxSim" operator: for each query token we take the
maximum similarity against all document tokens, then sum over the query
tokens.

Reference: Khattab & Zaharia, "ColBERT: Efficient and Effective Passage Search
via Contextualized Late Interaction over BERT" (SIGIR 2020).
"""

from pathlib import Path
from typing import Iterable, List, Optional, Tuple
from attrs import evolve

import torch
import torch.nn as nn
from experimaestro import Param, field, LightweightTask
from datamaestro_ir.data import IDTextRecord
from xpm_torch.learner import TrainerContext

from xpmir.letor.records import BaseItems, ProductItems
from xpmir.neural import DocsRep, QueriesRep
from xpmir.neural.dual import DualVectorScorer
from xpmir.rankers.scorer import AbstractModuleScorer
from xpmir.text.encoders import TokensRepresentationOutput
from xpmir.text.huggingface.tokenizers import get_default_max_len  # noqa: F401
from xpmir.text.tokenizers import TokenizedTexts, TokenizerOptions


try:
    from pylate import models
except Exception:  # ImportError or if pylate not available for any reason
    models = None

import logging

logging.basicConfig(level=logging.INFO)


[docs] class ColBERTEncoder( DualVectorScorer[TokensRepresentationOutput, TokensRepresentationOutput] ): """ColBERT-style dual scorer with late interaction MaxSim. The document (and optional query) encoder must return :class:`~xpmir.text.encoders.TokensRepresentationOutput`, i.e. a ``(batch, max_tokens, hidden_dim)`` tensor together with the tokenized inputs (providing the attention mask). A trainable linear projection reduces the per-token vectors to ``dim`` and the vectors are L2-normalised so the dot product amounts to a cosine similarity. The :attr:`encoder` (and :attr:`query_encoder`) inherited from :class:`~xpmir.neural.dual.DualVectorScorer` must in practice be a :class:`~xpmir.text.encoders.TokenizedTextEncoder` returning :class:`~xpmir.text.encoders.TokensRepresentationOutput`. The :class:`~xpmir.text.encoders.TokenizedTextEncoder` exposes the ``tokenize`` / ``forward_tokenized`` split that query augmentation needs. """ dim: Param[int] = field(default=128, ignore_default=True) """Output dimension of the per-token projection.""" query_maxlen: Param[int] = field(default=32, ignore_default=True) """Maximum number of tokens kept for a query.""" doc_maxlen: Param[int] = field(default=180, ignore_default=True) """Maximum number of tokens kept for a document.""" query_augmentation: Param[bool] = field(default=True) """Whether to apply ColBERT's query augmentation: queries shorter than ``query_maxlen`` are right-padded with ``[MASK]`` tokens (instead of ``[PAD]``) and every position participates in MaxSim. This mirrors the original ColBERT implementation. Disable to use plain padded queries with padding excluded from MaxSim.""" def __initialize__(self): super().__initialize__() hidden = self.encoder.dimension self._projection = nn.Linear(hidden, self.dim, bias=False) if self.query_augmentation: self._mask_token_id = self._lookup_mask_token_id(self._query_encoder) @staticmethod def _lookup_mask_token_id(encoder) -> int: """Locate the underlying HF tokenizer's ``mask_token_id`` by walking nested ``tokenizer`` attributes.""" obj = encoder seen = set() while obj is not None and id(obj) not in seen: seen.add(id(obj)) mask_id = getattr(obj, "mask_token_id", None) if mask_id is not None: return mask_id obj = getattr(obj, "tokenizer", None) raise ValueError( "Could not locate mask_token_id on the query encoder; set " "query_augmentation=False or use an encoder backed by an HF " "tokenizer exposing [MASK]." ) @property def dimension(self) -> int: """Projection dimension (returned per token).""" return self.dim # ------------------------------------------------------------------ utils def _project( self, output: TokensRepresentationOutput ) -> TokensRepresentationOutput: """Project the per-token vectors to ``dim`` and L2-normalise them.""" value = self._projection(output.value) value = torch.nn.functional.normalize(value, p=2, dim=-1) return evolve(output, value=value) @staticmethod def _token_mask(output: TokensRepresentationOutput) -> Optional[torch.Tensor]: mask = output.tokenized.mask if mask is None: return None return mask.to(output.value.device).bool()
[docs] def document_token_embeddings( self, records: List[IDTextRecord] ) -> List[torch.Tensor]: """Encode a batch of documents and return the list of per-token embeddings, one tensor ``(num_tokens, dim)`` per document. Padding positions are filtered out. """ output = self.encode_documents(records) mask = self._token_mask(output) value = output.value if mask is None: return [value[i] for i in range(value.shape[0])] return [value[i][mask[i]] for i in range(value.shape[0])]
[docs] def query_token_embeddings(self, records: List[IDTextRecord]) -> torch.Tensor: """Encode a batch of queries and return a dense ``(batch, query_maxlen, dim)`` tensor suitable for fast-plaid search. """ return self.encode_queries(records).value
# ------------------------------------------------------------- encoding
[docs] def encode_queries(self, records: List[IDTextRecord]) -> TokensRepresentationOutput: options = TokenizerOptions(max_length=self.query_maxlen) if self.query_augmentation: tokenized = self._query_encoder.tokenize(records, options=options) tokenized = self._mask_pad_query(tokenized) output = self._query_encoder.forward_tokenized(tokenized) else: output = self._query_encoder(records, options=options) return self._project(output)
def _mask_pad_query(self, tokenized: TokenizedTexts) -> TokenizedTexts: """Apply ColBERT query augmentation: right-pad to ``query_maxlen``, replace every padding position with ``[MASK]`` and set the attention mask to 1 over those positions so they participate in MaxSim. """ ids = tokenized.ids mask = tokenized.mask token_type_ids = tokenized.token_type_ids batch_size, current_len = ids.shape pad_len = self.query_maxlen - current_len if pad_len > 0: id_pad = torch.zeros( (batch_size, pad_len), dtype=ids.dtype, device=ids.device ) ids = torch.cat([ids, id_pad], dim=1) if mask is not None: mask_pad = torch.zeros( (batch_size, pad_len), dtype=mask.dtype, device=mask.device ) mask = torch.cat([mask, mask_pad], dim=1) if token_type_ids is not None: tt_pad = torch.zeros( (batch_size, pad_len), dtype=token_type_ids.dtype, device=token_type_ids.device, ) token_type_ids = torch.cat([token_type_ids, tt_pad], dim=1) if mask is not None: pad_positions = mask == 0 ids = ids.masked_fill(pad_positions, self._mask_token_id) mask = torch.ones_like(mask) return TokenizedTexts( tokens=tokenized.tokens, ids=ids, lens=[self.query_maxlen] * batch_size, mask=mask, token_type_ids=token_type_ids, )
[docs] def encode_documents( self, records: List[IDTextRecord] ) -> TokensRepresentationOutput: options = TokenizerOptions(max_length=self.doc_maxlen) output = self.encoder(records, options=options) return self._project(output)
# --------------------------------------------------------------- scoring def _max_sim( self, queries: TokensRepresentationOutput, documents: TokensRepresentationOutput, all_pairs: bool, ) -> torch.Tensor: """Compute the MaxSim operator. When ``all_pairs`` is True, returns an ``(Nq, Nd)`` matrix of scores between every query and every document; otherwise returns a vector of ``Nq == Nd`` scores for the aligned query/document pairs. """ q = queries.value d = documents.value doc_mask = self._token_mask(documents) query_mask = self._token_mask(queries) neg_inf = torch.finfo(q.dtype).min if all_pairs: # scores: (Nq, Lq, Nd, Ld) scores = torch.einsum("qmd,nkd->qmnk", q, d) if doc_mask is not None: scores = scores.masked_fill( ~doc_mask.unsqueeze(0).unsqueeze(0), neg_inf ) # max over doc tokens -> (Nq, Lq, Nd) max_scores = scores.max(dim=-1).values if query_mask is not None: max_scores = max_scores * query_mask.unsqueeze(-1).to(max_scores.dtype) return max_scores.sum(dim=1) # Paired scoring: expects Nq == Nd scores = torch.einsum("nmd,nkd->nmk", q, d) if doc_mask is not None: scores = scores.masked_fill(~doc_mask.unsqueeze(1), neg_inf) max_scores = scores.max(dim=-1).values # (N, Lq) if query_mask is not None: max_scores = max_scores * query_mask.to(max_scores.dtype) return max_scores.sum(dim=-1)
[docs] def score_product( self, queries: TokensRepresentationOutput, documents: TokensRepresentationOutput, info: Optional[TrainerContext] = None, ) -> torch.Tensor: return self._max_sim(queries, documents, all_pairs=True)
[docs] def score_pairs( self, queries: TokensRepresentationOutput, documents: TokensRepresentationOutput, info: Optional[TrainerContext] = None, ) -> torch.Tensor: return self._max_sim(queries, documents, all_pairs=False)
# ------------------------------------------------------ (de)serialisation def save_model(self, path: Path): path.mkdir(parents=True, exist_ok=True) self.encoder.save_model(path / "encoder") if self.query_encoder is not None and self.query_encoder is not self.encoder: self._query_encoder.save_model(path / "query_encoder") torch.save(self._projection.state_dict(), path / "projection.pth") def load_model(self, path: Path): if (path / "encoder").exists(): self.encoder.load_model(path / "encoder") if (path / "query_encoder").exists() and self.query_encoder is not None: self._query_encoder.load_model(path / "query_encoder") proj_path = path / "projection.pth" if proj_path.exists(): self._projection.load_state_dict(torch.load(proj_path, map_location="cpu"))
[docs] class PylateColBERT(AbstractModuleScorer): """Interface with Pylate to use a ColBERT model as a scorer.""" """ This classs isn't working as of right now. It needs specific changes to the toml file to accomodate pylate requirements.""" model_id: Param[str] """The HuggingFace model ID or path.""" dim: Param[int] = field(default=128, ignore_default=True) """Output dimension of the per-token projection.""" query_maxlen: Param[int] = field(default=32, ignore_default=True) """Maximum number of tokens kept for a query.""" doc_maxlen: Param[int] = field(default=180, ignore_default=True) """Maximum number of tokens kept for a document.""" def __initialize__(self): super().__initialize__() try: from pylate import models except Exception: # ImportError or if pylate not available for any reason raise ImportError( "Pylate is not available. Please install pylate to use PylateColBERT." ) self.pl_model = models.ColBERT( self.model_id, document_length=self.doc_maxlen, query_length=self.query_maxlen, embedding_size=self.dim, ) self.pl_model.compile() self._initialized = True def _ensure_tensor_batch(self, representations: object) -> torch.Tensor: if isinstance(representations, torch.Tensor): return representations if isinstance(representations, list): return torch.stack(representations) raise TypeError( "Expected a torch.Tensor or list[torch.Tensor] from the Pylate model" ) @property def dimension(self) -> int: """Projection dimension (returned per token).""" return self.dim def document_token_embeddings( self, records: List[IDTextRecord] ) -> List[torch.Tensor]: """Encode a batch of documents and return the list of per-token embeddings, one tensor ``(num_tokens, dim)`` per document. Padding positions are filtered out. """ return self.pl_model.encode_document( records, normalize_embeddings=True, convert_to_tensor=True ) def query_token_embeddings(self, records: List[IDTextRecord]) -> torch.Tensor: """Encode a batch of queries and return a dense ``(batch, query_maxlen, dim)`` tensor suitable for fast-plaid search. """ return self.pl_model.encode_query( records, normalize_embeddings=True, convert_to_tensor=True ) def encode_documents(self, records: Iterable[IDTextRecord]) -> DocsRep: """Encode a list of texts (document or query) The return value is model dependent""" representations = self.pl_model.encode( [record["text_item"].text for record in records], normalize_embeddings=True, convert_to_tensor=True, is_query=False, ) return self._ensure_tensor_batch(representations) def encode_queries(self, records: Iterable[IDTextRecord]) -> QueriesRep: """Encode a list of texts (document or query) The return value is model dependent, but should be sequence By default, uses `merge` """ representations = self.pl_model.encode( [record["text_item"].text for record in records], normalize_embeddings=True, convert_to_tensor=True, is_query=True, ) return self._ensure_tensor_batch(representations) # --------------------------------------------------------------- scoring def _max_sim( self, queries: torch.Tensor, documents: torch.Tensor, all_pairs: bool, ) -> torch.Tensor: """Compute the MaxSim operator. When ``all_pairs`` is True, returns an ``(Nq, Nd)`` matrix of scores between every query and every document; otherwise returns a vector of ``Nq == Nd`` scores for the aligned query/document pairs. """ if all_pairs: return self.pl_model.similarity_pairwise(queries, documents) return self.pl_model.similarity(queries, documents) def score_product( self, queries: TokensRepresentationOutput, documents: TokensRepresentationOutput, info: Optional[TrainerContext] = None, ) -> torch.Tensor: return self._max_sim(queries, documents, all_pairs=True) def score_pairs( self, queries: TokensRepresentationOutput, documents: TokensRepresentationOutput, info: Optional[TrainerContext] = None, ) -> torch.Tensor: return self._max_sim(queries, documents, all_pairs=False) def forward( self, inputs: BaseItems, info: Optional[TrainerContext] = None, **kwargs ): # Forward to model enc_queries = self.encode_queries(list(inputs.unique_queries)) enc_documents = self.encode_documents(list(inputs.unique_documents)) # Score product if isinstance(inputs, ProductItems): return self.score_product( enc_queries, enc_documents, info, ).flatten() # Score pairs pairs = inputs.pairs() q_ix, d_ix = pairs return self.score_pairs( enc_queries[q_ix], enc_documents[d_ix], info, ).flatten() # ------------------------------------------------------ (de)serialisation def save_model(self, path: Path): path.mkdir(parents=True, exist_ok=True) # self.encoder.save_model(path / "encoder") # if self.query_encoder is not None and self.query_encoder is not self.encoder: # self._query_encoder.save_model(path / "query_encoder") self.pl_model.save(path / "model.pth") def load_model(self, path: Path): if (path / "model.pth").exists(): self.pl_model.load(path / "model.pth")
# proj_path = path / "projection.pth" # if proj_path.exists(): # self._projection.load_state_dict(torch.load(proj_path, map_location="cpu"))
[docs] class InitPylateColBERT(LightweightTask): """Initializes the PylateColBERT by loading the model.""" model: Param[AbstractModuleScorer] def execute(self): self.model.initialize()
def pylate_colbert( model_id: str, document_length: int, query_length: int, embedding_size: int ) -> Tuple[PylateColBERT, List[LightweightTask]]: """Creates an PylateColBERT model. :param model_id: The HuggingFace model ID :param document_length: The maximum length of documents :param query_length: The maximum length of queries :param embedding_size: The size of the embedding vectors :returns: (PylateColBERT, init_tasks) """ scorer = PylateColBERT.C( model_id=model_id, doc_maxlen=document_length, query_maxlen=query_length, dim=embedding_size, ).tag("model_type", "colbert") return scorer, [InitPylateColBERT.C(model=scorer)]