from enum import Enum
import logging
import torch
from pathlib import Path
from typing import Dict, Iterator, List, NamedTuple, Any
from experimaestro import (
Task,
Config,
Param,
pathgenerator,
Annotated,
tqdm,
Meta,
)
import numpy as np
from xpmir.context import Hook, InitializationHook
from xpmir.utils.utils import EasyLogger, easylog, foreach
from xpmir.learning.devices import DEFAULT_DEVICE, Device, DeviceInformation
from xpmir.learning import Random, ModuleInitMode
from xpmir.learning.trainers import Trainer
from xpmir.learning.context import (
StepTrainingHook,
TrainState,
TrainerContext,
)
from xpmir.learning.metrics import Metrics
from .optim import (
Module,
ModuleLoader,
ParameterOptimizer,
ScheduledOptimizer,
OptimizationHook,
)
from .batchers import RecoverableOOMError
logger = easylog()
class LearnerListenerStatus(Enum):
NO_DECISION = 0
STOP = 1
DONT_STOP = 2
def update(self, other: "LearnerListenerStatus") -> "LearnerListenerStatus":
return LearnerListenerStatus(max(self.value, other.value))
[docs]class LearnerListener(Config):
"""Hook for learner
Performs some operations after a learning epoch"""
id: Meta[str]
"""Unique ID to identify the listener (ignored for signature)"""
def initialize(self, learner: "Learner", context: TrainerContext):
self.learner = learner
self.context = context
[docs] def __call__(self, state: TrainState) -> LearnerListenerStatus:
"""Process and returns whether the training process should stop"""
return LearnerListenerStatus.NO_DECISION
def update_metrics(self, metrics: Dict[str, float]):
"""Add metrics"""
pass
def task_outputs(self, learner: "Learner", dep):
"""Outputs from this listeners (deprecated)
:param learner: The learner object
:param dep: The function that adds a dependency
"""
raise DeprecationWarning("task_outputs has been deprecated, use init_task")
def init_task(self, learner: "Learner", dep):
"""Returns the initialization task that loads the associated checkpoint
:param learner: The learner object
:param dep: The function that adds a dependency
"""
return None
[docs]class LearnerOutput(NamedTuple):
"""The data structure for the output of a learner. It contains a dictionary
where the key is the name of the listener and the value is the output of
that listener"""
listeners: Dict[str, Any]
learned_model: ModuleLoader
[docs]class Learner(Task, EasyLogger):
"""Model Learner
The learner task is generic, and takes two main arguments: (1) the model
defines the model (e.g. DRMM), and (2) the trainer defines how the model
should be trained (e.g. pointwise, pairwise, etc.)
When submitted, it returns a dictionary based on the `listeners`
"""
# Training
random: Param[Random]
"""The random generator"""
trainer: Param[Trainer]
"""Specifies how to train the model"""
model: Param[Module]
"""Defines the model to be learned. If multiple models are used, one can use
MultipleModel.
"""
max_epochs: Param[int] = 1000
"""Maximum number of epochs"""
steps_per_epoch: Param[int] = 128
"""Number of steps for one epoch (after each epoch results are logged)"""
use_fp16: Param[bool] = False
"""Use mixed precision when training"""
optimizers: Param[List[ParameterOptimizer]]
"""The list of parameter optimizers"""
listeners: Param[List[LearnerListener]]
"""Listeners are in charge of handling the validation of the model, and
saving the relevant checkpoints"""
checkpoint_interval: Param[int] = 1
"""Number of epochs between each checkpoint"""
logpath: Annotated[Path, pathgenerator("runs")]
"""The path to the tensorboard logs"""
checkpointspath: Annotated[Path, pathgenerator("checkpoints")]
"""The path to the checkpoints"""
device: Meta[Device] = DEFAULT_DEVICE
"""The device(s) to be used for the model"""
hooks: Param[List[Hook]] = []
"""Global learning hooks
:class:`Initialization hooks <xpmir.context.InitializationHook>` are called
before and after the initialization of the trainer and listeners.
"""
use_pretasks: Meta[bool] = False
"""Use deprected pre-tasks as the output"""
def __validate__(self):
assert self.optimizers, "At least one optimizer should be defined"
assert len(set(listener.id for listener in self.listeners)) == len(
self.listeners
), "IDs of listeners should be unique"
return super().__validate__()
def task_outputs(self, dep) -> LearnerOutput:
"""Object returned when submitting the task"""
if self.use_pretasks:
logging.warn("Using deprecated pre-tasks in Learner")
return LearnerOutput(
listeners={
listener.id: listener.task_outputs(self, dep)
for listener in self.listeners
},
learned_model=ModuleLoader.construct(
self.model, self.last_checkpoint_path / TrainState.MODEL_PATH, dep
),
)
return LearnerOutput(
listeners={
listener.id: listener.init_task(self, dep)
for listener in self.listeners
},
learned_model=dep(
ModuleLoader(
value=self.model,
path=self.last_checkpoint_path / TrainState.MODEL_PATH,
)
),
)
@property
def last_checkpoint_path(self):
return self.checkpointspath / "last"
def execute(self):
self.device.execute(self.device_execute)
def device_execute(self, device_information: DeviceInformation):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
for handler in logger.handlers:
handler.setLevel(logging.INFO)
self.optimizer = ScheduledOptimizer()
self.only_cached = False
self.context = TrainerContext(
device_information,
self.logpath,
self.checkpointspath,
self.max_epochs,
self.steps_per_epoch,
self.trainer,
self.model,
self.optimizer,
)
for hook in self.hooks:
self.context.add_hook(hook)
# Call hooks
foreach(
self.context.hooks(InitializationHook),
lambda hook: hook.before(self.context),
)
# Sets the random seed
seed = self.random.state.randint((2**32) - 1)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Initialize the scorer and trainer
self.logger.info("model initialization")
self.model.initialize(ModuleInitMode.DEFAULT.to_options(self.random.state))
# Initialize the context and the listeners
self.trainer.initialize(self.random.state, self.context)
for listener in self.listeners:
listener.initialize(self, self.context)
self.logger.info("Moving model to device %s", device_information.device)
self.model.to(device_information.device)
self.trainer.to(device_information.device)
num_training_steps = self.max_epochs * self.steps_per_epoch
self.optimizer.initialize(
self.optimizers,
num_training_steps,
self.model,
self.use_fp16,
hooks=[hook for hook in self.hooks if isinstance(hook, OptimizationHook)],
trainer_context=self.context,
)
foreach(
self.context.hooks(InitializationHook),
lambda hook: hook.after(self.context),
)
self.logger.info("Starting to train")
current = 0
state = None
with tqdm(
total=self.max_epochs, desc=f"Training ({self.max_epochs} epochs)"
) as tqdm_epochs:
for state in self.iter_train(device_information):
# Report progress
tqdm_epochs.update(state.epoch - current)
current = state.epoch
if state.epoch >= 0 and not self.only_cached:
message = f"epoch {state.epoch}"
if state.cached:
self.logger.debug(f"[train] [cached] {message}")
else:
self.logger.debug(f"[train] {message}")
if state.epoch == -1:
continue
if not state.cached and state.epoch % self.checkpoint_interval == 0:
# Save checkpoint if needed
self.context.save_checkpoint()
self.context.copy(self.last_checkpoint_path)
# Call listeners
decision = LearnerListenerStatus.NO_DECISION
for listener in self.listeners:
# listener.__call__ returns True if we should stop
decision = decision.update(listener(state))
if decision == LearnerListenerStatus.STOP:
self.logger.warn(
"stopping after epoch {epoch} ({early_stop} epochs) since "
"all listeners asked for it"
)
break
# Stop if max epoch is reached
if self.context.epoch >= self.max_epochs:
self.logger.warning(
"stopping after epoch {max_epochs} (max_epoch)".format(
**self.__dict__
)
)
break
# End of the learning process
if state is not None and not state.cached:
# Set the hyper-parameters
metrics = {}
for listener in self.listeners:
listener.update_metrics(metrics)
self.context.writer.add_hparams(getattr(self, "__tags__", {}), metrics)
def iter_train(self, device_information) -> Iterator[TrainState]:
"""Train iteration"""
# Try to load a checkpoint
if self.context.load_bestcheckpoint(self.max_epochs):
yield self.context.state
# Get an iterator over batches
iter = self.trainer.iter_batches()
while True:
# Step to the next epoch
self.context.nextepoch()
# Train for an epoch
with tqdm(
leave=False,
total=self.steps_per_epoch,
ncols=100,
desc=f"Train - epoch #{self.context.epoch}",
) as pbar:
# Put the model into training mode (just in case)
self.context.state.model.train()
# Epoch: loop over batches
metrics = Metrics()
for b in range(self.steps_per_epoch):
# Get the next batch
batch = next(iter)
self.context.nextbatch()
while True:
try:
# Computes the gradient, takes a step and collect metrics
with self.context.step(metrics):
# Call the hook epoch hook
foreach(
self.context.hooks(StepTrainingHook),
lambda hook: hook.before(self.context),
)
# Computes the gradient
with torch.autocast(
device_information.device.type,
enabled=self.use_fp16,
):
self.trainer.process_batch(batch)
# Update metrics and counter
pbar.update(1)
break
except RecoverableOOMError:
logger.warning(
"Recoverable OOM detected"
" - re-running the training step"
)
continue
foreach(
self.context.hooks(StepTrainingHook),
lambda hook: hook.after(self.context),
)
# Yields the current state (after one epoch)
yield self.context.state
# Report metrics over the epoch
metrics.report(
self.context.state.step,
self.context.writer,
"train",
)