from dataclasses import InitVar
import sys
from typing import Iterator
import numpy as np
import torch
import torch.nn.functional as F
from xpmir.learning.context import Loss, TrainerContext
from xpmir.letor.records import DocumentRecords, DocumentRecord
from xpmir.letor.trainers import LossTrainer
from experimaestro import Param, Config
from xpmir.letor.samplers import Sampler
from xpmir.utils.iter import RandomSerializableIterator
[docs]class MLMLoss(Config):
NAME = "?"
weight: Param[float] = 1.0
def initialize(self):
pass
def process(self, scores, targets, context: TrainerContext):
value = self.compute(scores, targets)
context.add_loss(Loss(f"point-{self.NAME}", value, self.weight))
def compute(self, rel_scores, target_relscores) -> torch.Tensor:
raise NotImplementedError()
[docs]class CrossEntropyLoss(MLMLoss):
"""Computes cross-entropy
Uses a CE with logits if the scorer output type is
not a probability
"""
NAME = "ce"
def compute(self, rel_scores, target_relscores):
return F.cross_entropy(rel_scores, target_relscores)
[docs]class MLMTrainer(LossTrainer):
"""Trainer for Masked Language Modeling"""
# Loss function to use
lossfn: Param[MLMLoss] = CrossEntropyLoss()
sampler: Param[Sampler]
sampler_iter: InitVar[RandomSerializableIterator[DocumentRecord]]
def initialize(
self,
random: np.random.RandomState,
context: TrainerContext,
):
super().initialize(random, context)
self.lossfn.initialize()
self.sampler.initialize(random)
self.sampler_iter = self.sampler.record_iter()
def __validate__(self):
# assert self.grad_acc_batch >= 0, "Adaptative batch size not implemented"
pass
def iter_batches(self) -> Iterator[DocumentRecords]:
while True:
batch = DocumentRecords()
for _, record in zip(range(self.batch_size), self.sampler_iter):
batch.add(record)
yield batch
def train_batch(self, records: DocumentRecords):
mlm_output = self.model(records.to_texts(), self.context)
if torch.isnan(mlm_output.logits).any() or torch.isinf(mlm_output.logits).any():
self.logger.error("nan or inf relevance score detected. Aborting.")
sys.exit(1)
self.lossfn.process(
mlm_output.logits.view(-1, self.model.vocab_size),
mlm_output.labels.view(-1).to(mlm_output.logits.device),
self.context,
)