Source code for xpmir.letor.validation

import logging
import json
from pathlib import Path
import numpy as np
from typing import Dict, Iterator, List
from collections import defaultdict
from numpy import mean
from experimaestro import Config, Param, pathgenerator, Annotated, field
from datamaestro_ir.data import Adhoc

from xpm_torch.trainers.context import ValidationHook
from xpm_torch.learner import (
    TrainState,
    TrainerContext,
    LearnerListener,
    Learner,
    LearnerListenerStatus,
)

from xpmir.evaluation import evaluate
from xpmir.rankers import Retriever

logger = logging.getLogger(__name__)


[docs] class ValidationSettings(Config): """Settings for a validation-specific ModuleLoader. Attached as ``settings`` on the loader to distinguish validation checkpoints from other loaders with the same model and path. """ listener: Param["LearnerListener"] = field(ignore_generated=True) """The listener (kept to change the loader identifier based on the learner listener configuration)""" key: Param[str] """The metric key for this validation checkpoint"""
[docs] class ValidationListener(LearnerListener): """Learning validation early-stopping Computes a validation metric and stores the best result. If early_stop is set (> 0), then it signals to the learner that the learning process can stop. """ metrics: Param[Dict[str, bool]] = field(default={"map": True}, ignore_default=True) """Dictionary whose keys are the metrics to record, and boolean values whether the best performance checkpoint should be kept for the associated metric ([parseable by ir-measures](https://ir-measur.es/))""" dataset: Param[Adhoc] """The dataset to use""" retriever: Param[Retriever] """The retriever for validation""" warmup: Param[int] = field(default=-1, ignore_default=True) """How many epochs before actually computing the metric""" bestpath: Annotated[Path, pathgenerator("best")] """Path to the best checkpoints""" info: Annotated[Path, pathgenerator("info.json")] """Path to the JSON file that contains the metric values at each epoch""" validation_interval: Param[int] = field(default=1, ignore_default=True) """Epochs between each validation""" early_stop: Param[int] = field(default=0, ignore_default=True) """Number of epochs without improvement after which we stop learning. Should be a multiple of validation_interval or 0 (no early stopping)""" hooks: Param[List[ValidationHook]] = field(default=[], ignore_default=True) """The list of the hooks during the validation""" def __validate__(self): assert self.early_stop % self.validation_interval == 0, ( "Early stop should be a multiple of the validation interval" ) def get_best_metrics(self) -> Dict[str, Dict[str, float]]: """Returns the best metrics from the info.json file. Can be called after the training task has completed. Returns a dict like {"RR@10": {"value": 0.325, "epoch": 42}, ...} """ with self.info.open("rt") as fp: return json.load(fp) def initialize(self, learner: Learner, context: TrainerContext): super().initialize(learner, context) self.retriever.initialize() self.bestpath.mkdir(exist_ok=True, parents=True) # Checkpoint start try: with self.info.open("rt") as fp: self.top = json.load(fp) # type: Dict[str, Dict[str, float]] except Exception: self.top = {} def update_metrics(self, metrics: Dict[str, float]): if self.top: # Just use another key for metric in self.metrics.keys(): metrics[f"{self.id}/final/{metric}"] = self.top[metric]["value"] def monitored(self) -> Iterator[str]: return [key for key, monitored in self.metrics.items() if monitored] def init_task(self, learner: "Learner", dep, add_action): result = {} for key, store in self.metrics.items(): if not store: continue loader = dep( learner.model.loader_config( self.bestpath / key / TrainState.MODEL_DIR, settings=ValidationSettings.C(listener=self, key=key), ) ) add_action( learner.model.export_action(loader, default_name=f"{self.id}/{key}") ) result[key] = loader return result def should_stop(self, epoch=0): if self.early_stop > 0 and self.top: epochs_since_imp = (epoch or self.context.epoch) - max( info["epoch"] for key, info in self.top.items() if self.metrics[key] ) if epochs_since_imp >= self.early_stop: return LearnerListenerStatus.STOP return LearnerListenerStatus.DONT_STOP def _log_and_track_ir_metrics(self, means, details, state: TrainState): """Log IR metrics to tensorboard and update best checkpoints.""" for metric, keep in self.metrics.items(): value = means[metric] self.context.writer.add_scalar( f"{self.id}/{metric}/mean", value, state.step ) self.context.writer.add_histogram( f"{self.id}/{metric}", np.array(list(details[metric].values()), dtype=np.float32), state.step, ) # Update the top validation if state.epoch >= self.warmup: topstate = self.top.get(metric, None) if topstate is None or value > topstate["value"]: # Save the new top JSON self.top[metric] = {"value": value, "epoch": self.context.epoch} # Copy in corresponding directory if keep: logging.info( f"Saving the checkpoint {state.epoch} for metric {metric}" ) self.context.copy(self.bestpath / metric) def _on_validation(self, means, details, state: TrainState): """Called after IR metrics are computed. Override to add custom metrics. :param means: Aggregated IR metric values (e.g. {"RR@10": 0.35}) :param details: Per-query IR metric values :param state: Current training state """ pass def __call__(self, state: TrainState): # Check that we did not stop earlier (when loading from checkpoint / if other # listeners have not stopped yet) if self.should_stop(state.epoch - 1) == LearnerListenerStatus.STOP: return LearnerListenerStatus.STOP for hook in self.hooks: hook.before(self.context) if state.epoch % self.validation_interval == 0: # Compute validation metrics means, details = evaluate( self.retriever, self.dataset, list(self.metrics.keys()), True ) self._log_and_track_ir_metrics(means, details, state) self._on_validation(means, details, state) # Update information with self.info.open("wt") as fp: json.dump(self.top, fp) for hook in self.hooks: hook.after(self.context) # Early stopping? return self.should_stop()
[docs] class AggregatorValidationListener(LearnerListener): """Aggregates multiple validation listeners Stops when all the listeners agree to stop. """ listeners: Param[List[ValidationListener]] """The list of validation listeners to aggregate""" metrics: Param[Dict[str, bool]] = field(default={"map": True}, ignore_default=True) """Dictionary whose keys are the metrics to record, and boolean values whether the best performance checkpoint should be kept for the associated metric ([parseable by ir-measures](https://ir-measur.es/))""" warmup: Param[int] = field(default=-1, ignore_default=True) """How many epochs before actually computing the metric""" bestpath: Annotated[Path, pathgenerator("best")] """Path to the best checkpoints""" info: Annotated[Path, pathgenerator("info.json")] """Path to the JSON file that contains the metric values at each epoch""" validation_interval: Param[int] = field(default=1, ignore_default=True) """Epochs between each validation""" early_stop: Param[int] = field(default=0, ignore_default=True) """Number of epochs without improvement after which we stop learning. Should be a multiple of validation_interval or 0 (no early stopping)""" hooks: Param[List[ValidationHook]] = field(default=[], ignore_default=True) """The list of the hooks during the validation""" def __validate__(self): assert self.early_stop % self.validation_interval == 0, ( "Early stop should be a multiple of the validation interval" ) # Check that all listeners have the same validation interval intervals = {listener.validation_interval for listener in self.listeners} if len(intervals) != 1: raise ValueError( f"Listeners have different validation intervals: {intervals}" ) assert self.validation_interval == intervals.pop(), ( "The validation interval of the aggregator should be the same as " "the listeners" ) # Check that the metrics are the same key_sets = [set(listener.metrics.keys()) for listener in self.listeners] # Use the first as reference ref_keys = key_sets[0] for i, ks in enumerate(key_sets[1:], start=1): if ks != ref_keys: raise ValueError( f"Metric key mismatch between listeners:\n" f"Reference: {ref_keys}\n" f"Listener {i}: {ks}" ) assert self.metrics.keys() == ref_keys, ( "The metrics of the aggregator should be the same as the listeners" ) def initialize(self, learner: Learner, context: TrainerContext): super().initialize(learner, context) self.bestpath.mkdir(exist_ok=True, parents=True) # Checkpoint start try: with self.info.open("rt") as fp: self.top = json.load(fp) # type: Dict[str, Dict[str, float]] except Exception: self.top = {} def update_metrics(self, metrics: Dict[str, float]): if self.top: # Just use another key for metric in self.metrics.keys(): metrics[f"{self.id}/final/{metric}"] = self.top[metric]["value"] def monitored(self) -> Iterator[str]: return [key for key, monitored in self.metrics.items() if monitored] def init_task(self, learner: "Learner", dep, add_action): result = {} for key, store in self.metrics.items(): if not store: continue loader = dep( learner.model.loader_config( self.bestpath / key / TrainState.MODEL_DIR, settings=ValidationSettings.C(listener=self, key=key), ) ) add_action( learner.model.export_action(loader, default_name=f"{self.id}/{key}") ) result[key] = loader return result def should_stop(self, epoch=0): if self.early_stop > 0 and self.top: epochs_since_imp = (epoch or self.context.epoch) - max( info["epoch"] for key, info in self.top.items() if self.metrics[key] ) if epochs_since_imp >= self.early_stop: return LearnerListenerStatus.STOP return LearnerListenerStatus.DONT_STOP def __call__(self, state: TrainState): # Check that we did not stop earlier (when loading from checkpoint / if other # listeners have not stopped yet) if self.should_stop(state.epoch - 1) == LearnerListenerStatus.STOP: return LearnerListenerStatus.STOP for hook in self.hooks: hook.before(self.context) if state.epoch % self.validation_interval == 0: values = defaultdict(list) for d in self.listeners: for k, v in d.top.items(): values[k].append(v["value"]) mean_metrics = {k: mean(values[k]) for k in values} for metric, keep in self.metrics.items(): value = mean_metrics[metric] self.context.writer.add_scalar( f"{self.id}/{metric}/mean", value, state.step ) # Update the top validation if state.epoch >= self.warmup: topstate = self.top.get(metric, None) if topstate is None or value > topstate["value"]: # Save the new top JSON self.top[metric] = {"value": value, "epoch": self.context.epoch} # Copy in corresponding directory if keep: logging.info( f"Saving the checkpoint {state.epoch}" f" for metric {metric}" ) self.context.copy(self.bestpath / metric) # Update information with self.info.open("wt") as fp: json.dump(self.top, fp) for hook in self.hooks: hook.after(self.context) # Early stopping? return self.should_stop()