Source code for xpmir.index.sparse

"""Index for sparse models"""

import asyncio
from functools import cached_property
import logging
import threading
import heapq
import torch
from queue import Empty
import torch.multiprocessing as mp
import numpy as np
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Generic, Iterator, Union
from attrs import define
from experimaestro import (
    Annotated,
    Config,
    Task,
    Param,
    Meta,
    pathgenerator,
    tqdm,
    Constant,
)
from datamaestro_text.data.ir import DocumentRecord, DocumentStore, TextItem
from xpmir.learning import ModuleInitMode
from xpmir.learning.batchers import Batcher
from xpmir.utils.utils import batchiter, easylog
from xpmir.letor import Device, DeviceInformation, DEFAULT_DEVICE
from xpmir.text.encoders import TextEncoderBase, TextsRepresentationOutput, InputType
from xpmir.rankers import Retriever, TopicRecord, ScoredDocument
from xpmir.utils.iter import MultiprocessIterator
from xpmir.utils.multiprocessing import StoppableQueue, available_cpus
import xpmir_rust

logger = easylog()

# --- Index and retriever


[docs]class SparseRetrieverIndex(Config): index_path: Meta[Path] documents: Param[DocumentStore] index: xpmir_rust.index.SparseBuilderIndex ordered = False def initialize(self, in_memory: bool): self.index = xpmir_rust.index.SparseBuilderIndex.load( str(self.index_path.absolute()), in_memory ) def retrieve(self, query: Dict[int, float], top_k: int) -> List[ScoredDocument]: results = [] for sd in self.index.search_maxscore(query, top_k): results.append( ScoredDocument( self.documents.document_int(sd.docid), sd.score, ) ) return results async def aio_retrieve( self, query: Dict[int, float], top_k: int ) -> List[ScoredDocument]: results = [] for sd in await self.index.aio_search_maxscore(query, top_k): results.append( ScoredDocument( self.documents.document_int(sd.docid), sd.score, ) ) return results
[docs]class SparseRetriever(Retriever, Generic[InputType]): index: Param[SparseRetrieverIndex] encoder: Param[TextEncoderBase[InputType, torch.Tensor]] topk: Param[int] device: Meta[Device] = DEFAULT_DEVICE """The device for building the index""" batcher: Meta[Batcher] = Batcher() """The way to prepare batches of queries (when using retrieve_all)""" batchsize: Meta[int] """Size of batches (when using retrieve_all)""" in_memory: Meta[bool] = False """Whether the index should be fully loaded in memory (otherwise, uses virtual memory)""" def initialize(self): super().initialize() logging.info("Initializing the encoder") self.encoder.initialize(ModuleInitMode.DEFAULT.to_options(None)) self.encoder.to(self.device.value) logging.info("Initializing the index") self.index.initialize(self.in_memory) def retrieve_all( self, queries: Dict[str, InputType] ) -> Dict[str, List[ScoredDocument]]: """Input queries: {id: text}""" async def aio_search_worker(progress, results: Dict, queue: asyncio.Queue): try: while True: key, query, topk = await queue.get() results[key] = await self.index.aio_retrieve(query, topk) progress.update(1) queue.task_done() except asyncio.exceptions.CancelledError: # Just stopped pass except Exception: logging.exception("Error in worker thread") async def reducer( batch: List[Tuple[str, InputType]], queue: asyncio.Queue, ): for (key, _), vector in zip( batch, self.encoder([text for _, text in batch]).value.cpu().detach().numpy(), ): (ix,) = vector.nonzero() query = {ix: float(v) for ix, v in zip(ix, vector[ix])} logging.debug("Adding topic %s to the queue", key) await queue.put((key, query, self.topk)) logging.debug("[done] Adding topic %s to the queue", key) async def aio_process(): workers = [] results = {} try: queue = asyncio.Queue(available_cpus()) items = list(queries.items()) with tqdm( desc="Retrieve documents", total=len(items), unit="queries" ) as progress: for _ in range(available_cpus()): worker = asyncio.create_task( aio_search_worker(progress, results, queue) ) workers.append(worker) self.encoder.eval() batcher = self.batcher.initialize(self.batchsize) with torch.no_grad(): for batch in batchiter(self.batchsize, items): await batcher.aio_reduce(batch, reducer, queue) await queue.join() finally: for worker in workers: worker.cancel() return results results = asyncio.run(aio_process()) return results def retrieve(self, query: TopicRecord, top_k=None) -> List[ScoredDocument]: """Search with document-at-a-time (DAAT) strategy :param top_k: Overrides the default top-K value """ # Build up iterators vector = self.encoder([query]).value[0].cpu().detach().numpy() (ix,) = vector.nonzero() # ix represents the position without 0 in the vector query = { ix: float(v) for ix, v in zip(ix, vector[ix]) } # generate a dict: {position:value} return self.index.retrieve(query, top_k or self.topk)
@define(frozen=True) class EncodedDocument: docid: int value: torch.Tensor @define(frozen=True) class DocumentRange: rank: int start: int end: int def __lt__(self, other: "DocumentRange"): return self.start < other.start class DocumentIterator: def __init__(self, documents, max_docs, batch_size): self.documents = documents self.max_docs = max_docs self.batch_size = batch_size @cached_property def iterator(self): return batchiter( self.batch_size, zip( range(sys.maxsize if self.max_docs == 0 else self.max_docs), self.documents.iter_documents(), ), ) def __next__(self): return next(self.iterator)
[docs]class SparseRetrieverIndexBuilder(Task, Generic[InputType]): """Builds an index from a sparse representation Assumes that document and queries have the same dimension, and that the score is computed through an inner product """ documents: Param[DocumentStore] """Set of documents to index""" encoder: Param[TextEncoderBase[InputType, TextsRepresentationOutput]] """The encoder""" batcher: Meta[Batcher] = Batcher() """Batcher used when computing representations""" batch_size: Param[int] """Size of batches""" ordered_index: Param[bool] """Ordered index: if not ordered, use DAAT strategy (WAND), otherwise, use fast top-k strategies""" device: Meta[Device] = DEFAULT_DEVICE """The device for building the index""" max_postings: Meta[int] = 16384 """Maximum number of postings (per term) before flushing to disk""" index_path: Annotated[Path, pathgenerator("index")] in_memory: Meta[bool] = False """Whether the index should be fully loaded in memory (otherwise, uses virtual memory)""" version: Constant[int] = 3 """Version 3 of the index""" max_docs: Param[int] = 0 """Maximum number of indexed documents""" def task_outputs(self, dep): """Returns a sparse retriever index that can be used by a SparseRetriever to search efficiently for documents""" return dep( SparseRetrieverIndex(index_path=self.index_path, documents=self.documents) ) def execute(self): if mp.get_start_method(allow_none=True) is None: mp.set_start_method("spawn") max_docs = ( self.documents.documentcount if self.max_docs == 0 else min(self.max_docs, self.documents.documentcount) ) iter_batches = MultiprocessIterator( DocumentIterator(self.documents, max_docs, self.batch_size) ).detach() self.encoder.initialize(ModuleInitMode.DEFAULT.to_options(None)) closed = mp.Event() queues = [ StoppableQueue(2 * self.batch_size + 1, closed) for _ in range(self.device.n_processes) ] # Cleanup the index before starting # ENHANCE: recover index build when possible from shutil import rmtree if self.index_path.is_dir(): rmtree(self.index_path) self.index_path.mkdir(parents=True) # Start the index process (thread) index_thread = threading.Thread( target=self.index, name="index", args=(queues, max_docs), ) index_thread.start() # Waiting for the encoder process to end logger.info(f"Starting to index {max_docs} documents") try: self.device.execute( self.device_execute, iter_batches, self.encoder, self.batcher, self.batch_size, queues, ) except Exception: logger.exception("Got an exception while running encoders") finally: logger.info("Waiting for the index process to stop") index_thread.join() def index( self, queues: List[StoppableQueue[Union[DocumentRange, EncodedDocument]]], max_docs: int, ): """Index encoded documents :param queues: Queues are used to send tensors """ with tqdm( total=max_docs, unit="documents", desc="Building the index", ) as pb: try: # Get ranges logger.info( "Starting the indexing process (%d queues) in %s", len(queues), self.index_path, ) indexer = xpmir_rust.index.SparseIndexer(str(self.index_path)) heap = [queue.get() for queue in queues] heapq.heapify(heap) # Loop over them while heap: # Process current range current = heap[0] logger.debug("Handling range: %s", current) for docid in range(current.start, current.end + 1): encoded = queues[current.rank].get() assert ( encoded.docid == docid ), f"Mismatch in document IDs ({encoded.docid} vs {docid})" (nonzero_ix,) = encoded.value.nonzero() indexer.add( docid, nonzero_ix.astype(np.uint64), encoded.value[nonzero_ix], ) pb.update() # Get next range next_range = queues[current.rank].get() # type: DocumentRange if next_range: logger.debug("Got next range: %s", next_range) heapq.heappushpop(heap, next_range) else: logger.info("Iterator %d is over", current.rank) heapq.heappop(heap) logger.info("Building the index") indexer.build(self.in_memory) except Empty: logger.warning("One encoder got a problem... stopping") raise except Exception: # Close all the queues logger.exception( "Got an exception in the indexing process, closing the queues" ) queues[0].stop() raise @staticmethod def device_execute( device_information: DeviceInformation, iter_batches: Iterator[List[Tuple[int, DocumentRecord]]], encoder, batcher, batch_size, queues: List[StoppableQueue], ): try: # Encode all documents logger.info( "Load the encoder and " f"transfer to the target device {device_information.device}" ) encoder = encoder.to(device_information.device).eval() queue = queues[device_information.rank] batcher = batcher.initialize(batch_size) # Index with torch.no_grad(): for batch in iter_batches: # Signals the output range document_range = DocumentRange( device_information.rank, batch[0][0], batch[-1][0] ) logger.debug( "Starting range [%d] %s", device_information.rank, document_range, ) queue.put(document_range) # Outputs the documents batcher.process( batch, SparseRetrieverIndexBuilder.encode_documents, encoder, queue, ) # Build the index logger.info("Closing queue %d", device_information.rank) queue.put(None) except Exception: queue.stop() raise @staticmethod def encode_documents( batch: List[Tuple[int, DocumentRecord]], encoder: TextEncoderBase[InputType, TextsRepresentationOutput], queue: "mp.Queue[EncodedDocument]", ): # Assumes for now dense vectors vectors = ( encoder([d[TextItem].text for _, d in batch]).value.cpu().numpy() ) # bs * vocab for vector, (docid, _) in zip(vectors, batch): queue.put(EncodedDocument(docid, vector))