"""Interface to the fast-plaid library.
`fast-plaid <https://github.com/lightonai/fast-plaid>`_ is a Rust-based
implementation of PLAID / ColBERT late-interaction retrieval. This module
wraps it to build and query an index from a
:class:`~xpmir.neural.colbert.ColBERTEncoder`.
Three classes are exposed:
- :class:`PlaidIndex` — the index configuration (paths, metadata). Supports
retrieving per-document token vectors via :meth:`PlaidIndex.get_document_tokens`
using fast-plaid's compressed centroid+residual storage.
- :class:`PlaidIndexBuilder` — a :class:`~experimaestro.Task` that encodes a
:class:`~datamaestro_ir.data.DocumentStore` and builds the fast-plaid index.
- :class:`PlaidRetriever` — a :class:`~xpmir.rankers.Retriever` that searches
the index given a query.
"""
from __future__ import annotations
import json
import logging
import shutil
from pathlib import Path
from typing import List
import torch
from experimaestro import (
Config,
Meta,
Param,
PathGenerator,
Task,
field,
tqdm,
)
from datamaestro_ir.data import DocumentStore, IDTextRecord
from xpm_torch.configuration import FabricConfiguration
from xpmir.rankers import Retriever, ScoredDocument
from xpmir.rankers.scorer import AbstractModuleScorer
from xpmir.text.encoders import TextEncoderBase
from xpmir.utils.utils import batchiter
logger = logging.getLogger(__name__)
def _import_fast_plaid():
try:
from fast_plaid import search as fp_search # noqa: WPS433
except ModuleNotFoundError as exc: # pragma: no cover - guard clause
raise ModuleNotFoundError(
"fast-plaid is not installed. Install it with `pip install fast-plaid`"
" (see https://github.com/lightonai/fast-plaid)."
) from exc
return fp_search
# File layout (under ``index_path``):
#
# plaid/ — fast-plaid index directory
# metadata.json — dim, num_docs, n_bits
#
_PLAID_SUBDIR = "plaid"
_METADATA_FILE = "metadata.json"
[docs]
class PlaidIndex(Config):
"""A ColBERT / PLAID index backed by `fast-plaid`_.
The index stores per-token document embeddings in fast-plaid's compressed
centroid + residual format. Per-document token vectors can be
reconstructed (approximately) via :meth:`get_document_tokens`, which
delegates to fast-plaid's ``get_embeddings`` method. The reconstruction
quality is controlled by :attr:`n_bits`.
When :attr:`compress_only` is ``True`` the index only contains the
compressed vectors (centroids + quantised residuals) without the IVF
search structure. This is cheaper to build and sufficient when only
:meth:`get_document_tokens` is needed. Attempting to search a
compress-only index via :class:`PlaidRetriever` will raise an error.
.. _fast-plaid: https://github.com/lightonai/fast-plaid
"""
documents: Param[DocumentStore]
"""Set of documents to index."""
compress_only: Param[bool] = False
index_path: Meta[Path]
"""Directory containing the fast-plaid index and side-car files."""
device: Meta[str] = field(default="", ignore_default=True)
"""Device used to load the index for :meth:`get_document_tokens`
(``""`` = auto). Fixed at first use because the underlying
``FastPlaid`` instance is cached."""
in_memory: Meta[bool] = field(default=False, ignore_default=True)
"""If ``True``, load the index fully into device memory (passes
``low_memory=False`` to fast-plaid). Use when the index fits in
VRAM/RAM and you want faster decompression/search; otherwise the
document codes and residuals stay memory-mapped from disk."""
def _plaid_dir(self) -> Path:
return self.index_path / _PLAID_SUBDIR
def _get_fast_plaid(self):
"""Return a cached ``FastPlaid`` instance for this index.
The instance is constructed lazily on first access and reused
afterwards so that subsequent calls (e.g. repeated
:meth:`get_document_tokens`) avoid reloading the index.
"""
cached = getattr(self, "_fast_plaid", None)
if cached is not None:
return cached
fp_search = _import_fast_plaid()
fp = fp_search.FastPlaid(
index=str(self._plaid_dir()),
device=self.device or None,
low_memory=not self.in_memory,
)
self._fast_plaid = fp
return fp
[docs]
def get_document_tokens(
self,
docids: list[int | str],
device: str = "",
) -> torch.Tensor:
"""Return the (approximate) per-token embeddings for a document.
The vectors are reconstructed from fast-plaid's compressed
centroid + residual storage using ``FastPlaid.get_embeddings``.
The reconstruction quality depends on :attr:`n_bits`.
:param docid: The document identifiers. Integers are interpreted as
internal positions in the index (``0..num_docs-1``); strings are
looked up in the external-to-internal map written at indexing
time.
:param device: Device for the fast-plaid instance used to decompress
(``""`` = auto).
:returns: A ``(num_tokens, dim)`` float tensor containing the
reconstructed token embeddings.
"""
if isinstance(docids[0], str):
ext2int = self._load_ext2int()
internal_docids: list = []
for docid in docids:
if docid not in ext2int:
raise KeyError(
f"External document id {docid!r} is unknown to this index"
)
internal_docids.append(int(ext2int[docid]))
else:
internal_docids = [int(docid) for docid in docids]
fp_search = _import_fast_plaid()
fp = fp_search.FastPlaid(index=str(self._plaid_dir()), device=device or None)
results = fp.get_embeddings(subset=internal_docids)
return results
[docs]
class PlaidIndexBuilder(Task):
"""Builds a fast-plaid index from a document collection.
The builder encodes every document using the given
:class:`~xpmir.neural.colbert.ColBERTEncoder`, collects the valid (i.e.
non-padding) token vectors, and feeds them to ``fast-plaid``.
The fast-plaid index stores the embeddings in a compressed
centroid + residual format, so no separate raw-token file is needed.
Per-document token vectors can be reconstructed later via
:meth:`PlaidIndex.get_document_tokens`.
"""
documents: Param[DocumentStore]
"""Set of documents to index."""
encoder: Param[TextEncoderBase]
"""The ColBERT-style encoder used to produce per-token embeddings."""
batch_size: Meta[int] = field(default=32, ignore_default=True)
"""Encoder batch size. Warning, different from the batch size used internally by fast-plaid ('fast_plaid_batch_size')"""
buffer_size: Param[int] = field(default=1000, ignore_default=True)
"""Number of documents to encode and accumulate in RAM before creating/updating the fast-plaid index and fitting the centroids.
The token embeddings used to initialize the centroids will be sampled randomly from those documents by plaid
(or they will all be used if n_samples_kmeans is 0)."""
fast_plaid_batch_size: Meta[int] = field(default=32, ignore_default=True)
"""Fast plaid internal batch size."""
n_bits: Param[int] = field(default=2, ignore_default=True)
"""Number of bits used by fast-plaid for residual quantisation."""
kmeans_niters: Param[int] = field(default=4, ignore_default=True)
"""Number of K-means iterations performed by fast-plaid when clustering
the centroids."""
n_samples_kmeans: Param[int] = field(default=0, ignore_default=True)
"""Number of token samples used to train the centroids (0 = fast-plaid
default)."""
max_points_per_centroid: Param[int] = field(default=256, ignore_default=True)
"""Maximum number of points (documents) per centroid. Controls the creation of new centroids."""
seed: Param[int] = field(default=42, ignore_default=True)
"""Random seed for reproducibility (passed to fast-plaid's index creation)."""
compress_only: Param[bool] = field(default=False, ignore_default=True)
"""When ``True``, skip IVF construction. The resulting index supports
:meth:`PlaidIndex.get_document_tokens` but not search via
:class:`PlaidRetriever`.
Requires fast-plaid support for ``compress_only``
(see `lightonai/fast-plaid#41 <https://github.com/lightonai/fast-plaid/pull/41>`_).
Falls back to building the full index with a warning if unsupported."""
low_memory: Param[bool] = field(default=True)
"""https://github.com/lightonai/fast-plaid#-search-speed-tip-low_memoryfalse
If index fits on VRAM, set to False for faster search. Otherwise, keep True to avoid OOM errors."""
force_cpu_indexing: Param[bool] = field(default=False)
"""When True, forces the use of CPU for indexing even if a GPU is available.
This can be useful to avoid GPU OOM errors during indexing, especially for large corpora."""
fabric_config: Meta[FabricConfiguration] = field(
default_factory=FabricConfiguration.C
)
"""Control the device for the model encoding and fast-plaid index."""
index_path: Meta[Path] = field(default_factory=PathGenerator("plaid-index"))
"""Output directory for the index and its side-car files."""
def task_outputs(self, dep) -> PlaidIndex:
"""Expose a :class:`PlaidIndex` for downstream tasks."""
return dep(
PlaidIndex.C(
documents=self.documents,
index_path=self.index_path,
compress_only=self.compress_only,
)
)
def execute(self):
if self.index_path.exists():
shutil.rmtree(self.index_path)
self.index_path.mkdir(parents=True, exist_ok=True)
# 1. Initialize Fabric first
fabric = self.fabric_config.get_fabric()
fabric.launch()
with fabric.init_module():
self.encoder.initialize()
self.encoder = fabric.setup(self.encoder)
self.encoder.eval()
fp_search = _import_fast_plaid()
plaid_dir = self.index_path / _PLAID_SUBDIR
plaid_dir.mkdir(parents=True, exist_ok=True)
device = fabric.device if not self.force_cpu_indexing else "cpu"
fast_plaid = fp_search.FastPlaid(
index=str(plaid_dir), device=device, low_memory=self.low_memory
)
total_docs = self.documents.documentcount or 0
num_docs_seen = 0
index_created = False
doc_buffer: list = []
with torch.no_grad():
pbar = tqdm(
total=total_docs or None,
desc="Encoding documents for fast-plaid",
unit="doc",
)
for batch in batchiter(self.batch_size, self.documents.iter_documents()):
per_doc = self.encoder.document_token_embeddings(batch)
per_doc_cpu = [
t.detach().to("cpu", dtype=torch.float32) for t in per_doc
]
doc_buffer.extend(per_doc_cpu)
if len(doc_buffer) >= self.buffer_size:
logging.debug(
"Warmup buffer filled (%d documents, %d tokens). "
"Creating the fast-plaid index and fitting centroids...",
len(doc_buffer),
sum(t.shape[0] for t in doc_buffer),
)
if not index_created:
# Enough docs accumulated — fit centroids and create index
create_kwargs = {
"documents_embeddings": doc_buffer,
"nbits": self.n_bits,
"kmeans_niters": self.kmeans_niters,
"batch_size": self.fast_plaid_batch_size,
"seed": self.seed,
"max_points_per_centroid": self.max_points_per_centroid,
}
if self.n_samples_kmeans:
create_kwargs["n_samples_kmeans"] = self.n_samples_kmeans
if self.compress_only:
create_kwargs["compress_only"] = True
try:
fast_plaid.create(**create_kwargs)
except TypeError:
if self.compress_only:
logger.warning(
"compress_only is not supported by this "
"version of fast-plaid; building the full "
"index instead. See "
"https://github.com/lightonai/fast-plaid/pull/41"
)
del create_kwargs["compress_only"]
fast_plaid.create(**create_kwargs)
else:
raise
doc_buffer.clear() # free RAM immediately
index_created = True
else:
create_kwargs = {
"kmeans_niters": self.kmeans_niters,
"batch_size": self.fast_plaid_batch_size,
"seed": self.seed,
"max_points_per_centroid": self.max_points_per_centroid,
}
if self.n_samples_kmeans:
create_kwargs["n_samples_kmeans"] = self.n_samples_kmeans
fast_plaid.update(
documents_embeddings=doc_buffer, **create_kwargs
)
doc_buffer.clear() # free RAM immediately
# ID mapping — same logic regardless of create vs update path
num_docs_seen += len(per_doc_cpu)
pbar.update(len(per_doc_cpu))
pbar.close()
# In case the whole corpus was smaller than the warmup buffer, we still want to create the index
if not index_created and doc_buffer:
create_kwargs = {
"documents_embeddings": doc_buffer,
"nbits": self.n_bits,
"kmeans_niters": self.kmeans_niters,
"batch_size": self.fast_plaid_batch_size,
"seed": self.seed,
"max_points_per_centroid": self.max_points_per_centroid,
}
if self.n_samples_kmeans:
create_kwargs["n_samples_kmeans"] = self.n_samples_kmeans
if self.compress_only:
create_kwargs["compress_only"] = True
try:
fast_plaid.create(**create_kwargs)
except TypeError:
if self.compress_only:
logger.warning(
"compress_only is not supported by this "
"version of fast-plaid; building the full "
"index instead. See "
"https://github.com/lightonai/fast-plaid/pull/41"
)
del create_kwargs["compress_only"]
fast_plaid.create(**create_kwargs)
else:
raise
doc_buffer.clear()
elif index_created and doc_buffer:
create_kwargs = {
"kmeans_niters": self.kmeans_niters,
"batch_size": self.fast_plaid_batch_size,
"seed": self.seed,
"max_points_per_centroid": self.max_points_per_centroid,
}
if self.n_samples_kmeans:
create_kwargs["n_samples_kmeans"] = self.n_samples_kmeans
fast_plaid.update(documents_embeddings=doc_buffer, **create_kwargs)
doc_buffer.clear()
else:
logger.info("No documents left to encode.")
with (self.index_path / _METADATA_FILE).open("w") as fh:
json.dump(
{
"num_docs": num_docs_seen,
"dim": self.encoder.dimension,
"n_bits": self.n_bits,
},
fh,
)
logger.info(
"fast-plaid index built: %d documents, dim=%d, n_bits=%d",
num_docs_seen,
self.encoder.dimension,
self.n_bits,
)
[docs]
class PlaidRetriever(Retriever):
"""Retriever using a `fast-plaid`_ PLAID index.
.. _fast-plaid: https://github.com/lightonai/fast-plaid
"""
encoder: Param[AbstractModuleScorer]
"""The query encoder. Typically the same encoder that was used to build
:attr:`index`."""
index: Param[PlaidIndex]
"""The fast-plaid index to search."""
topk: Param[int]
"""Number of documents to return per query."""
n_ivf_probe: Meta[int] = field(default=8, ignore_default=True)
"""Number of inverted-list clusters explored by fast-plaid at search
time."""
n_full_scores: Meta[int] = field(default=0, ignore_default=True)
"""Number of candidates for which fast-plaid computes full scores
(0 = fast-plaid default)."""
fabric_config: Meta[FabricConfiguration] = field(
default_factory=FabricConfiguration.C
)
"""Control the device for the model encoding and fast-plaid index."""
def initialize(self):
super().initialize()
if self.index.compress_only:
raise RuntimeError(
"Cannot search a compress-only PLAID index. "
"Rebuild with compress_only=False to enable retrieval."
)
logger.info("PLAID retriever (1/2): initializing the encoder")
# 1. Initialize Fabric first
fabric = self.fabric_config.get_fabric()
fabric.launch()
with fabric.init_module():
self.encoder.initialize()
self.encoder = fabric.setup(self.encoder)
self.encoder.eval()
logger.info("PLAID retriever (2/2): opening the fast-plaid index")
fp_search = _import_fast_plaid()
device = fabric.device or None
self._fast_plaid = fp_search.FastPlaid(
index=str(self.index._plaid_dir()), device=device
)
def _store(self):
return self.index.documents
def retrieve(self, record: IDTextRecord) -> List[ScoredDocument]:
with torch.no_grad():
queries_embeddings = self.encoder.query_token_embeddings([record])
queries_embeddings = queries_embeddings.detach().to(
"cpu", dtype=torch.float32
)
search_kwargs = {"top_k": self.topk, "n_ivf_probe": self.n_ivf_probe}
if self.n_full_scores:
search_kwargs["n_full_scores"] = self.n_full_scores
results = self._fast_plaid.search(
queries_embeddings=queries_embeddings,
**search_kwargs,
)
single = results[0] if results else []
documents = self.index.documents
out: List[ScoredDocument] = []
for doc_index, score in single:
out.append(
ScoredDocument(documents.document_int(int(doc_index)), float(score))
)
return out