from dataclasses import dataclass
from enum import Enum
import threading
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union
from pathlib import Path
import numpy as np
import torch
import logging
import re
from experimaestro import (
Config,
Param,
tagspath,
Task,
PathSerializationLWTask,
)
from experimaestro.scheduler import Job, Listener
from experimaestro.utils import cleanupdir
from experimaestro.scheduler.services import WebService
from xpmir.context import Hook, Context
from xpmir.utils.utils import easylog, Initializable, foreach
from xpmir.learning.metrics import ScalarMetric
from .schedulers import Scheduler
if TYPE_CHECKING:
from xpmir.learning.context import TrainerContext
logger = easylog()
[docs]class Optimizer(Config):
def __call__(self, parameters) -> torch.optim.Optimizer:
raise NotImplementedError()
[docs]class SGD(Optimizer):
"""Wrapper for SGD optimizer in Pytorch"""
lr: Param[float] = 1e-5
"""Learning rate"""
weight_decay: Param[float] = 0.0
"""Weight decay (L2)"""
def __call__(self, parameters):
from torch.optim import SGD
return SGD(parameters, lr=self.lr, weight_decay=self.weight_decay)
[docs]class Adafactor(Optimizer):
"""Wrapper for Adafactor optimizer in Transformers library
See :class:`transformers.optimization.Adafactor` for full documentation
"""
lr: Param[Optional[float]] = None
"""Learning rate"""
weight_decay: Param[float] = 0.0
"""Weight decay (L2)"""
relative_step: Param[bool] = True
"""If true, time-dependent learning rate is computed instead of external
learning rate"""
def __call__(self, parameters):
from transformers.optimization import Adafactor
return Adafactor(
parameters,
lr=self.lr,
weight_decay=self.weight_decay,
relative_step=self.relative_step,
)
[docs]class Adam(Optimizer):
"""Wrapper for Adam optimizer in PyTorch"""
lr: Param[float] = 1e-3
"""Learning rate"""
weight_decay: Param[float] = 0.0
"""Weight decay (L2)"""
eps: Param[float] = 1e-8
def __call__(self, parameters):
from torch.optim import Adam
return Adam(
parameters, lr=self.lr, weight_decay=self.weight_decay, eps=self.eps
)
[docs]class AdamW(Optimizer):
"""Adam optimizer that takes into account the regularization
See the `PyTorch documentation
<https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html>`_
"""
lr: Param[float] = 1e-3
weight_decay: Param[float] = 1e-2
eps: Param[float] = 1e-8
def __call__(self, parameters):
from torch.optim import AdamW
return AdamW(
parameters, lr=self.lr, weight_decay=self.weight_decay, eps=self.eps
)
class ModuleInitMode(Enum):
"""Initialization mode"""
#: Default initialization (i.e. can load default parameters or initialize randomly)
DEFAULT = 0
#: No parameter initialization (just initialize the structure of the model)
NONE = 1
#: Random initialization (initialize the structure, then use a the random
#: number generator to initialize the values)
RANDOM = 2
def to_options(self, random: Optional[np.random.RandomState] = None):
return ModuleInitOptions(self, random)
@dataclass
class ModuleInitOptions:
#: Initialization mode
mode: ModuleInitMode
#: Random generator (only defined when mode is RANDOM)
random: Optional[np.random.RandomState] = None
[docs]class Module(Config, Initializable, torch.nn.Module):
"""A module contains parameters"""
def __init__(self):
Initializable.__init__(self)
torch.nn.Module.__init__(self)
def __initialize__(self, options: ModuleInitOptions):
"""Initialize a module
:param options: The initialization options
"""
pass
def __call__(self, *args, **kwargs):
return torch.nn.Module.__call__(self, *args, **kwargs)
def to(self, *args, **kwargs):
return torch.nn.Module.to(self, *args, **kwargs)
[docs]class ModuleList(Module, Initializable):
"""Groups different models together, to be used within the Learner"""
sub_modules: Param[List[Module]]
def __post_init__(self):
# Register sub-modules
for ix, sub_module in enumerate(self.sub_modules):
self.add_module(str(ix), sub_module)
def __initialize__(self, options: ModuleInitOptions):
for module in self.sub_modules:
module.initialize(options)
def __call__(self, *args, **kwargs):
raise AssertionError("This module cannot be used as such")
def to(self, *args, **kwargs):
return torch.nn.Module.to(self, *args, **kwargs)
[docs]class ModuleLoader(PathSerializationLWTask):
def execute(self):
"""Loads the model from disk using the given serialization path"""
logging.info("Loading model from disk: %s", self.path)
self.value.initialize(ModuleInitMode.NONE.to_options())
data = torch.load(self.path)
self.value.load_state_dict(data)
[docs]class ParameterFilter(Config):
"""One abstract class which doesn't do the filtrage"""
def __call__(self, name, params) -> bool:
"""Returns true if the parameters should be optimized with the
associated optimizer"""
return True
[docs]class RegexParameterFilter(ParameterFilter):
"""gives the name of the model to do the filtrage
Precondition: Only and just one of the includes and excludes can be None"""
includes: Param[Optional[List[str]]] = None
"""The str of params to be included from the model"""
excludes: Param[Optional[List[str]]] = None
"""The str of params to be excludes from the model"""
def __init__(self):
self.name = set()
def __validate__(self):
return self.includes or self.excludes
def __repr__(self) -> str:
return f"RegexParameterFilter({self.includes}, {self.excludes})"
def __call__(self, name, params) -> bool:
# Look first at included
if self.includes:
for regex in self.includes:
if re.search(regex, name):
return True
# Include if not excluded
if not self.excludes:
return False
for regex in self.excludes:
if re.search(regex, name):
return False
return True
[docs]class ParameterOptimizer(Config):
"""Associates an optimizer with a list of parameters to optimize"""
optimizer: Param[Optimizer]
"""The optimizer"""
scheduler: Param[Optional[Scheduler]]
"""The optional scheduler"""
module: Param[Optional[Module]]
"""The module from which parameters should be extracted"""
filter: Param[Optional[ParameterFilter]] = ParameterFilter()
"""How parameters should be selected for this (by default, use them all)"""
def create_optimizer(
self, module: Module, filter: Callable[[str, Any], bool]
) -> torch.optim.Optimizer:
"""Returns a (pytorch) optimizer"""
module = self.module or module
params = [
param
for name, param in module.named_parameters()
if (self.filter is None or self.filter(name, param)) and filter(name, param)
]
if not params:
logging.warning(
"Parameter list: %s", [name for name, _ in module.named_parameters()]
)
raise RuntimeError(f"Parameter list is empty with {self.filter}")
optimizer = self.optimizer(params)
return optimizer
class DuplicateParameterFilter:
"""Filters out already optimized parameters"""
def __init__(self):
self.parameters = set()
def __call__(self, name, params):
if params in self.parameters:
return False
self.parameters.add(params)
return True
[docs]class OptimizationHook(Hook):
"""Base class for all optimization hooks"""
pass
[docs]class GradientHook(OptimizationHook):
"""Hooks that are called when the gradient is computed
The gradient is guaranteed to be unscaled in this case.
"""
pass
[docs]class GradientClippingHook(GradientHook):
"""Gradient clipping"""
max_norm: Param[float]
"""Maximum norm for gradient clipping"""
def __call__(self, main: "ScheduledOptimizer"):
torch.nn.utils.clip_grad_norm_(main.module.parameters(), self.max_norm)
[docs]class GradientLogHook(GradientHook):
""" "Log the gradient norm"""
name: Param[str] = "gradient_norm"
def __call__(self, main: "ScheduledOptimizer"):
sum_norms = 0.0
n_params = 0
with torch.no_grad():
for param in main.module.parameters():
if param.grad is not None:
n_params += param.grad.numel()
sum_norms += param.grad.numel() * param.grad.norm() ** 2
main.trainer_context.writer.add_scalar(
self.name, sum_norms / n_params, main.trainer_context.state.step
)
class ScheduledOptimizer:
def initialize(
self,
param_optimizers: List[ParameterOptimizer],
num_training_steps: int,
module: Module,
use_scaler: bool,
hooks: List[OptimizationHook] = [],
trainer_context: Optional["TrainerContext"] = None,
):
self.schedulers = []
self.scheduler_factories = []
self.optimizers = []
self.scheduler_steps = -1 # Number of scheduler steps
self.num_training_steps = num_training_steps
self.module = module
self.context = Context(hooks)
self.trainer_context = trainer_context
try:
next(module.parameters())
except StopIteration:
raise RuntimeError(f"No parameters to optimize in the module {module}")
filter = DuplicateParameterFilter()
for param_optimizer in param_optimizers:
optimizer = param_optimizer.create_optimizer(module, filter)
self.optimizers.append(optimizer)
self.scheduler_factories.append(param_optimizer.scheduler)
self.reset_schedulers()
assert len(self.schedulers) == len(self.optimizers)
if use_scaler:
logger.info("Using GradScaler when optimizing")
self.scaler = torch.cuda.amp.GradScaler() if use_scaler else None
def load_state_dict(self, state):
for optimizer, optimizer_state in zip(self.optimizers, state["optimizers"]):
optimizer.load_state_dict(optimizer_state)
if self.scaler is not None:
self.scaler.load_state_dict(state["scaler"])
# Re-create schedulers
self.scheduler_steps = state["scheduler_steps"]
self.reset_schedulers()
def reset_schedulers(self):
self.schedulers = []
for optimizer, scheduler_factory in zip(
self.optimizers, self.scheduler_factories
):
if scheduler_factory is None:
self.schedulers.append(None)
else:
self.schedulers.append(
scheduler_factory(
optimizer,
self.num_training_steps,
last_epoch=self.scheduler_steps,
)
)
def state_dict(self):
return {
"optimizers": [optimizer.state_dict() for optimizer in self.optimizers],
"scaler": None if self.scaler is None else self.scaler.state_dict(),
"scheduler_steps": self.scheduler_steps,
}
def scale(self, loss: torch.Tensor):
if self.scaler is None:
return loss
return self.scaler.scale(loss)
def zero_grad(self):
"""Zero-grad for all optimizers"""
for optimizer in self.optimizers:
optimizer.zero_grad()
def optimizer_step(self, context: "TrainerContext"):
"""Performs an optimizer step (using the scaler if defined)"""
if self.scaler is None:
# Apply gradient hooks
foreach(
self.context.hooks(GradientHook),
lambda hook: hook(self),
)
for optimizer in self.optimizers:
optimizer.step()
else:
# Unscale first
for optimizer in self.optimizers:
self.scaler.unscale_(optimizer)
# Apply gradient hooks
foreach(
self.context.hooks(GradientHook),
lambda hook: hook(self),
)
# Step
for optimizer in self.optimizers:
self.scaler.step(optimizer)
context.add_metric(
ScalarMetric("gradient/scaler", self.scaler.get_scale(), 1)
)
self.scaler.update()
def scheduler_step(self, context: "TrainerContext"):
"""Performs a step for all the schedulers"""
for ix, scheduler in enumerate(self.schedulers):
if scheduler is not None:
for p_ix, lr in enumerate(scheduler.get_last_lr()):
context.add_metric(
ScalarMetric(f"gradient/scheduler/{ix+1}/{p_ix+1}", lr, 1)
)
scheduler.step()
self.scheduler_steps += 1
Optimizers = Union[ParameterOptimizer, Optimizer, List[ParameterOptimizer]]
"""Defines a set of optimizers"""
def get_optimizers(optimizers: Optimizers):
"""Returns a list of ParameterOptimizer"""
if isinstance(optimizers, list):
return optimizers
if isinstance(optimizers, ParameterOptimizer):
return [optimizers]
return [ParameterOptimizer(optimizer=optimizers)]
class TensorboardServiceListener(Listener):
def __init__(self, source: Path, target: Path):
self.source = source
self.target = target
def job_state(self, job: Job):
if not job.state.notstarted():
if not self.source.is_symlink():
try:
self.source.symlink_to(self.target)
except Exception:
logger.exception(
"Cannot symlink %s to %s", self.source, self.target
)
class TensorboardService(WebService):
id = "tensorboard"
def __init__(self, path: Path):
super().__init__()
self.path = path
cleanupdir(self.path)
self.path.mkdir(exist_ok=True, parents=True)
logger.info("You can monitor learning with:")
logger.info("tensorboard --logdir=%s", self.path)
self.url = None
def add(self, task: Task, path: Path):
# Wait until config has started
if job := task.__xpm__.job:
if job.scheduler is not None:
tag_path = tagspath(task)
if tag_path:
job.scheduler.addlistener(
TensorboardServiceListener(self.path / tag_path, path)
)
else:
logger.error(
"The task is not associated with tags: "
"cannot link to tensorboard data"
)
else:
logger.debug("No scheduler: not adding the tensorboard data")
else:
logger.error("Task was not started: cannot link to tensorboard job path")
def description(self):
return "Tensorboard service"
def close(self):
if self.server:
self.server.shutdown()
def _serve(self, running: threading.Event):
import tensorboard as tb
try:
logger.info("Starting %s service", self.id)
self.program = tb.program.TensorBoard()
self.program.configure(
host="localhost",
logdir=str(self.path.absolute()),
path_prefix=f"/services/{self.id}",
port=0,
)
self.server = self.program._make_server()
self.url = self.server.get_url()
running.set()
self.server.serve_forever()
except Exception:
logger.exception("Error while starting tensorboard")
running.set()