import logging
import json
from pathlib import Path
from typing import Dict, Iterator, List
from datamaestro_text.data.ir import Adhoc
from experimaestro import Param, pathgenerator, Annotated
import numpy as np
from xpmir.utils.utils import easylog, foreach
from xpmir.evaluation import evaluate
from xpmir.learning.context import TrainState, TrainerContext, ValidationHook
from xpmir.rankers import (
Retriever,
)
from xpmir.learning.learner import LearnerListener, Learner, LearnerListenerStatus
from xpmir.learning.optim import ModuleLoader
logger = easylog()
class ValidationModuleLoader(ModuleLoader):
"""Specializes the validation module loader"""
listener: Param["ValidationListener"]
"""The listener"""
key: Param[str]
"""The key for this listener"""
[docs]class ValidationListener(LearnerListener):
"""Learning validation early-stopping
Computes a validation metric and stores the best result. If early_stop is
set (> 0), then it signals to the learner that the learning process can
stop.
"""
metrics: Param[Dict[str, bool]] = {"map": True}
"""Dictionary whose keys are the metrics to record, and boolean
values whether the best performance checkpoint should be kept for
the associated metric ([parseable by ir-measures](https://ir-measur.es/))"""
dataset: Param[Adhoc]
"""The dataset to use"""
retriever: Param[Retriever]
"""The retriever for validation"""
warmup: Param[int] = -1
"""How many epochs before actually computing the metric"""
bestpath: Annotated[Path, pathgenerator("best")]
"""Path to the best checkpoints"""
info: Annotated[Path, pathgenerator("info.json")]
"""Path to the JSON file that contains the metric values at each epoch"""
validation_interval: Param[int] = 1
"""Epochs between each validation"""
early_stop: Param[int] = 0
"""Number of epochs without improvement after which we stop learning.
Should be a multiple of validation_interval or 0 (no early stopping)"""
hooks: Param[List[ValidationHook]] = []
"""The list of the hooks during the validation"""
def __validate__(self):
assert (
self.early_stop % self.validation_interval == 0
), "Early stop should be a multiple of the validation interval"
def initialize(self, learner: Learner, context: TrainerContext):
super().initialize(learner, context)
self.retriever.initialize()
self.bestpath.mkdir(exist_ok=True, parents=True)
# Checkpoint start
try:
with self.info.open("rt") as fp:
self.top = json.load(fp) # type: Dict[str, Dict[str, float]]
except Exception:
self.top = {}
def update_metrics(self, metrics: Dict[str, float]):
if self.top:
# Just use another key
for metric in self.metrics.keys():
metrics[f"{self.id}/final/{metric}"] = self.top[metric]["value"]
def monitored(self) -> Iterator[str]:
return [key for key, monitored in self.metrics.items() if monitored]
def init_task(self, learner: "Learner", dep):
return {
key: dep(
ValidationModuleLoader(
value=learner.model,
listener=self,
key=key,
path=self.bestpath / key / TrainState.MODEL_PATH,
)
)
for key, store in self.metrics.items()
if store
}
def should_stop(self, epoch=0):
if self.early_stop > 0 and self.top:
epochs_since_imp = (epoch or self.context.epoch) - max(
info["epoch"] for key, info in self.top.items() if self.metrics[key]
)
if epochs_since_imp >= self.early_stop:
return LearnerListenerStatus.STOP
return LearnerListenerStatus.DONT_STOP
def __call__(self, state: TrainState):
# Check that we did not stop earlier (when loading from checkpoint / if other
# listeners have not stopped yet)
if self.should_stop(state.epoch - 1) == LearnerListenerStatus.STOP:
return LearnerListenerStatus.STOP
foreach(
self.hooks,
lambda hook: hook.before(self.context),
)
if state.epoch % self.validation_interval == 0:
# Compute validation metrics
means, details = evaluate(
self.retriever, self.dataset, list(self.metrics.keys()), True
)
for metric, keep in self.metrics.items():
value = means[metric]
self.context.writer.add_scalar(
f"{self.id}/{metric}/mean", value, state.step
)
self.context.writer.add_histogram(
f"{self.id}/{metric}",
np.array(list(details[metric].values()), dtype=np.float32),
state.step,
)
# Update the top validation
if state.epoch >= self.warmup:
topstate = self.top.get(metric, None)
if topstate is None or value > topstate["value"]:
# Save the new top JSON
self.top[metric] = {"value": value, "epoch": self.context.epoch}
# Copy in corresponding directory
if keep:
logging.info(
f"Saving the checkpoint {state.epoch}"
f" for metric {metric}"
)
self.context.copy(self.bestpath / metric)
# Update information
with self.info.open("wt") as fp:
json.dump(self.top, fp)
foreach(
self.hooks,
lambda hook: hook.after(self.context),
)
# Early stopping?
return self.should_stop()