Source code for xpmir.text.huggingface.base

import os
from abc import ABC
from contextlib import nullcontext
from dataclasses import InitVar
from pathlib import Path
from typing import Generic, Optional, Type, TypeVar

import torch.nn as nn
from experimaestro import field, Config, Param, LightweightTask

from xpm_torch import Module
from xpm_torch.configuration import FabricConfiguration
from xpmir.text import TokenizedTexts
from functools import lru_cache

import logging

logger = logging.getLogger(__name__)


try:
    from transformers import (
        AutoConfig,
        AutoModel,
        AutoModelForMaskedLM,
        AutoModelForSequenceClassification,
    )
except Exception:
    logging.error("Install huggingface transformers to use these configurations")
    raise


@lru_cache
def is_local_files_only():
    return os.environ.get("HF_HUB_OFFLINE", "").lower() in ["1", "true", "on"]


def _resolve_model_path(model_id: str, automodel: Type[AutoModel]):
    """Resolves the model ID or local path, checking XPMIR_TRANSFORMERS_CACHE"""
    model_id_or_path = model_id

    if model_path := os.environ.get("XPMIR_TRANSFORMERS_CACHE", None):
        path = (
            Path(model_path)
            / Path(f"{automodel.__module__}.{automodel.__qualname__}")
            / Path(model_id)
        )
        if path.is_dir():
            logging.warning("Using saved model from %s", path)
            model_id_or_path = path
        else:
            logging.warning("Could not find saved model in %s, using HF loading", path)

    return model_id_or_path


