Source code for xpmir.learning.trainers

from abc import abstractmethod
from typing import Dict, Iterator, List, Optional
from experimaestro import Config, Param
import torch.nn as nn
import numpy as np
from xpmir.utils.utils import EasyLogger
from xpmir.learning import Module
from xpmir.learning.context import (
    TrainingHook,
    TrainerContext,
)

from xpmir.utils.utils import foreach


[docs]class Trainer(Config, EasyLogger): """Generic trainer""" hooks: Param[List[TrainingHook]] = [] """Hooks for this trainer: this includes the losses, but can be adapted for other uses The specific list of hooks depends on the specific trainer """ model: Param[Optional[Module]] = None """If the model to optimize is different from the model passsed to Learn, this parameter can be used – initialization is still expected to be done at the learner level""" def initialize( self, random: np.random.RandomState, context: TrainerContext, ): self.random = random # Generic style if self.model is None: self.model = context.state.model # Old style (to be deprecated) self.ranker = self.model self.context = context foreach(self.hooks, self.context.add_hook) def to(self, device): """Change the computing device (if this is needed)""" foreach(self.context.hooks(nn.Module), lambda hook: hook.to(device)) @abstractmethod def iter_batches(self) -> Iterator: """Returns a (serializable) iterator over batches""" ... @abstractmethod def process_batch(self, batch): ... @abstractmethod def load_state_dict(self, state: Dict): ... @abstractmethod def state_dict(self): ...