Source code for datamaestro_ir.data.distillation

"""Distillation dataset types for IR"""

import logging
from dataclasses import dataclass
from typing import (
    Generic,
    Iterable,
    Iterator,
    List,
    Tuple,
    TypeVar,
)

from experimaestro import Config, Meta, Param
from datamaestro.data import File
from datamaestro_ir.data import AdhocAssessments
from datamaestro_ir.data.base import (
    IDRecord,
    TextRecord,
    SimpleTextItem,
    ScoredDocument,
)

DocT = TypeVar("DocT")
DocT2 = TypeVar("DocT2")
QueryT = TypeVar("QueryT")
QueryT2 = TypeVar("QueryT2")


[docs] @dataclass class PairwiseDistillationSample(Generic[DocT, QueryT]): query: QueryT """The query""" documents: Tuple[DocT, DocT] """Positive/negative document with teacher scores""" def get_queries(self) -> List[QueryT]: return [self.query] def with_queries( self, qs: "List[QueryT2]" ) -> "PairwiseDistillationSample[DocT, QueryT2]": return PairwiseDistillationSample(qs[0], self.documents) def get_documents(self) -> List[DocT]: return list(self.documents) def with_documents( self, ds: "List[DocT2]" ) -> "PairwiseDistillationSample[DocT2, QueryT]": return PairwiseDistillationSample(self.query, tuple(ds))
class PairwiseDistillationSamples(Config, Iterable[PairwiseDistillationSample]): """Pairwise distillation file""" def __iter__(self) -> Iterator[PairwiseDistillationSample]: raise NotImplementedError() class PairwiseDistillationSamplesTSV(PairwiseDistillationSamples, File): """A TSV file (Score 1, Score 2, Query, Document 1, Document 2)""" with_docid: Meta[bool] with_queryid: Meta[bool] def _parse_line(self, line: str) -> PairwiseDistillationSample: """Parse a single TSV line into a PairwiseDistillationSample.""" import csv import io reader = csv.reader(io.StringIO(line), delimiter="\t") row = next(reader) if self.with_queryid: query = IDRecord(id=row[2]) else: query = TextRecord(text_item=SimpleTextItem(row[2])) if self.with_docid: documents = ( ScoredDocument(IDRecord(id=row[3]), float(row[0])), ScoredDocument(IDRecord(id=row[4]), float(row[1])), ) else: documents = ( ScoredDocument( TextRecord(text_item=SimpleTextItem(row[3])), float(row[0]) ), ScoredDocument( TextRecord(text_item=SimpleTextItem(row[4])), float(row[1]) ), ) return PairwiseDistillationSample(query, documents) @dataclass class ListwiseDistillationSample(Generic[DocT, QueryT]): query: QueryT """The query""" documents: List[DocT] """List of documents with their ranking position""" def get_queries(self) -> List[QueryT]: return [self.query] def with_queries( self, qs: "List[QueryT2]" ) -> "ListwiseDistillationSample[DocT, QueryT2]": return ListwiseDistillationSample(qs[0], self.documents) def get_documents(self) -> List[DocT]: return self.documents def with_documents( self, ds: "List[DocT2]" ) -> "ListwiseDistillationSample[DocT2, QueryT]": return ListwiseDistillationSample(self.query, list(ds)) class ListwiseDistillationSamples(Config, Iterable[ListwiseDistillationSample]): """Listwise distillation file""" def __iter__(self) -> Iterator[ListwiseDistillationSample]: raise NotImplementedError() class ListwiseDistillationSamplesTSV(ListwiseDistillationSamples, File): """A TSV file ("query_id", "q0", "doc_id", "rank", "score", "system")""" top_k: Meta[int] with_docid: Meta[bool] with_queryid: Meta[bool] @staticmethod def _parse_trec_line(line: str) -> tuple: """Parse a TREC-format line, return (query_key, row_fields).""" parts = line.split("\t") if "\t" in line else line.split() return parts[0], parts def _build_group(self, query_key: str, rows: list) -> ListwiseDistillationSample: """Build a ListwiseDistillationSample from grouped TREC lines.""" if self.with_queryid: query_record = IDRecord(id=query_key) else: query_record = TextRecord(text_item=SimpleTextItem(query_key)) documents = [] for row in rows: if self.with_docid: doc = ScoredDocument(IDRecord(id=row[2]), float(row[4])) else: doc = ScoredDocument( TextRecord(text_item=SimpleTextItem(row[2])), float(row[4]) ) documents.append(doc) return ListwiseDistillationSample(query_record, documents) class ListwiseDistillationSamplesTSVWithAnnotations(ListwiseDistillationSamplesTSV): qrels: Param[AdhocAssessments] def __post_init__(self): self.qrels_dict = {} logging.info("Loading qrels into memory...") for qrel in self.qrels.iter(): self.qrels_dict[qrel.topic_id] = [ assess.doc_id for assess in qrel.assessments if assess.rel > 0 ]