Source code for xpmir.letor.trainers.generative

from torch import nn
import numpy as np
from experimaestro import Param, Config
import logging

from xpmir.letor.samplers import PairwiseSampler
from xpmir.letor.records import BaseRecords, PairwiseRecords
from xpmir.letor.trainers import TrainerContext, LossTrainer
from xpmir.learning.context import Loss
from xpmir.utils.iter import MultiprocessSerializableIterator
from xpmir.utils.utils import foreach, easylog

logger = easylog()


[docs]class PairwiseGenerativeLoss(Config, nn.Module): """Generic loss for generative models""" NAME = "?" weight: Param[float] = 1.0 """The weight :math:`w` with which the loss is multiplied (useful when combining with other ones)""" def compute(self, records, context): pass def process(self, records: BaseRecords, context: TrainerContext): value = self.compute(records, context) # tensor shape [bs] if logger.isEnabledFor(logging.DEBUG): logger.debug(f"Loss: {value}") context.add_loss(Loss(f"pair-{self.NAME}", value, self.weight))
[docs]class GenerativeTrainer(LossTrainer): loss: Param[PairwiseGenerativeLoss] sampler: Param[PairwiseSampler] """The pairwise sampler""" def initialize(self, random: np.random.RandomState, context: TrainerContext): super().initialize(random, context) self.loss.initialize() foreach( context.hooks(PairwiseGenerativeLoss), lambda loss: loss.initialize() ) # maybe later we need to change the sampling target, we can use this hook self.sampler.initialize(random) self.sampler_iter = MultiprocessSerializableIterator( self.sampler.pairwise_batch_iter(self.batch_size) ) def train_batch(self, records: PairwiseRecords): # do the forward pass to get the gradient value self.loss.process(records, self.context)