Source code for xpmir.index.faiss

"""Interface to the Facebook FAISS library

https://github.com/facebookresearch/faiss
"""

from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple
from experimaestro import Config, initializer, PathGenerator
import torch
import numpy as np
from experimaestro import Meta, Task, Param, tqdm, field
import logging
from datamaestro_ir.data import DocumentStore, TextRecord
from xpmir.rankers import Retriever, ScoredDocument
from xpm_torch.batchers import Batcher
from xpmir.text.encoders import TextEncoder

from xpmir.utils.utils import batchiter, foreach
from xpmir.documents.samplers import DocumentSampler
from xpmir.context import Context, Hook, InitializationHook

logger = logging.getLogger(__name__)

try:
    import faiss
except ModuleNotFoundError:
    logging.error("FAISS library is not available (install faiss-cpu or faiss)")
    raise


[docs] class FaissIndex(Config): """FAISS Index""" normalize: Param[bool] """Whether vectors should be normalized (L2)""" faiss_index: Meta[Path] = field( default_factory=PathGenerator("faiss.dat"), ignore_generated=True ) """Path to the file containing the index""" documents: Param[DocumentStore] """The set of documents"""
[docs] class IndexBackedFaiss(FaissIndex, Task): """Constructs a FAISS index backed up by an index During executions, InitializationHooks are used (pre/post) """ encoder: Param[TextEncoder] """Encoder for document texts""" batchsize: Meta[int] = field(default=1, ignore_default=True) """The batch size used when computing representations of documents""" batcher: Meta[Batcher] = field(default=Batcher.C(), ignore_default=True) """The way to prepare batches of documents""" hooks: Param[List[Hook]] = field(default=[], ignore_default=True) """An optional list of hooks""" indexspec: Param[str] """The index type as a factory string See https://github.com/facebookresearch/faiss/wiki/Faiss-indexes for the full list of indices and https://github.com/facebookresearch/faiss/wiki/The-index-factory for the combination of the index factory """ sampler: Param[Optional[DocumentSampler]] """Optional document sampler when training the index -- by default, all the documents from the collection are used""" def full_sampler(self) -> Tuple[int, Iterator[str]]: """Returns an iterator over the full set of documents""" iter = (d for d in self.documents.iter_documents()) return self.documents.documentcount or 0, iter def train( self, index: faiss.Index, batch_encoder: Callable[[Iterator[str]], Iterator[torch.Tensor]], ): """train the index params index: batch_encoder: function, input is a iterator of list of documents str, return the encoded document vector(tensor of shape (bs*dimension)) """ logger.info("Building index") count, iter = ( self.sampler() if self.sampler is not None else self.full_sampler() ) doc_iter = tqdm( iter, total=count, desc="Collecting the representation of documents (train)" ) # Collect batches (in memory) logger.info("Collecting the representation of %d documents", count) sample = np.ndarray((count, self.encoder.dimension), dtype=np.float32) ix = 0 for batch in batch_encoder(doc_iter): sample[ix : (ix + len(batch))] = batch.cpu().numpy() ix += len(batch) logger.info("Training index (%d samples)", count) # Here we may use just a part of the document to train the index index.train(sample) def execute(self): # Initialization hooks context = Context(hooks=self.hooks) foreach(context.hooks(InitializationHook), lambda hook: hook.before(context)) step_iter = tqdm(total=2, desc="Building the FAISS index") # Initializations self.encoder.initialize() index = faiss.index_factory( self.encoder.dimension, self.indexspec, faiss.METRIC_INNER_PRODUCT ) batcher = self.batcher.initialize(self.batchsize) self.encoder.eval() # Train the index if not index.is_trained: with torch.no_grad(): logging.info("Training FAISS index (%d documents)", index.ntotal) def batch_encoder(doc_iter: Iterator[str]): for batch in batchiter(self.batchsize, doc_iter): data = [] batcher.process(batch, self.encode, data) yield torch.cat(data) self.train(index, batch_encoder) step_iter.update() # Index the collection doc_iter = tqdm( self.documents.iter_documents(), total=self.documents.documentcount, desc="Indexing the collection", ) # Initialization hooks (after) foreach(context.hooks(InitializationHook), lambda hook: hook.after(context)) # Let's index ! # We add index for all the documents with torch.no_grad(): for batch in batchiter(self.batchsize, doc_iter): batcher.process( [document["text_item"].text for document in batch], self.index_documents, index, ) logging.info("Writing FAISS index (%d documents)", index.ntotal) faiss.write_index(index, str(self.faiss_index)) step_iter.update() def encode(self, batch: List[str], data: List): batch = [ text for text in batch if text != "" ] # remove the empty strings in the dataset (training only) x = self.encoder(batch) if self.normalize: x /= x.norm(2, keepdim=True, dim=1) data.append(x) def index_documents(self, batch: List[str], index): x = self.encoder(batch).value if self.normalize: x /= x.norm(2, keepdim=True, dim=1) index.add(np.ascontiguousarray(x.cpu().numpy()))
[docs] class FaissRetriever(Retriever): """Retriever based on Faiss""" encoder: Param[TextEncoder] """The query encoder""" index: Param[FaissIndex] """The faiss index""" topk: Param[int] """the number of documents to be retrieved""" @initializer def initialize(self): logger.info("FAISS retriever (1/2): initializing the encoder") self.encoder.initialize() logger.info("FAISS retriever (2/2): reading the index") self._index = faiss.read_index(str(self.index.faiss_index)) logger.info("FAISS retriever: initialized") def retrieve(self, query: TextRecord) -> List[ScoredDocument]: """Retrieves a documents, returning a list sorted by decreasing score""" with torch.no_grad(): self.encoder.eval() # pass the model to the evaluation model encoded_query = self.encoder([query["text_item"].text]).value if self.index.normalize: encoded_query /= encoded_query.norm(2) values, indices = self._index.search( np.ascontiguousarray(encoded_query.cpu().numpy(), dtype=np.float32), self.topk, ) return [ ScoredDocument(self.index.documents.document_int(int(ix)), float(value)) for ix, value in zip(indices[0], values[0]) if ix >= 0 ]