import logging
import json
from pathlib import Path
from typing import Dict, Iterator, List
from import Adhoc
from experimaestro import Param, pathgenerator, Annotated
import numpy as np
from xpmir.utils.utils import easylog, foreach
from xpmir.evaluation import evaluate
from xpmir.learning.context import TrainState, TrainerContext, ValidationHook
from xpmir.rankers import (
from xpmir.learning.learner import LearnerListener, Learner, LearnerListenerStatus
from xpmir.learning.optim import ModuleLoader

logger = easylog()

[docs]class ValidationListener(LearnerListener): """Learning validation early-stopping Computes a validation metric and stores the best result. If early_stop is set (> 0), then it signals to the learner that the learning process can stop. """ metrics: Param[Dict[str, bool]] = {"map": True} """Dictionary whose keys are the metrics to record, and boolean values whether the best performance checkpoint should be kept for the associated metric ([parseable by ir-measures](""" dataset: Param[Adhoc] """The dataset to use""" retriever: Param[Retriever] """The retriever for validation""" warmup: Param[int] = -1 """How many epochs before actually computing the metric""" bestpath: Annotated[Path, pathgenerator("best")] """Path to the best checkpoints""" info: Annotated[Path, pathgenerator("info.json")] """Path to the JSON file that contains the metric values at each epoch""" validation_interval: Param[int] = 1 """Epochs between each validation""" early_stop: Param[int] = 0 """Number of epochs without improvement after which we stop learning. Should be a multiple of validation_interval or 0 (no early stopping)""" hooks: Param[List[ValidationHook]] = [] """The list of the hooks during the validation""" def __validate__(self): assert ( self.early_stop % self.validation_interval == 0 ), "Early stop should be a multiple of the validation interval" def initialize(self, learner: Learner, context: TrainerContext): super().initialize(learner, context) self.retriever.initialize() self.bestpath.mkdir(exist_ok=True, parents=True) # Checkpoint start try: with"rt") as fp: = json.load(fp) # type: Dict[str, Dict[str, float]] except Exception: = {} def update_metrics(self, metrics: Dict[str, float]): if # Just use another key for metric in self.metrics.keys(): metrics[f"{}/final/{metric}"] =[metric]["value"] def monitored(self) -> Iterator[str]: return [key for key, monitored in self.metrics.items() if monitored] def task_outputs(self, learner: "Learner", dep): """Experimaestro outputs: returns the best checkpoints for each metric""" res = { key: ModuleLoader.construct( learner.model, self.bestpath / key / TrainState.MODEL_PATH, dep ) for key, store in self.metrics.items() if store } return res def init_task(self, learner: "Learner", dep): return { key: dep( ModuleLoader( value=learner.model, path=self.bestpath / key / TrainState.MODEL_PATH, ) ) for key, store in self.metrics.items() if store } def should_stop(self, epoch=0): if self.early_stop > 0 and epochs_since_imp = (epoch or self.context.epoch) - max( info["epoch"] for key, info in if self.metrics[key] ) if epochs_since_imp >= self.early_stop: return LearnerListenerStatus.STOP return LearnerListenerStatus.DONT_STOP def __call__(self, state: TrainState): # Check that we did not stop earlier (when loading from checkpoint / if other # listeners have not stopped yet) if self.should_stop(state.epoch - 1) == LearnerListenerStatus.STOP: return LearnerListenerStatus.STOP foreach( self.hooks, lambda hook: hook.before(self.context), ) if state.epoch % self.validation_interval == 0: # Compute validation metrics means, details = evaluate( self.retriever, self.dataset, list(self.metrics.keys()), True ) for metric, keep in self.metrics.items(): value = means[metric] self.context.writer.add_scalar( f"{}/{metric}/mean", value, state.step ) self.context.writer.add_histogram( f"{}/{metric}", np.array(list(details[metric].values()), dtype=np.float32), state.step, ) # Update the top validation if state.epoch >= self.warmup: topstate =, None) if topstate is None or value > topstate["value"]: # Save the new top JSON[metric] = {"value": value, "epoch": self.context.epoch} # Copy in corresponding directory if keep: f"Saving the checkpoint {state.epoch}" f" for metric {metric}" ) self.context.copy(self.bestpath / metric) # Update information with"wt") as fp: json.dump(, fp) foreach( self.hooks, lambda hook: hook.after(self.context), ) # Early stopping? return self.should_stop()