Source code for xpmir.interfaces.anserini

import asyncio
from functools import cached_property
from pathlib import Path
import contextlib
import json
import logging
import os
import re
import subprocess
import sys
from typing import List, Optional
from experimaestro import tqdm as xpmtqdm, Task, Meta

from datamaestro_text.data.ir import (
    DocumentStore,
    TextItem,
    IDItem,
    TopicRecord,
    create_record,
    InternalIDItem,
)
import datamaestro_text.data.ir.csv as ir_csv
from datamaestro_text.data.ir.trec import (
    Documents,
    Topics,
    TipsterCollection,
    TrecTopics,
)
from experimaestro import Param, param, pathoption, progress
from tqdm import tqdm
from xpmir.index.anserini import Index
from xpmir.rankers import Retriever, ScoredDocument, document_cache
from xpmir.rankers.standard import BM25, QLDirichlet, Model
from xpmir.utils.utils import Handler, StreamGenerator, needs_java

pyserini_java = needs_java(11)


def anserini_classpath():
    import pyserini

    base = Path(pyserini.__file__).parent

    paths = [path for path in base.rglob("anserini-*-fatjar.jar")]
    if not paths:
        raise Exception(f"No matching jar file found in {base}")

    latest = max(paths, key=os.path.getctime)
    return latest


def javacommand():
    """Returns the start of the java command including the Anserini class path"""
    from jnius_config import get_classpath

    command = ["{}/bin/java".format(os.environ["JAVA_HOME"]), "-cp"]
    command.append(":".join(get_classpath() + [str(anserini_classpath())]))

    return command


