Source code for xpmir.letor.samplers

import json
from pathlib import Path
from typing import Iterator, List, Tuple, Dict, Any
import numpy as np
from datamaestro.record import Record
from datamaestro_text.data.ir import (
    Adhoc,
    TrainingTriplets,
    PairwiseSampleDataset,
    PairwiseSample,
    ScoredItem,
    DocumentStore,
    TextItem,
    SimpleTextItem,
    create_record,
    DocumentRecord,
    IDItem,
)
from experimaestro import Param, tqdm, Task, Annotated, pathgenerator
from experimaestro.annotations import cache
from experimaestro.compat import cached_property
import torch
from xpmir.rankers import ScoredDocument
from xpmir.datasets.adapters import TextStore
from xpmir.letor.records import (
    BatchwiseRecords,
    PairwiseRecords,
    ProductRecords,
    PairwiseRecord,
    PointwiseRecord,
    TopicRecord,
)
from xpmir.rankers import Retriever, Scorer
from xpmir.learning import Sampler
from xpmir.utils.utils import easylog
from xpmir.utils.iter import (
    RandomSerializableIterator,
    SerializableIterator,
    SerializableIteratorAdapter,
    SkippingIterator,
    RandomStateSerializableAdaptor,
    InfiniteSkippingIterator,
    iterable_of,
)
from datamaestro_text.interfaces.plaintext import read_tsv

logger = easylog()


# --- Base classes for samplers


