Source code for xpmir.index.faiss

"""Interface to the Facebook FAISS library

https://github.com/facebookresearch/faiss
"""

from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple
from experimaestro import Config, initializer
import torch
import numpy as np
from experimaestro import Annotated, Meta, Task, pathgenerator, Param, tqdm
import logging
from datamaestro_text.data.ir import DocumentStore
from xpmir.rankers import Retriever, ScoredDocument
from xpmir.learning.batchers import Batcher
from xpmir.learning import ModuleInitMode
from xpmir.text.encoders import TextEncoder
from xpmir.letor import (
    Device,
    DEFAULT_DEVICE,
    DeviceInformation,
)
from xpmir.utils.utils import batchiter, easylog, foreach
from xpmir.documents.samplers import DocumentSampler
from xpmir.context import Context, Hook, InitializationHook

logger = easylog()

try:
    import faiss
except ModuleNotFoundError:
    logging.error("FAISS library is not available (install faiss-cpu or faiss)")
    raise


[docs]class FaissIndex(Config): """FAISS Index""" normalize: Param[bool] """Whether vectors should be normalized (L2)""" faiss_index: Annotated[Path, pathgenerator("faiss.dat")] """Path to the file containing the index""" documents: Param[DocumentStore] """The set of documents"""
[docs]class IndexBackedFaiss(FaissIndex, Task): """Constructs a FAISS index backed up by an index During executions, InitializationHooks are used (pre/post) """ encoder: Param[TextEncoder] """Encoder for document texts""" batchsize: Meta[int] = 1 """The batch size used when computing representations of documents""" device: Meta[Device] = DEFAULT_DEVICE """The device used by the encoder""" batcher: Meta[Batcher] = Batcher() """The way to prepare batches of documents""" hooks: Param[List[Hook]] = [] """An optional list of hooks""" indexspec: Param[str] """The index type as a factory string See https://github.com/facebookresearch/faiss/wiki/Faiss-indexes for the full list of indices and https://github.com/facebookresearch/faiss/wiki/The-index-factory for the combination of the index factory """ sampler: Param[Optional[DocumentSampler]] """Optional document sampler when training the index -- by default, all the documents from the collection are used""" def full_sampler(self) -> Tuple[int, Iterator[str]]: """Returns an iterator over the full set of documents""" iter = (d.text for d in self.documents.iter_documents()) return self.documents.documentcount or 0, iter def train( self, index: faiss.Index, batch_encoder: Callable[[Iterator[str]], Iterator[torch.Tensor]], ): """train the index params index: batch_encoder: function, input is a iterator of list of documents str, return the encoded document vector(tensor of shape (bs*dimension)) """ logger.info("Building index") count, iter = ( self.sampler() if self.sampler is not None else self.full_sampler() ) doc_iter = tqdm( iter, total=count, desc="Collecting the representation of documents (train)" ) # Collect batches (in memory) logger.info("Collecting the representation of %d documents", count) sample = np.ndarray((count, self.encoder.dimension), dtype=np.float32) ix = 0 for batch in batch_encoder(doc_iter): sample[ix : (ix + len(batch))] = batch.cpu().numpy() ix += len(batch) logger.info("Training index (%d samples)", count) # Here we may use just a part of the document to train the index index.train(sample) def execute(self): self.device.execute(self._execute) def _execute(self, device_information: DeviceInformation): # Initialization hooks context = Context(device_information, hooks=self.hooks) foreach(context.hooks(InitializationHook), lambda hook: hook.before(context)) step_iter = tqdm(total=2, desc="Building the FAISS index") # Initializations self.encoder.initialize(ModuleInitMode.DEFAULT.to_options()) index = faiss.index_factory( self.encoder.dimension, self.indexspec, faiss.METRIC_INNER_PRODUCT ) batcher = self.batcher.initialize(self.batchsize) # Change the device of the encoder self.encoder.to(device_information.device).eval() # Train the index if not index.is_trained: with torch.no_grad(): logging.info("Training FAISS index (%d documents)", index.ntotal) def batch_encoder(doc_iter: Iterator[str]): for batch in batchiter(self.batchsize, doc_iter): data = [] batcher.process(batch, self.encode, data) yield torch.cat(data) self.train(index, batch_encoder) step_iter.update() # Index the collection doc_iter = ( tqdm( self.documents.iter_documents(), total=self.documents.documentcount, desc="Indexing the collection", ) if device_information.main else self.documents.iter_documents() ) # Initialization hooks (after) foreach(context.hooks(InitializationHook), lambda hook: hook.after(context)) # Let's index ! # We add index for all the documents with torch.no_grad(): for batch in batchiter(self.batchsize, doc_iter): batcher.process( [document.text for document in batch], self.index_documents, index ) logging.info("Writing FAISS index (%d documents)", index.ntotal) faiss.write_index(index, str(self.faiss_index)) step_iter.update() def encode(self, batch: List[str], data: List): batch = [ text for text in batch if text != "" ] # remove the empty strings in the dataset (training only) x = self.encoder(batch) if self.normalize: x /= x.norm(2, keepdim=True, dim=1) data.append(x) def index_documents(self, batch: List[str], index): x = self.encoder(batch) if self.normalize: x /= x.norm(2, keepdim=True, dim=1) index.add(np.ascontiguousarray(x.cpu().numpy()))
[docs]class FaissRetriever(Retriever): """Retriever based on Faiss""" encoder: Param[TextEncoder] """The query encoder""" index: Param[FaissIndex] """The faiss index""" topk: Param[int] """the number of documents to be retrieved""" @initializer def initialize(self): logger.info("FAISS retriever (1/2): initializing the encoder") self.encoder.initialize() logger.info("FAISS retriever (2/2): reading the index") self._index = faiss.read_index(str(self.index.faiss_index)) logger.info("FAISS retriever: initialized") def retrieve(self, query: str) -> List[ScoredDocument]: """Retrieves a documents, returning a list sorted by decreasing score""" with torch.no_grad(): self.encoder.eval() # pass the model to the evaluation model encoded_query = self.encoder([query]) if self.index.normalize: encoded_query /= encoded_query.norm(2) values, indices = self._index.search(encoded_query.cpu().numpy(), self.topk) return [ ScoredDocument(self.index.documents.document_int(int(ix)), float(value)) for ix, value in zip(indices[0], values[0]) if ix >= 0 ]