from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Any
from attrs import evolve
import torch
from experimaestro import DataPath, OptionalDataPath, field, Param
from datamaestro_ir.data import IDTextRecord
from xpmir.neural import DualRepresentationScorer, QueriesRep, DocsRep
from xpmir.text.encoders import TextEncoderBase
from xpm_torch.learner import TrainerContext
from xpm_torch.losses import Loss
from xpm_torch.module import ModuleLoader
from xpm_torch.trainers import TrainingHook
from xpm_torch.metrics import ScalarMetric
import logging
logger = logging.getLogger(__name__)
[docs]
class DualVectorListener(TrainingHook):
"""Listener called with the (vectorial) representation of queries and
documents
The hook is called just after the computation of documents and queries
representations.
This can be used for logging purposes, but more importantly, to add
regularization losses such as the :class:`FlopsRegularizer` regularizer.
"""
[docs]
def __call__(
self, context: TrainerContext, queries: torch.Tensor, documents: torch.Tensor
):
"""Hook handler
Args:
context (TrainerContext): The training context
queries (torch.Tensor): The query vectors
documents (torch.Tensor): The document vectors
Raises:
NotImplementedError: _description_
"""
raise NotImplementedError(f"__call__ in {self.__class__}")
[docs]
class DualVectorScorerListener(TrainingHook, ABC):
"""Listener called with the (vectorial) representation of queries and
documents
The hook is called just after the computation of documents and queries
representations.
This can be used for logging purposes, but more importantly, to add
regularization losses such as the :class:`FlopsRegularizer` regularizer.
"""
@abstractmethod
def __call__(
self,
context: TrainerContext,
queries: torch.Tensor,
documents: torch.Tensor,
scorer: torch.Tensor,
):
"""Hook handler
Args:
context (TrainerContext): The training context queries
(torch.Tensor): The query vectors documents (torch.Tensor): The
document vectors scores (torch.Tensor): A vector or matrix of scores
(depending on how the scores were computed)
Raises:
NotImplementedError: _description_
"""
raise NotImplementedError(f"__call__ in {self.__class__}")
[docs]
class DualModuleLoader(ModuleLoader):
"""ModuleLoader for dual encoder models.
Has distinct ``encoder_path`` and ``query_encoder_path`` DataPaths so each
encoder is serialized independently. This enables proper
sentence-transformers format on HF Hub export (symmetric vs
router/asymmetric).
"""
encoder_path: Param[DataPath]
"""Path to the document encoder checkpoint"""
query_encoder_path: OptionalDataPath = None
"""Path to the query encoder checkpoint (if separate from doc encoder)"""
def execute(self):
self.value.initialize()
self.value.encoder.load_model(Path(self.encoder_path))
if self.query_encoder_path and self.value.query_encoder is not None:
self.value._query_encoder.load_model(Path(self.query_encoder_path))
[docs]
class DualVectorScorer(DualRepresentationScorer[QueriesRep, DocsRep]):
"""A scorer based on dual vectorial representations"""
CONFIG_LOADER = DualModuleLoader.C
encoder: Param[TextEncoderBase[Any, Any]]
"""The document (and potentially query) encoder"""
query_encoder: Param[Optional[TextEncoderBase[Any, Any]]]
"""The query encoder (optional, if not defined uses the query_encoder)"""
def __initialize__(self):
self.encoder.initialize()
if self.query_encoder:
self.query_encoder.initialize()
@property
def _query_encoder(self):
return self.query_encoder or self.encoder
def __validate__(self):
super().__validate__()
assert not self.encoder.static(), "The vocabulary should be learnable"
def _has_separate_query_model(self) -> bool:
"""Check if query and doc encoders are separate models.
Default: object identity check. Override for more specific logic.
"""
return self.query_encoder is not None and self.query_encoder is not self.encoder
def export_action(self, loader, **kwargs):
from xpmir.models import XPMIRExportAction
if self.doc:
kwargs.setdefault("doc", self.doc)
if self.bibtex:
kwargs.setdefault("bibtex", self.bibtex)
return XPMIRExportAction.C(loader=loader, **kwargs)
def loader_config(self, path: Path, *, settings=None) -> DualModuleLoader:
has_separate_query = self._has_separate_query_model()
return self.CONFIG_LOADER(
value=self,
encoder_path=path / "encoder",
query_encoder_path=(path / "query_encoder" if has_separate_query else None),
settings=settings,
)
[docs]
class Dense(DualVectorScorer[QueriesRep, DocsRep]):
"""A scorer based on a pair of (query, document) dense vectors"""
def score_product(self, queries, documents, info: Optional[TrainerContext] = None):
# Gather the representations across all processes
if info is not None and info.fabric is not None:
fabric = info.fabric
q_all = fabric.all_gather(queries.value, sync_grads=True)
d_all = fabric.all_gather(documents.value, sync_grads=True)
else:
# We should always have access to fabric (even not in trainer mode)
q_all = queries.value
d_all = documents.value
scores = q_all @ d_all.T
if info is not None:
for hook in info.hooks(DualVectorListener):
hook(info, queries, documents)
for hook in info.hooks(DualVectorScorerListener):
hook(info, queries, documents, scores)
return scores
def score_pairs(self, queries, documents, info: Optional[TrainerContext] = None):
scores = (
(queries.value.unsqueeze(1) @ documents.value.unsqueeze(2))
.squeeze(-1)
.squeeze(-1)
)
# Apply the dual vector hook
if info is not None:
for hook in info.hooks(DualVectorListener):
hook(info, queries, documents)
for hook in info.hooks(DualVectorScorerListener):
hook(info, queries, documents, scores)
return scores
@classmethod
def from_pretrained_hf(cls, model_id: str, **kwargs):
"""Creates a Dense model from a HuggingFace encoder model
:param model_id: The HuggingFace model ID
:param kwargs: Additional keyword arguments passed to the constructor
:returns: (model, init_tasks) tuple
"""
from xpmir.text.huggingface.base import (
HFConfigID,
HFModel,
HFModelInitFromID,
)
from xpmir.text.huggingface.encoders import HFCLSEncoder
hf_model = HFModel.C(config=HFConfigID.C(hf_id=model_id))
init_hf = HFModelInitFromID.C(model=hf_model)
encoder = HFCLSEncoder.C(model=hf_model, **kwargs)
return cls.C(encoder=encoder), [init_hf]
[docs]
class CosineDense(Dense):
"""Dual model based on cosine similarity."""
def encode_queries(self, records: List[IDTextRecord]):
queries = (self.query_encoder or self.encoder)(records)
return evolve(
queries, value=queries.value / queries.value.norm(dim=-1, keepdim=True)
)
def encode_documents(self, records: List[IDTextRecord]):
documents = self.encoder(records)
return evolve(
documents,
value=documents.value / documents.value.norm(dim=-1, keepdim=True),
)
[docs]
class DotDense(Dense):
"""Dual model based on inner product."""
def __validate__(self):
super().__validate__()
assert not self.encoder.static(), "The vocabulary should be learnable"
def encode_queries(self, records: List[IDTextRecord]):
"""Encode the different queries"""
return self._query_encoder(records)
def encode_documents(self, records: List[IDTextRecord]):
"""Encode the different documents"""
return self.encoder(records)
def save_model(self, path: Path):
"""Save sub-encoders independently to subdirectories."""
path.mkdir(parents=True, exist_ok=True)
self.encoder.save_model(path / "encoder")
if self.query_encoder is not None and self.query_encoder is not self.encoder:
self._query_encoder.save_model(path / "query_encoder")
def load_model(self, path: Path):
"""Load sub-encoders from subdirectories, with backward compat."""
if (path / "encoder").exists():
self.encoder.load_model(path / "encoder")
if (path / "query_encoder").exists() and self.query_encoder is not None:
self._query_encoder.load_model(path / "query_encoder")
else:
# Backward compat: flat state_dict
super().load_model(path)
def dual_representation_metrics(
info: TrainerContext, queries: torch.Tensor, documents: torch.Tensor
):
"""Compute and report sparsity and QD-FLOPS metrics for dual representations"""
with torch.no_grad():
qdflops_count = (
((queries > 0).sum(0) * (documents > 0).sum(0)).float().mean()
) / queries.shape[1]
info.metrics.add(ScalarMetric("qdflops_count", qdflops_count.item(), 1))
info.metrics.add(
ScalarMetric(
"sparsity_q",
torch.count_nonzero(queries).item()
/ (queries.shape[0] * queries.shape[1]),
len(queries),
)
)
info.metrics.add(
ScalarMetric(
"sparsity_d",
torch.count_nonzero(documents).item()
/ (documents.shape[0] * documents.shape[1]),
len(documents),
)
)
# Median activation rate of the top-k most frequent terms
for prefix, x in [("q", queries), ("d", documents)]:
freq = (x > 0).float().mean(0)
top_freq = freq.topk(min(20, freq.shape[0])).values
for k in (1, 5, 10, 20):
if k <= top_freq.shape[0]:
info.metrics.add(
ScalarMetric(
f"saturation_{prefix}/top{k}",
top_freq[:k].median().item(),
1,
)
)
[docs]
class FlopsRegularizer(DualVectorListener):
r"""The FLOPS regularizer computes
.. math::
FLOPS(q,d) = \lambda_q FLOPS(q) + \lambda_d FLOPS(d)
where
.. math::
FLOPS(x) = \left( \frac{1}{d} \sum_{i=1}^d |x_i| \right)^2
"""
lambda_q: Param[float]
"""Lambda for queries"""
lambda_d: Param[float]
"""Lambda for documents"""
def compute(x: torch.Tensor):
"""
:param x: term vectors (batch size x vocabulary dimension)
:returns: A couple (vocabulary dimension / FLOPS regularizations)
"""
# Computes the mean for each term (weights are positive)
y = x.mean(0)
# Returns the sum of squared means
return y, (y * y).sum()
def __call__(self, info: TrainerContext, queries, documents):
# queries and documents are length x dimension
# Assumes that all weights are positive
assert info.metrics is not None
queries = queries.value
documents = documents.value
# q of shape (dimension), flops_q of shape (1)
q, flops_q = FlopsRegularizer.compute(queries)
d, flops_d = FlopsRegularizer.compute(documents)
flops = self.lambda_d * flops_d + self.lambda_q * flops_q
info.add_loss(Loss("flops", flops, 1.0))
info.metrics.add(ScalarMetric("flops", flops.item(), 1))
info.metrics.add(ScalarMetric("flops_q", flops_q.item(), 1))
info.metrics.add(ScalarMetric("flops_d", flops_d.item(), 1))
dual_representation_metrics(info, queries, documents)
[docs]
class ScheduledFlopsRegularizer(FlopsRegularizer):
"""
The FLOPS regularizer where the lamdba_q and lambda_d varie according to the
steps. The lambda values goes quadratic before the
```lambda_warmup_steps```, and then remains constant
"""
min_lambda_q: Param[float] = field(default=0, ignore_default=True)
"""Min value for the lambda_q before it increase"""
min_lambda_d: Param[float] = field(default=0, ignore_default=True)
"""Min value for the lambda_d before it increase"""
lambda_warmup_steps: Param[int] = field(default=0, ignore_default=True)
"""The warmup steps for the lambda"""
def quadratic_ratio(self, step):
if step > self.lambda_warmup_steps:
return 1
else:
return (step / self.lambda_warmup_steps) ** 2
def __post_init__(self):
self.initial_lambda_q = self.lambda_q
self.initial_lambda_d = self.lambda_d
def __call__(self, info: TrainerContext, queries, documents):
current_step = info.steps
self.lambda_q = (
self.initial_lambda_q - self.min_lambda_q
) * self.quadratic_ratio(current_step) + self.min_lambda_q
self.lambda_d = (
self.initial_lambda_d - self.min_lambda_d
) * self.quadratic_ratio(current_step) + self.min_lambda_d
FlopsRegularizer.__call__(self, info, queries, documents)