[docs]@pyserini_java @param("documents", type=Documents) @param("threads", default=8, ignored=True) @pathoption("path", "index") class IndexCollection(Index, Task): """An [Anserini](https://github.com/castorini/anserini) index""" CLASSPATH = "io.anserini.index.IndexCollection" documents: Param[Documents] """The documents to index""" thread: Meta[int] = 8 """Number of threads when indexing""" id: Param[str] = "" """Use an empty ID since identifier is determined by documents""" def execute(self): command = javacommand() command.append(IndexCollection.CLASSPATH) command.extend(["-index", self.path, "-threads", self.threads]) chandler = Handler() @chandler() def trec_collection(documents: TipsterCollection): return contextlib.nullcontext("void"), [ "-collection", "TrecCollection", "-input", documents.path, ] @chandler() def csv_collection(documents: ir_csv.Documents): def _generator(out): counter = 0 size = os.path.getsize(documents.path) with documents.path.open("rt", encoding="utf-8") as fp, tqdm( total=size, unit="B", unit_scale=True ) as pb: for ix, line in enumerate(fp): # Update progress ll = len(line) pb.update(ll) counter += ll progress(counter / size) # Generate document docid, text = line.strip().split(documents.separator, 1) json.dump({"id": docid, "contents": text}, out) out.write("\n") generator = StreamGenerator(_generator, mode="wt") return generator, [ "-collection", "JsonCollection", "-input", generator.filepath.parent, ] @chandler.default() def generic_collection(documents: Documents): """Generic collection handler, supposes that we can iterate documents""" def _generator(out): logging.info( "Starting the iterator over the documents (%d documents)", documents.documentcount, ) for document in xpmtqdm( documents.iter(), unit="documents", total=documents.documentcount ): # Generate document json.dump( { "id": document[IDItem].id, "contents": document[TextItem].text, }, out, ) out.write("\n") generator = StreamGenerator(_generator, mode="wt") return generator, [ "-collection", "JsonCollection", "-input", generator.filepath.parent, ] generator, args = chandler[self.documents] command.extend(args) if self.storePositions: command.append("-storePositions") if self.storeDocvectors: command.append("-storeDocvectors") if self.storeRaw: command.append("-storeRawDocs") if self.storeContents: command.append("-storeContents") # Index and keep track of progress through regular expressions RE_FILES = re.compile( rb""".*index\.IndexCollection \(IndexCollection.java:\d+\) - ([\d,]+) files found""" # noqa: E501 ) RE_FILE = re.compile( rb""".*index\.IndexCollection\$LocalIndexerThread \(IndexCollection.java:\d+\).* docs added.""" # noqa: E501 ) RE_COMPLETE = re.compile( rb""".*IndexCollection\.java.*Indexing Complete.*documents indexed""" ) async def run(command): with generator as _: logging.info("Running with command %s", command) proc = await asyncio.create_subprocess_exec( *command, stderr=None, stdout=asyncio.subprocess.PIPE ) nfiles = -1 indexedfiles = 0 complete = False while True: data = await proc.stdout.readline() if not data: break m = RE_FILES.match(data) complete = complete or (RE_COMPLETE.match(data) is not None) if m: nfiles = int(m.group(1).decode("utf-8").replace(",", "")) elif RE_FILE.match(data): indexedfiles += 1 progress(indexedfiles / nfiles) else: sys.stdout.write( data.decode("utf-8"), ) await proc.wait() if proc.returncode == 0 and not complete: logging.error( "Did not see the indexing complete log message" " -- exiting with error" ) sys.exit(1) sys.exit(proc.returncode) asyncio.run(run([str(s) for s in command]))
[docs]@pyserini_java @param("index", Index) @param("topics", Topics) @param("model", Model) @pathoption("path", "results.trec") class SearchCollection(Task): def execute(self): command = javacommand() command.append("io.anserini.search.SearchCollection") command.extend(("-index", self.index.path, "-output", self.path)) # Topics topicshandler = Handler() @topicshandler() def trectopics(topics: TrecTopics): return ("-topicreader", "Trec", "-topics", topics.path) @topicshandler() def tsvtopics(topics: ir_csv.Topics): return ("-topicreader", "TsvInt", "-topics", topics.path) command.extend(topicshandler[self.topics]) # Model modelhandler = Handler() @modelhandler() def handle(bm25: BM25): return ("-bm25", "-bm25.k1", str(bm25.k1), "-bm25.b", str(bm25.b)) command.extend(modelhandler[self.model]) # Start logging.info("Starting command %s", command) p = subprocess.run(command) sys.exit(p.returncode)
[docs]@pyserini_java class AnseriniRetriever(Retriever): """An Anserini-based retriever Attributes: index: The Anserini index model: the model used to search. Only supports BM25 so far. k: Number of results to retrieve """ index: Param[Index] model: Param[Model] k: Param[int] = 1500 @cached_property def searcher(self): from pyserini.search.lucene import LuceneSearcher searcher = LuceneSearcher(str(self.index.path)) modelhandler = Handler() @modelhandler() def handle_bm25(bm25: BM25): searcher.set_bm25(bm25.k1, bm25.b) @modelhandler() def handle_qld(qld: QLDirichlet): searcher.set_qld(qld.mu) modelhandler[self.model] return searcher def _get_store(self) -> Optional[Index]: """Returns the associated index (if any)""" if self.index.storeContents: return self.index def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: # see # https://github.com/castorini/anserini/blob/master/src/main/java/io/anserini/search/SimpleSearcher.java hits = self.searcher.search(record[TextItem].text, k=self.k) store = self.get_store() # Batch retrieve documents if store is not None: return [ ScoredDocument(doc, hit.score) for hit, doc in zip( hits, store.documents_ext([hit.docid for hit in hits]) ) ] return [ ScoredDocument( create_record( InternalIDItem(hit.lucene_docid), id=hit.docid, text=hit.contents ), hit.score, ) for hit in hits ]
@document_cache def index_builder( documents: Documents, *, launcher=None, **index_params ) -> IndexCollection: return IndexCollection(documents=documents, **index_params).submit( launcher=launcher ) def retriever( index_builder: IndexCollection, documents: Documents, *, k: int = None, model: Model = None, store: DocumentStore = None, ): """Function to construct an Anserini retriever""" index = index_builder(documents) return AnseriniRetriever( index=index, k=k or AnseriniRetriever.k, model=model, store=store )