Source code for xpmir.index.sparse

"""Index for sparse models"""

import torch
import numpy as np
import sys
from pathlib import Path
from typing import Dict, List, Tuple
from experimaestro import (
    Annotated,
    Config,
    Task,
    Param,
    Meta,
    pathgenerator,
    tqdm,
    Constant,
)
from datamaestro_text.data.ir import Document, DocumentStore
from xpmir.learning import ModuleInitMode
from xpmir.learning.batchers import Batcher
from xpmir.utils.utils import batchiter, easylog
from xpmir.letor import Device, DEFAULT_DEVICE
from xpmir.text.encoders import TextEncoder
from xpmir.rankers import Retriever, ScoredDocument
from xpmir.utils.iter import MultiprocessIterator
import xpmir_rust

logger = easylog()

# --- Index and retriever


[docs]class SparseRetrieverIndex(Config): index_path: Meta[Path] documents: Param[DocumentStore] index: xpmir_rust.index.SparseBuilderIndex ordered = False def initialize(self, in_memory: bool): self.index = xpmir_rust.index.SparseBuilderIndex.load( str(self.index_path.absolute()), in_memory ) def retrieve(self, query: Dict[int, float], top_k: int) -> List[ScoredDocument]: results = [] for sd in self.index.search_maxscore(query, top_k): results.append( ScoredDocument( self.documents.document_int(sd.docid), sd.score, ) ) return results
[docs]class SparseRetriever(Retriever): index: Param[SparseRetrieverIndex] encoder: Param[TextEncoder] topk: Param[int] batcher: Meta[Batcher] = Batcher() """The way to prepare batches of queries (when using retrieve_all)""" batchsize: Meta[int] """Size of batches (when using retrieve_all)""" in_memory: Meta[bool] = False """Whether the index should be fully loaded in memory (otherwise, uses virtual memory)""" def initialize(self): super().initialize() self.encoder.initialize(ModuleInitMode.RANDOM.to_options(None)) self.index.initialize(self.in_memory) def retrieve_all(self, queries: Dict[str, str]) -> Dict[str, List[ScoredDocument]]: """Input queries: {id: text}""" def reducer( batch: List[Tuple[str, str]], results: Dict[str, List[ScoredDocument]], progress, ): for (key, _), vector in zip( batch, self.encoder([text for _, text in batch]).cpu().detach().numpy() ): (ix,) = vector.nonzero() query = {ix: float(v) for ix, v in zip(ix, vector[ix])} results[key] = self.index.retrieve(query, self.topk) progress.update(1) return results self.encoder.eval() batcher = self.batcher.initialize(self.batchsize) results = {} items = list(queries.items()) with tqdm( desc="Retrieve documents", total=len(items), unit="queries" ) as progress: with torch.no_grad(): for batch in batchiter(self.batchsize, items): results = batcher.reduce(batch, reducer, results, progress) return results def retrieve(self, query: str, top_k=None) -> List[ScoredDocument]: """Search with document-at-a-time (DAAT) strategy :param top_k: Overrides the default top-K value """ # Build up iterators vector = self.encoder([query])[0].cpu().detach().numpy() (ix,) = vector.nonzero() # ix represents the position without 0 in the vector query = { ix: float(v) for ix, v in zip(ix, vector[ix]) } # generate a dict: {position:value} return self.index.retrieve(query, top_k or self.topk)
[docs]class SparseRetrieverIndexBuilder(Task): """Builds an index from a sparse representation Assumes that document and queries have the same dimension, and that the score is computed through an inner product """ documents: Param[DocumentStore] """Set of documents to index""" encoder: Param[TextEncoder] """The encoder""" batcher: Meta[Batcher] = Batcher() """Batcher used when computing representations""" batch_size: Param[int] """Size of batches""" ordered_index: Param[bool] """Ordered index: if not ordered, use DAAT strategy (WAND), otherwise, use fast top-k strategies""" device: Meta[Device] = DEFAULT_DEVICE max_postings: Meta[int] = 16384 """Maximum number of postings (per term) before flushing to disk""" index_path: Annotated[Path, pathgenerator("index")] in_memory: Meta[bool] = False """Whether the index should be fully loaded in memory (otherwise, uses virtual memory)""" version: Constant[int] = 3 """Version 3 of the index""" max_docs: Param[int] = 0 """Maximum number of indexed documents""" def task_outputs(self, dep): """Returns a sparse retriever index that can be used by a SparseRetriever to search efficiently for documents""" return dep( SparseRetrieverIndex(index_path=self.index_path, documents=self.documents) ) def execute(self): # Encode all documents logger.info( f"Load the encoder and transfer to the target device {self.device.value}" ) self.encoder.initialize(ModuleInitMode.RANDOM.to_options(None)) self.encoder.to(self.device.value).eval() batcher = self.batcher.initialize(self.batch_size) doc_iter = tqdm( zip( range(sys.maxsize if self.max_docs == 0 else self.max_docs), MultiprocessIterator(self.documents.iter_documents()), ), total=self.documents.documentcount if self.max_docs == 0 else min(self.max_docs, self.documents.documentcount), desc="Building the index", ) # Create the index builder from shutil import rmtree import xpmir_rust if self.index_path.is_dir(): rmtree(self.index_path) self.index_path.mkdir(parents=True) self.indexer = xpmir_rust.index.SparseIndexer(str(self.index_path)) # Index logger.info(f"Starting to index {self.documents.documentcount} documents") with torch.no_grad(): for batch in batchiter(self.batch_size, doc_iter): batcher.process(batch, self.encode_documents) # Build the index self.indexer.build(self.in_memory) def encode_documents(self, batch: List[Tuple[int, Document]]): # Assumes for now dense vectors vectors = ( self.encoder([d.get_text() for _, d in batch]).cpu().numpy() ) # bs * vocab for vector, (docid, _) in zip(vectors, batch): (nonzero_ix,) = vector.nonzero() self.indexer.add(docid, nonzero_ix.astype(np.uint64), vector[nonzero_ix])