[docs] class HFConfig(Config): """Base configuration for HuggingFace models""" pass
[docs] class HFConfigID(HFConfig): """Configuration identified by a HuggingFace model ID""" hf_id: Param[str] """HuggingFace model ID (e.g. ``distilbert-base-uncased``)"""
ConfigT = TypeVar("ConfigT", bound=HFConfig)
[docs] class HFModel(Module, Generic[ConfigT]): """Base transformer class from Huggingface Model structure is created during ``__initialize__`` from the :attr:`config` when available. Pretrained weights can be loaded via init tasks such as :class:`HFModelInitFromID` or :class:`HFFromCheckpoint`. """ config: Param[ConfigT] """HuggingFace model configuration""" model: InitVar[AutoModel] """The HF model""" @property def autoconfig(self): return AutoConfig @property def automodel(self): return AutoModel def __initialize__(self): """Creates the model structure from config.hf_id (no pretrained weights)""" if isinstance(self.config, HFConfigID): hf_id = self.config.hf_id model_id_or_path = _resolve_model_path(hf_id, self.automodel) hf_config = self.autoconfig.from_pretrained( model_id_or_path, trust_remote_code=True, local_files_only=is_local_files_only(), ) self.hf_config = hf_config logging.info( "Creating model structure from config (%s) with %s.%s", hf_id, self.automodel.__module__, self.automodel.__name__, ) self.model = self.automodel.from_config(hf_config) @property def contextual_model(self) -> nn.Module: """Returns the model that only outputs base representations""" return self.model def forward(self, tokenized: TokenizedTexts): tokenized = tokenized.to(self.model.device) kwargs = {} if tokenized.token_type_ids is not None: kwargs["token_type_ids"] = tokenized.token_type_ids return self.model( input_ids=tokenized.ids, attention_mask=tokenized.mask, )
[docs] class HFMaskedLanguageModel(HFModel): model: InitVar[AutoModelForMaskedLM] @property def automodel(self): return AutoModelForMaskedLM def decompose(self): """Decompose into (backbone, transform, decoder). See :func:`~xpmir.text.huggingface.decompose.decompose_mlm_model` for details. """ from xpmir.text.huggingface.decompose import decompose_mlm_model return decompose_mlm_model(self.model)
[docs] class HFSequenceClassification(HFModel): """HuggingFace model for sequence classification""" model: InitVar[AutoModelForSequenceClassification] n_labels: Param[int] = field(default=1, ignore_default=True) # override def __initialize__(self): """Creates the model structure from config.hf_id (no pretrained weights) Checks and modifies the config to match "n_labels" """ if isinstance(self.config, HFConfigID): hf_id = self.config.hf_id model_id_or_path = _resolve_model_path(hf_id, self.automodel) hf_config = self.autoconfig.from_pretrained( model_id_or_path, trust_remote_code=True, local_files_only=is_local_files_only(), ) # ensure that num_labels is one for a Cross-encoder if hasattr(hf_config, "num_labels"): if hf_config.num_labels != self.n_labels: logger.debug( f"hf config 'n_labels' was {hf_config.num_labels}, setting it to {self.n_labels}" ) hf_config.num_labels = self.n_labels else: self.logger.warning( "no 'num_labels param found in config, check that classifier outputs one label" ) self.hf_config = hf_config logger.info( "Creating model structure from config (%s) with %s.%s", hf_id, self.automodel.__module__, self.automodel.__name__, ) self.model = self.automodel.from_config(hf_config) @property def automodel(self): return AutoModelForSequenceClassification
[docs] class HFModelInitBase(LightweightTask, ABC): """Base class for initializing HF models""" model: Param[HFModel[HFConfigID]] def __validate__(self): assert isinstance(self.model.config, HFConfigID), ( f"model.config must be an HFConfigID, got {type(self.model.config)}" ) fabric: Param[Optional[FabricConfiguration]] """The fabric configuration to use for initialization. When set, model creation runs inside ``fabric.init_module()`` so that parameters are allocated directly on the target device and dtype. See https://lightning.ai/docs/fabric/stable/advanced/model_init.html """ def _init_context(self, empty_init: bool): """Returns a context manager for model initialization. When ``self.fabric`` is set, returns ``fabric.init_module(empty_init)``; otherwise returns a no-op context. :param empty_init: If True, parameters are created on the meta device (no memory allocated). Use True when loading weights from a checkpoint (pretrained / saved), False when random init is needed. """ if self.fabric is not None: return self.fabric.get_fabric().init_module(empty_init=empty_init) return nullcontext()
[docs] class HFModelInitFromID(HFModelInitBase): """Load pretrained weights from a HuggingFace Hub model ID. Uses ``model.config.hf_id`` to resolve the model. """ def execute(self): hf_id = self.model.config.hf_id model_id_or_path = _resolve_model_path(hf_id, self.model.automodel) config = self.model.autoconfig.from_pretrained( model_id_or_path, trust_remote_code=True, local_files_only=is_local_files_only(), ) self.model.hf_config = config logging.info( "Loading pretrained model from HF (%s) with %s.%s", hf_id, self.model.automodel.__module__, self.model.automodel.__name__, ) with self._init_context(empty_init=True): self.model.model = self.model.automodel.from_pretrained( model_id_or_path, config=config, trust_remote_code=True, local_files_only=is_local_files_only(), ) self.model._initialized = True
[docs] class HFFromCheckpoint(HFModelInitBase): """Load from a local checkpoint. Uses ``model.config.hf_id`` for the architecture config, then loads weights from ``checkpoint``. """ checkpoint: Param[Path] """The checkpoint path to load weights from""" def execute(self): hf_id = self.model.config.hf_id model_id_or_path = _resolve_model_path(hf_id, self.model.automodel) config = self.model.autoconfig.from_pretrained( model_id_or_path, trust_remote_code=True, local_files_only=is_local_files_only(), ) self.model.hf_config = config logging.info( "Loading model from checkpoint %s (config from %s)", self.checkpoint, hf_id, ) with self._init_context(empty_init=True): self.model.model = self.model.automodel.from_pretrained( self.checkpoint, config=config, trust_remote_code=True, local_files_only=is_local_files_only(), ) self.model._initialized = True