Source code for xpmir.letor.samplers

import io
import json
from pathlib import Path
from typing import Iterator, List, Tuple, Dict
import numpy as np
from datamaestro_ir.data import (
    Adhoc,
    TrainingTriplets,
    PairwiseSampleDataset,
    PairwiseSample,
    DocumentStore,
    SimpleTextItem,
)
from experimaestro import field, Param, tqdm, Task, Annotated, pathgenerator
from experimaestro.annotations import cache
from functools import cached_property
from xpmir.rankers import ScoredDocument
from xpmir.datasets.adapters import TextStore
from xpmir.letor.records import (
    BatchwiseItems,
    PairwiseItems,
    PairwiseItem,
    PointwiseItem,
)
from xpmir.rankers import Retriever, Scorer
from xpm_torch import Sampler

from xpm_torch.datasets import ShardedIterableDataset, LineFileDataset, InfiniteDataset
from datamaestro_ir.interfaces.plaintext import read_tsv

import logging

logger = logging.getLogger(__name__)


# --- Base classes for samplers

PointwiseSampler = Sampler[PointwiseItem]
PairwiseSampler = Sampler[PairwiseItem]
BatchwiseSampler = Sampler[BatchwiseItems]


# --- Real instances


[docs] class ModelBasedSampler(Sampler): """Base class for retriever-based sampler""" dataset: Param[Adhoc] """The IR adhoc dataset""" retriever: Param[Retriever] """A retriever to sample negative documents""" _store: DocumentStore def __validate__(self) -> None: super().__validate__() assert self.retriever.get_store() is not None or isinstance( self.dataset.documents, DocumentStore ), "The retriever has no associated document store (to get document text)" def initialize(self, random): super().initialize(random) self._store = self.retriever.get_store() or self.dataset.documents assert self._store is not None, "No document store found" def document(self, doc_id): """Returns the document textual content""" return self._store.document_ext(doc_id) def document_text(self, doc_id): return self.document(doc_id)["text_item"].text @cache("run") def _itertopics( self, runpath: Path ) -> Iterator[ Tuple[str, List[Tuple[str, int, float]], List[Tuple[str, int, float]]] ]: """Iterates over topics, returning retrieved positives and negatives documents""" self.logger.info("Reading topics and retrieving documents") if not runpath.is_file(): tmprunpath = runpath.with_suffix(".tmp") with tmprunpath.open("wt") as fp: # Read the assessments self.logger.info("Reading assessments") assessments: Dict[str, Dict[str, float]] = {} for qrels in self.dataset.assessments.iter(): doc2rel = {} assessments[qrels.topic_id] = doc2rel for qrel in qrels.assessments: doc2rel[qrel.doc_id] = qrel.rel self.logger.info("Read assessments for %d topics", len(assessments)) self.logger.info("Retrieving documents for each topic") queries = [] for query in self.dataset.topics.iter(): queries.append(query) # Retrieve documents skipped = 0 for query in tqdm(queries): q_fp = io.StringIO() qassessments = assessments.get(query["id"], None) if not qassessments: skipped += 1 self.logger.warning( "Skipping topic %s (no assessments)", query["id"] ) continue # Write all the positive documents positives = [] for docno, rel in qassessments.items(): if rel > 0: q_fp.write( f"{query['text_item'].text if not positives else ''}" f"\t{docno}\t0.\t{rel}\n" ) positives.append((docno, rel, 0)) if not positives: self.logger.warning( "Skipping topic %s (no relevant documents)", query["id"], ) skipped += 1 continue scoreddocuments: List[ScoredDocument] = self.retriever.retrieve( query["text_item"].text ) negatives = [] for rank, sd in enumerate(scoreddocuments): # Get the assessment (assumes not relevant) rel = qassessments.get(sd.document["id"], 0) if rel > 0: continue negatives.append((sd.document["id"], rel, sd.score)) q_fp.write(f"\t{sd.document['id']}\t{sd.score}\t{rel}\n") if not negatives: self.logger.warning( "Skipping topic %s (no negatives documents)", query["id"], ) skipped += 1 continue assert len(positives) > 0 and len(negatives) > 0 # Write in cache, and yield fp.write(q_fp.getvalue()) q_fp.close() yield query["text_item"].text, positives, negatives # Finally, move the cache file in place... self.logger.info( "Processed %d topics (%d skipped)", len(queries), skipped ) tmprunpath.rename(runpath) else: # Read from cache self.logger.info("Reading records from file %s", runpath) with runpath.open("rt") as fp: positives = [] negatives = [] oldtitle = "" for line in fp.readlines(): title, docno, score, rel = line.rstrip().split("\t") if title: if oldtitle: yield oldtitle, positives, negatives positives = [] negatives = [] else: title = oldtitle title = title or oldtitle rel = int(rel) (positives if rel > 0 else negatives).append( (docno, rel, float(score)) ) oldtitle = title yield oldtitle, positives, negatives
[docs] class PointwiseModelBasedSampler(ModelBasedSampler, PointwiseSampler): relevant_ratio: Param[float] = field(default=0.5, ignore_default=True) """The target relevance ratio""" def initialize(self, random): super().initialize(random) self.retriever.initialize() self.pos_records, self.neg_records = self.readrecords() self.logger.info( "Loaded %d/%d pos/neg records", len(self.pos_records), len(self.neg_records) ) def prepare(self, sample: Tuple[str, int, float]): assert self.document_text(sample[1]) is not None document = self.document_text(sample[1]) return PointwiseItem( topic={"text_item": SimpleTextItem(sample[0])}, document={"text_item": SimpleTextItem(document)}, relevance=sample[3], ) def readrecords(self, runpath=None): pos_records, neg_records = [], [] for title, positives, negatives in self._itertopics(): for docno, rel, score in positives: pos_records.append((title, docno, score, rel)) for docno, rel, score in negatives: neg_records.append((title, docno, score, rel)) return pos_records, neg_records def record_iter(self) -> Iterator[PointwiseItem]: npos = len(self.pos_records) nneg = len(self.neg_records) while True: if self.random.random() < self.relevant_ratio: yield self.prepare(self.pos_records[self.random.randint(0, npos)]) else: yield self.prepare(self.neg_records[self.random.randint(0, nneg)]) def as_dataset(self) -> ShardedIterableDataset: """Returns a dataset that yields infinite random pointwise records.""" class _PointwiseDataset(ShardedIterableDataset): def __init__(self, sampler): super().__init__() self.sampler = sampler def iter_shard(self, shard_id, num_shards): yield from self.sampler.record_iter() return InfiniteDataset(_PointwiseDataset(self))
[docs] class PairwiseModelBasedSampler(ModelBasedSampler, PairwiseSampler): """A pairwise sampler based on a retrieval model""" def initialize(self, random: np.random.RandomState): super().initialize(random) self.retriever.initialize() self.topics: List[Tuple[str, List, List]] = self._readrecords() def _readrecords(self): topics = [] for title, positives, negatives in self._itertopics(): topics.append((title, positives, negatives)) return topics def sample(self, samples: List[Tuple[str, int, float]]): text = None while text is None: docid, rel, score = samples[self.random.randint(0, len(samples))] document = self.document(docid) text = document["text_item"].text return ScoredDocument(document, score) def _record_iter(self) -> Iterator[PairwiseItem]: """Infinite iterator over pairwise records.""" while True: title, positives, negatives = self.topics[ self.random.randint(0, len(self.topics)) ] yield PairwiseItem( {"text_item": SimpleTextItem(title)}, self.sample(positives), self.sample(negatives), ) def as_dataset(self) -> ShardedIterableDataset: """Returns a dataset that yields infinite random pairwise records.""" class _PairwiseDataset(ShardedIterableDataset): def __init__(self, sampler): super().__init__() self.sampler = sampler def iter_shard(self, shard_id, num_shards): yield from self.sampler._record_iter() return InfiniteDataset(_PairwiseDataset(self))
[docs] class PairwiseInBatchNegativesSampler(BatchwiseSampler): """An in-batch negative sampler constructured from a pairwise one""" sampler: Param[PairwiseSampler] """The base pairwise sampler""" def initialize(self, random): super().initialize(random) self.sampler.initialize(random) def as_dataset(self) -> ShardedIterableDataset: """Returns the inner sampler's dataset. In-batch negative construction moves to batchwise_collate. """ return self.sampler.as_dataset()
[docs] class TripletBasedSampler(PairwiseSampler): """Sampler based on a triplet source""" source: Param[TrainingTriplets] """Triplets""" def as_dataset(self) -> ShardedIterableDataset: """Returns a dataset wrapping the triplet source.""" # Wrap the triplet source: we use the source's iter() and wrap with SkippingIterator-like # Since TrainingTriplets doesn't expose a file path, we use an indexed approach # by collecting all triplets. For very large sources this may need a file-based approach. class _TripletIterableDataset(ShardedIterableDataset): def __init__(self, source): super().__init__() self.source = source def iter_shard(self, shard_id, num_shards): # Iterate once through data, yielding every num_shards-th item for i, (topic, pos, neg) in enumerate(self.source.iter()): if i % num_shards == shard_id: yield PairwiseItem(topic, pos, neg) return InfiniteDataset(_TripletIterableDataset(self.source))
[docs] class PairwiseDatasetTripletBasedSampler(PairwiseSampler): """Sampler based on a dataset where each query is associated with (1) a set of relevant documents (2) negative documents, where each negative is sampled with a specific algorithm """ documents: Param[DocumentStore] """The document store""" dataset: Param[PairwiseSampleDataset] """The dataset which contains the generated queries with its positives and negatives""" negative_algo: Param[str] = field(default="random", ignore_default=True) """The algo to sample the negatives, default value is random""" def _sample_record(self, sample: PairwiseSample) -> PairwiseItem: """Convert a PairwiseSample to a PairwiseItem by sampling pos/neg.""" possible_algos = sample.negatives.keys() assert self.negative_algo in possible_algos or self.negative_algo == "random" pos = sample.positives[self.random.randint(len(sample.positives))] qry = sample.topics[self.random.randint(len(sample.topics))] if self.negative_algo == "random": while True: neg_id = self.documents.docid_internal2external( self.random.randint(0, self.documents.documentcount) ) if neg_id != pos["id"]: break neg = {"id": neg_id} else: negatives = sample.negatives[self.negative_algo] neg = negatives[self.random.randint(len(negatives))] return PairwiseItem(qry.as_record(), pos, neg) def as_dataset(self) -> ShardedIterableDataset: """Returns a dataset that yields infinite random pairwise records.""" from xpm_torch.datasets import TransformDataset class _PairwiseSampleDataset(ShardedIterableDataset): def __init__(self, dataset): super().__init__() self.dataset = dataset def iter_shard(self, shard_id, num_shards): for i, sample in enumerate(self.dataset.iter()): if i % num_shards == shard_id: yield sample return InfiniteDataset( TransformDataset( _PairwiseSampleDataset(self.dataset), self._sample_record, ) )
# --- Dataloader
[docs] class TSVPairwiseSampleDataset(PairwiseSampleDataset): """Read the pairwise sample dataset from a tsv file""" hard_negative_samples_path: Param[Path] """The path which stores the existing ids""" def iter(self) -> Iterator[PairwiseSample]: """return a iterator over a set of pairwise_samples""" for triplet in read_tsv(self.hard_negative_samples_path): topics = [triplet[0]] positives = triplet[2].split(" ") negatives = triplet[4].split(" ") # at the moment, I don't have some good idea to store the algo yield PairwiseSample( topics=topics, positives=positives, negatives=negatives )
[docs] class JSONLPairwiseSampleDataset(PairwiseSampleDataset): """Transform a JSONL file to a pairwise dataset. General format:: { "queries": ["str", "str"], "pos_ids": ["id", "id"], "neg_ids": { "bm25": ["id", "id"], "random": ["id", "id"] } } """ path: Param[Path] """The path to the Jsonl file""" @cached_property def count(self): with self.path.open("r") as fp: line_count = sum(1 for _ in fp) return line_count def iter(self) -> Iterator[PairwiseSample]: with self.path.open("r") as fp: for line in fp: sample = json.loads(line) topics = [] positives = [] negatives = {} for topic_text in sample["queries"]: topics.append({"text_item": SimpleTextItem(topic_text)}) for pos_id in sample["pos_ids"]: positives.append({"id": pos_id}) for algo in sample["neg_ids"].keys(): negatives[algo] = [] for neg_id in sample["neg_ids"][algo]: negatives[algo].append({"id": neg_id}) yield PairwiseSample( topics=topics, positives=positives, negatives=negatives )
# A class for loading the data, need to move the other places.
[docs] class PairwiseSamplerFromTSV(PairwiseSampler): pairwise_samples_path: Param[Path] """The path which stores the existing triplets""" def _parse_tsv_line(self, line: str) -> PairwiseItem: """Parse a TSV line into a PairwiseItem.""" parts = line.split("\t") q_id, pos_id, pos_score, neg_id, neg_score = parts return PairwiseItem( {"id": q_id}, {"id": pos_id, "score": float(pos_score)}, {"id": neg_id, "score": float(neg_score)}, ) def as_dataset(self) -> ShardedIterableDataset: """Returns a LineFileDataset for the TSV pairwise samples.""" return InfiniteDataset( LineFileDataset(self.pairwise_samples_path, self._parse_tsv_line) )
# A class for loading the data, need to move the other places. # class ListwiseSamplerFromTSV(ListwiseSampler): # pairwise_samples_path: Param[Path] # """The path which stores the existing triplets""" # def pairwise_iter(self) -> SerializableIterator[PairwiseItem, Any]: # def iter() -> Iterator[PairwiseSample]: # for triplet in read_tsv(self.pairwise_samples_path): # q_id, pos_id, pos_score, neg_id, neg_score = triplet # yield PairwiseItem( # Record(IDItem(q_id)), # Record(IDItem(pos_id), ScoredItem(pos_score)), # Record(IDItem(neg_id), ScoredItem(neg_score)), # ) # return SkippingIterator(iter) # --- Tasks for hard negatives
[docs] class ModelBasedHardNegativeSampler(Task, Sampler): """Retriever-based hard negative sampler""" dataset: Param[Adhoc] """The dataset which contains the topics and assessments""" retriever: Param[Retriever] """The retriever to score of the document wrt the query""" hard_negative_samples: Annotated[Path, pathgenerator("hard_negatives.tsv")] """Path to store the generated hard negatives""" def task_outputs(self, dep) -> PairwiseSampleDataset: """return a iterator of PairwiseSample""" return dep( TSVPairwiseSampleDataset( ids=self.dataset.id, hard_negative_samples_path=self.hard_negative_samples, ) ) def execute(self): """Retrieve over the dataset and select the positive and negative according to the relevance score and their rank """ self.logger.info("Reading topics and retrieving documents") # create the file self.hard_negative_samples.parent.mkdir(parents=True, exist_ok=True) # Read the assessments self.logger.info("Reading assessments") assessments = {} # type: Dict[str, Dict[str, float]] for qrels in self.dataset.assessments.iter(): doc2rel = {} assessments[qrels.topic_id] = doc2rel for qrel in qrels.assessments: doc2rel[qrel.doc_id] = qrel.rel self.logger.info("Assessment loaded") self.logger.info("Read assessments for %d topics", len(assessments)) self.logger.info("Retrieving documents for each topic") queries = [] for query in self.dataset.topics.iter(): queries.append(query) with self.hard_negative_samples.open("wt") as fp: # Retrieve documents # count the number of queries been skipped because of no assessments # available skipped = 0 for query in tqdm(queries): qassessments = assessments.get(query["id"], None) if not qassessments: skipped += 1 self.logger.warning( "Skipping topic %s (no assessments)", query["id"] ) continue # Write all the positive documents positives = [] negatives = [] scoreddocuments: List[ScoredDocument] = self.retriever.retrieve(query) for rank, sd in enumerate(scoreddocuments): if qassessments.get(sd.document["id"], 0) > 0: # It is a positive document: positives.append(sd.document["id"]) else: # It is a negative document or # don't exist in assessment negatives.append(sd.document["id"]) if not positives: self.logger.debug( "Skipping topic %s (no relevant documents)", query["id"] ) skipped += 1 continue if not negatives: self.logger.debug( "Skipping topic %s (no negative documents)", query["id"] ) skipped += 1 continue # Write the result to the file positive_str = " ".join(positives) negative_str = " ".join(negatives) qid = query["id"] fp.write( f"{qid}\tpositives:\t{positive_str}\tnegatives:\t{negative_str}" ) self.logger.info("Processed %d topics (%d skipped)", len(queries), skipped)
[docs] class TeacherModelBasedHardNegativesTripletSampler(Task, Sampler): """Builds a teacher file for pairwise distillation losses""" sampler: Param[PairwiseSampler] """The list of exsting hard negatives which we can sample from""" document_store: Param[DocumentStore] """The document store""" topic_store: Param[TextStore] """The query_document store""" teacher_model: Param[Scorer] """The teacher model which scores the positive and negative document""" hard_negative_triplet: Annotated[Path, pathgenerator("triplet.tsv")] """The path to store the generated triplets""" batch_size: int """How many pairs of documents are been calculate in a batch""" def task_outputs(self, dep) -> PairwiseSampler: return dep( PairwiseSamplerFromTSV(pairwise_samples_path=self.hard_negative_triplet) ) def iter_pairs_with_text(self) -> Iterator[PairwiseItem]: """Add the information of the text back to the records""" for record in self.sampler.pairwise_iter(): record.query.text = self.topic_store[record.query.id] record.positive.text = self.document_store.document_text( record.positive.docid ) record.negative.text = self.document_store.document_text( record.negative.docid ) yield record def iter_batches(self) -> Iterator[PairwiseItems]: """Return the batch which contains the records""" while True: batch = PairwiseItems() for _, record in zip(range(self.batch_size), self.iter_pairs_with_text()): batch.add(record) yield batch def execute(self): """Pre-calculate the score for the teacher model, and store them""" self.logger.info("Calculating the score for the teacher model") # create the file self.hard_negative_triplet.parent.mkdir(parents=True, exist_ok=True) # make the tqdm progressing wrt one record, not a batch of records with self.hard_negative_triplet.open("wt") as fp: for batch in tqdm(self.iter_batches()): # scores in shape: [batch_size, 2] self.teacher_model.eval() scores = self.teacher_model(batch) scores = scores.reshape(2, -1).T # write in the file for i, record in enumerate(batch): fp.write( f"{record.query.id}\t{record.positive.id}\t{scores[i, 0]}" f"\t{record.negative.id}\t{scores[i, 1]}" ) self.logger.info("Teacher models score generating finish")