[docs]class PointwiseSampler(Sampler):
[docs] def pointwise_iter(self) -> SerializableIterator[PointwiseRecord, Any]: """Iterable over pointwise records""" raise NotImplementedError(f"{self.__class__} should implement PointwiseRecord")
[docs]class PairwiseSampler(Sampler): """Abstract class for pairwise samplers which output a set of (query, positive, negative) triples""" def pairwise_iter(self) -> SerializableIterator[PairwiseRecord, Any]: """Iterate over batches of size (# of queries) batch_size Args: batch_size: Number of queries per batch """ raise NotImplementedError(f"{self.__class__} should implement __iter__") def pairwise_batch_iter(self, size) -> SerializableIterator[PairwiseRecords, Any]: """Batchwise iterator Can be subclassed by some classes to be more efficient""" class BatchIterator(SerializableIterator): def __init__(self, sampler: PairwiseSampler): self.iter = sampler.pairwise_iter() def state_dict(self): return self.iter.state_dict() def load_state_dict(self, state): self.iter.load_state_dict(state) def __next__(self): batch = PairwiseRecords() for _, record in zip(range(size), self.iter): batch.add(record) return batch return BatchIterator(self)
[docs]class BatchwiseSampler(Sampler): """Base class for batchwise samplers, that provide for each question a list of documents""" def batchwise_iter( self, batch_size: int ) -> SerializableIterator[BatchwiseRecords, Any]: """Iterate over batches of size (# of queries) batch_size Args: batch_size: Number of queries per batch """ raise NotImplementedError(f"{self.__class__} should implement __iter__")
# --- Real instances
[docs]class ModelBasedSampler(Sampler): """Base class for retriever-based sampler""" dataset: Param[Adhoc] """The IR adhoc dataset""" retriever: Param[Retriever] """A retriever to sample negative documents""" _store: DocumentStore def __validate__(self) -> None: super().__validate__() assert self.retriever.get_store() is not None or isinstance( self.dataset.documents, DocumentStore ), "The retriever has no associated document store (to get document text)" def initialize(self, random): super().initialize(random) self._store = self.retriever.get_store() or self.dataset.documents assert self._store is not None, "No document store found" def document(self, doc_id): """Returns the document textual content""" return self._store.document_ext(doc_id) def document_text(self, doc_id): return self.document(doc_id).text @cache("run") def _itertopics( self, runpath: Path ) -> Iterator[ Tuple[str, List[Tuple[str, int, float]], List[Tuple[str, int, float]]] ]: """Iterates over topics, returning retrieved positives and negatives documents""" self.logger.info("Reading topics and retrieving documents") if not runpath.is_file(): tmprunpath = runpath.with_suffix(".tmp") with tmprunpath.open("wt") as fp: # Read the assessments self.logger.info("Reading assessments") assessments: Dict[str, Dict[str, float]] = {} for qrels in self.dataset.assessments.iter(): doc2rel = {} assessments[qrels.topic_id] = doc2rel for qrel in qrels.assessments: doc2rel[qrel.doc_id] = qrel.rel self.logger.info("Read assessments for %d topics", len(assessments)) self.logger.info("Retrieving documents for each topic") queries = [] for query in self.dataset.topics.iter(): queries.append(query) # Retrieve documents skipped = 0 for query in tqdm(queries): qassessments = assessments.get(query[IDItem].id, None) if not qassessments: skipped += 1 self.logger.warning( "Skipping topic %s (no assessments)", query[IDItem].id ) continue # Write all the positive documents positives = [] for docno, rel in qassessments.items(): if rel > 0: fp.write( f"{query.text if not positives else ''}" f"\t{docno}\t0.\t{rel}\n" ) positives.append((docno, rel, 0)) if not positives: self.logger.warning( "Skipping topic %s (no relevant documents)", query[IDItem].id, ) skipped += 1 continue scoreddocuments: List[ScoredDocument] = self.retriever.retrieve( query.text ) negatives = [] for rank, sd in enumerate(scoreddocuments): # Get the assessment (assumes not relevant) rel = qassessments.get(sd.document[IDItem].id, 0) if rel > 0: continue negatives.append((sd.document[IDItem].id, rel, sd.score)) fp.write(f"\t{sd.document[IDItem].id}\t{sd.score}\t{rel}\n") if not negatives: self.logger.warning( "Skipping topic %s (no negatives documents)", query[IDItem].id, ) skipped += 1 continue assert len(positives) > 0 and len(negatives) > 0 yield query.text, positives, negatives # Finally, move the cache file in place... self.logger.info( "Processed %d topics (%d skipped)", len(queries), skipped ) tmprunpath.rename(runpath) else: # Read from cache self.logger.info("Reading records from file %s", runpath) with runpath.open("rt") as fp: positives = [] negatives = [] oldtitle = "" for line in fp.readlines(): title, docno, score, rel = line.rstrip().split("\t") if title: if oldtitle: yield oldtitle, positives, negatives positives = [] negatives = [] else: title = oldtitle title = title or oldtitle rel = int(rel) (positives if rel > 0 else negatives).append( (docno, rel, float(score)) ) oldtitle = title yield oldtitle, positives, negatives
[docs]class PointwiseModelBasedSampler(PointwiseSampler, ModelBasedSampler): relevant_ratio: Param[float] = 0.5 """The target relevance ratio""" def initialize(self, random): super().initialize(random) self.retriever.initialize() self.pos_records, self.neg_records = self.readrecords() self.logger.info( "Loaded %d/%d pos/neg records", len(self.pos_records), len(self.neg_records) ) def prepare(self, sample: Tuple[str, int, float]): assert self.document_text(sample[1]) is not None document = self.document_text(sample[1]) return PointwiseRecord( topic=TopicRecord(SimpleTextItem(sample[0])), document=DocumentRecord(document=document), relevance=sample[3], ) def readrecords(self, runpath=None): pos_records, neg_records = [], [] for title, positives, negatives in self._itertopics(): for docno, rel, score in positives: pos_records.append((title, docno, score, rel)) for docno, rel, score in negatives: neg_records.append((title, docno, score, rel)) return pos_records, neg_records def record_iter(self) -> Iterator[PointwiseRecord]: npos = len(self.pos_records) nneg = len(self.neg_records) while True: if self.random.random() < self.relevant_ratio: yield self.prepare(self.pos_records[self.random.randint(0, npos)]) else: yield self.prepare(self.neg_records[self.random.randint(0, nneg)]) def pointwise_iter(self) -> SerializableIterator[PointwiseRecord, Any]: npos = len(self.pos_records) nneg = len(self.neg_records) def iter(random): while True: if self.random.random() < self.relevant_ratio: yield self.prepare(self.pos_records[self.random.randint(0, npos)]) else: yield self.prepare(self.neg_records[self.random.randint(0, nneg)]) return RandomSerializableIterator(self.random, iter)
[docs]class PairwiseModelBasedSampler(PairwiseSampler, ModelBasedSampler): """A pairwise sampler based on a retrieval model""" def initialize(self, random: np.random.RandomState): super().initialize(random) self.retriever.initialize() self.topics: List[Tuple[str, List, List]] = self._readrecords() def _readrecords(self): topics = [] for title, positives, negatives in self._itertopics(): topics.append((title, positives, negatives)) return topics def sample(self, samples: List[Tuple[str, int, float]]): text = None while text is None: docid, rel, score = samples[self.random.randint(0, len(samples))] document = self.document(docid).add(ScoredItem(score)) text = document[TextItem].text return document def pairwise_iter(self) -> SerializableIterator[PairwiseRecord, Any]: def iter(random): while True: title, positives, negatives = self.topics[ random.randint(0, len(self.topics)) ] yield PairwiseRecord( create_record(text=title), self.sample(positives), self.sample(negatives), ) return RandomSerializableIterator(self.random, iter)
[docs]class PairwiseInBatchNegativesSampler(BatchwiseSampler): """An in-batch negative sampler constructured from a pairwise one""" sampler: Param[PairwiseSampler] """The base pairwise sampler""" def initialize(self, random): super().initialize(random) self.sampler.initialize(random) def batchwise_iter( self, batch_size: int ) -> SerializableIterator[BatchwiseRecords, Any]: def iter(pair_iter): # Pre-compute relevance matrix (query x document) relevances = torch.cat( (torch.eye(batch_size), torch.zeros(batch_size, batch_size)), 1 ) while True: batch = ProductRecords() positives = [] negatives = [] for _, record in zip(range(batch_size), pair_iter): batch.add_topics(record.query) positives.append(record.positive) negatives.append(record.negative) batch.add_documents(*positives) batch.add_documents(*negatives) batch.set_relevances(relevances) yield batch return SerializableIteratorAdapter(self.sampler.pairwise_iter(), iter)
[docs]class TripletBasedSampler(PairwiseSampler): """Sampler based on a triplet source""" source: Param[TrainingTriplets] """Triplets""" def pairwise_iter(self) -> SerializableIterator[PairwiseRecord, Any]: iterator = ( PairwiseRecord(topic, pos, neg) for topic, pos, neg in self.source.iter() ) return SkippingIterator(iterator)
[docs]class PairwiseDatasetTripletBasedSampler(PairwiseSampler): """Sampler based on a dataset where each query is associated with (1) a set of relevant documents (2) negative documents, where each negative is sampled with a specific algorithm """ documents: Param[DocumentStore] """The document store""" dataset: Param[PairwiseSampleDataset] """The dataset which contains the generated queries with its positives and negatives""" negative_algo: Param[str] = "random" """The algo to sample the negatives, default value is random""" def pairwise_iter(self) -> SkippingIterator[PairwiseRecord]: class _Iterator( RandomStateSerializableAdaptor[SerializableIterator[PairwiseSample]] ): def __init__( self, iterator: SerializableIterator[PairwiseSample], random: np.random.RandomState, negative_algo: str, documents: DocumentStore, ): super().__init__(iterator) self.random = random self.negative_algo = negative_algo self.documents = documents def __next__(self): sample = next(self.iterator) # type: PairwiseSample possible_algos = sample.negatives.keys() assert ( self.negative_algo in possible_algos or self.negative_algo == "random" ) pos = sample.positives[self.random.randint(len(sample.positives))] qry = sample.topics[self.random.randint(len(sample.topics))] if self.negative_algo == "random": # choose the random negatives while True: neg_id = self.documents.docid_internal2external( self.random.randint(0, self.documents.documentcount) ) if neg_id != pos.id: break neg = create_record(id=neg_id) else: negatives = sample.negatives[self.negative_algo] neg = negatives[self.random.randint(len(negatives))] return PairwiseRecord( qry.as_record(), DocumentRecord(pos), DocumentRecord(neg) ) base = InfiniteSkippingIterator(iterable_of(lambda: self.dataset.iter())) return _Iterator(base, self.random, self.negative_algo, self.documents)
# --- Dataloader
[docs]class TSVPairwiseSampleDataset(PairwiseSampleDataset): """Read the pairwise sample dataset from a tsv file""" hard_negative_samples_path: Param[Path] """The path which stores the existing ids""" def iter(self) -> Iterator[PairwiseSample]: """return a iterator over a set of pairwise_samples""" for triplet in read_tsv(self.hard_negative_samples_path): topics = [triplet[0]] positives = triplet[2].split(" ") negatives = triplet[4].split(" ") # at the moment, I don't have some good idea to store the algo yield PairwiseSample(topics, positives, negatives)
[docs]class JSONLPairwiseSampleDataset(PairwiseSampleDataset): """Transform a jsonl file to a pairwise dataset General format: { queries: [str, str], pos_ids: [id, id], neg_ids: { "bm25": [id, id], "random": [id, id] } } """ path: Param[Path] """The path to the Jsonl file""" @cached_property def count(self): with self.path.open("r") as fp: line_count = sum(1 for _ in fp) return line_count def iter(self) -> Iterator[PairwiseSample]: with self.path.open("r") as fp: for line in fp: sample = json.loads(line) topics = [] positives = [] negatives = {} for topic_text in sample["queries"]: topics.append(create_record(text=topic_text)) for pos_id in sample["pos_ids"]: positives.append(create_record(id=pos_id)) for algo in sample["neg_ids"].keys(): negatives[algo] = [] for neg_id in sample["neg_ids"][algo]: negatives[algo].append(create_record(id=neg_id)) yield PairwiseSample( topics=topics, positives=positives, negatives=negatives )
# A class for loading the data, need to move the other places.
[docs]class PairwiseSamplerFromTSV(PairwiseSampler): pairwise_samples_path: Param[Path] """The path which stores the existing triplets""" def pairwise_iter(self) -> SerializableIterator[PairwiseRecord, Any]: def iter() -> Iterator[PairwiseSample]: for triplet in read_tsv(self.pairwise_samples_path): q_id, pos_id, pos_score, neg_id, neg_score = triplet yield PairwiseRecord( Record(IDItem(q_id)), Record(IDItem(pos_id), ScoredItem(pos_score)), Record(IDItem(neg_id), ScoredItem(neg_score)), ) return SkippingIterator(iter)
# --- Tasks for hard negatives
[docs]class ModelBasedHardNegativeSampler(Task, Sampler): """Retriever-based hard negative sampler""" dataset: Param[Adhoc] """The dataset which contains the topics and assessments""" retriever: Param[Retriever] """The retriever to score of the document wrt the query""" hard_negative_samples: Annotated[Path, pathgenerator("hard_negatives.tsv")] """Path to store the generated hard negatives""" def task_outputs(self, dep) -> PairwiseSampleDataset: """return a iterator of PairwiseSample""" return dep( TSVPairwiseSampleDataset( ids=self.dataset.id, hard_negative_samples_path=self.hard_negative_samples, ) ) def execute(self): """Retrieve over the dataset and select the positive and negative according to the relevance score and their rank """ self.logger.info("Reading topics and retrieving documents") # create the file self.hard_negative_samples.parent.mkdir(parents=True, exist_ok=True) # Read the assessments self.logger.info("Reading assessments") assessments = {} # type: Dict[str, Dict[str, float]] for qrels in self.dataset.assessments.iter(): doc2rel = {} assessments[qrels.qid] = doc2rel for qrel in qrels.assessments: doc2rel[qrel.docid] = qrel.rel self.logger.info("Assessment loaded") self.logger.info("Read assessments for %d topics", len(assessments)) self.logger.info("Retrieving documents for each topic") queries = [] for query in self.dataset.topics.iter(): queries.append(query) with self.hard_negative_samples.open("wt") as fp: # Retrieve documents # count the number of queries been skipped because of no assessments # available skipped = 0 for query in tqdm(queries): qassessments = assessments.get(query.qid, None) if not qassessments: skipped += 1 self.logger.warning("Skipping topic %s (no assessments)", query.qid) continue # Write all the positive documents positives = [] negatives = [] scoreddocuments: List[ScoredDocument] = self.retriever.retrieve( query.text ) for rank, sd in enumerate(scoreddocuments): if qassessments.get(sd.docid, 0) > 0: # It is a positive document: positives.append(sd.docid) else: # It is a negative document or # don't exist in assessment negatives.append(sd.docid) if not positives: self.logger.debug( "Skipping topic %s (no relevant documents)", query.qid ) skipped += 1 continue if not negatives: self.logger.debug( "Skipping topic %s (no negative documents)", query.qid ) skipped += 1 continue # Write the result to the file positive_str = " ".join(positives) negative_str = " ".join(negatives) fp.write( f"{qrels.qid}\tpositives:\t{positive_str}\t" f"negatives:\t{negative_str}" ) self.logger.info("Processed %d topics (%d skipped)", len(queries), skipped)
[docs]class TeacherModelBasedHardNegativesTripletSampler(Task, Sampler): """Builds a teacher file for pairwise distillation losses""" sampler: Param[PairwiseSampler] """The list of exsting hard negatives which we can sample from""" document_store: Param[DocumentStore] """The document store""" topic_store: Param[TextStore] """The query_document store""" teacher_model: Param[Scorer] """The teacher model which scores the positive and negative document""" hard_negative_triplet: Annotated[Path, pathgenerator("triplet.tsv")] """The path to store the generated triplets""" batch_size: int """How many pairs of documents are been calculate in a batch""" def task_outputs(self, dep) -> PairwiseSampler: return dep( PairwiseSamplerFromTSV(pairwise_samples_path=self.hard_negative_triplet) ) def iter_pairs_with_text(self) -> Iterator[PairwiseRecord]: """Add the information of the text back to the records""" for record in self.sampler.pairwise_iter(): record.query.text = self.topic_store[record.query.id] record.positive.text = self.document_store.document_text( record.positive.docid ) record.negative.text = self.document_store.document_text( record.negative.docid ) yield record def iter_batches(self) -> Iterator[PairwiseRecords]: """Return the batch which contains the records""" while True: batch = PairwiseRecords() for _, record in zip(range(self.batch_size), self.iter_pairs_with_text()): batch.add(record) yield batch def execute(self): """Pre-calculate the score for the teacher model, and store them""" self.logger.info("Calculating the score for the teacher model") # create the file self.hard_negative_triplet.parent.mkdir(parents=True, exist_ok=True) # make the tqdm progressing wrt one record, not a batch of records with self.hard_negative_triplet.open("wt") as fp: for batch in tqdm(self.iter_batches()): # scores in shape: [batch_size, 2] self.teacher_model.eval() scores = self.teacher_model(batch) scores = scores.reshape(2, -1).T # write in the file for i, record in enumerate(batch): fp.write( f"{record.query.id}\t{record.positive.id}\t{scores[i,0]}" f"\t{record.negative.id}\t{scores[i,1]}" ) self.logger.info("Teacher models score generating finish")