diff --git a/src/xpmir/text/huggingface/base.py b/src/xpmir/text/huggingface/base.py index 9c99cdd..d2f776a 100644 --- a/src/xpmir/text/huggingface/base.py +++ b/src/xpmir/text/huggingface/base.py @@ -1,9 +1,10 @@ -from abc import ABC, abstractmethod -from dataclasses import InitVar import logging import os +from abc import ABC, abstractmethod +from dataclasses import InitVar from pathlib import Path from typing import Type + import torch.nn as nn from experimaestro import Config, Param @@ -30,7 +31,12 @@ class HFModelConfigFromId(HFModelConfig): model_id: Param[str] """HuggingFace Model ID""" - def get_config(self, options: ModuleInitOptions, automodel: Type[AutoModel]): + def get_config( + self, + options: ModuleInitOptions, + autoconfig: Type[AutoModel], + automodel: Type[AutoConfig], + ): model_id_or_path = self.model_id # Use saved models @@ -49,19 +55,29 @@ def get_config(self, options: ModuleInitOptions, automodel: Type[AutoModel]): ) # Load the model configuration - config = AutoConfig.from_pretrained(model_id_or_path) + config = autoconfig.from_pretrained(model_id_or_path) # Return it return config, model_id_or_path - def __call__(self, options: ModuleInitOptions, automodel: Type[AutoModel]): - config, model_id_or_path = self.get_config(options, automodel) + def __call__( + self, + options: ModuleInitOptions, + autoconfig: Type[AutoConfig], + automodel: Type[AutoModel], + ): + config, model_id_or_path = self.get_config(options, autoconfig, automodel) if options.mode == ModuleInitMode.NONE or options.mode == ModuleInitMode.RANDOM: logging.info("Random initialization of HF model") return config, automodel.from_config(config) - logging.info("Loading model from HF (%s)", self.model_id) + logging.info( + "Loading model from HF (%s) with model %s.%s", + self.model_id, + automodel.__module__, + automodel.__name__, + ) return config, automodel.from_pretrained(model_id_or_path, config=config) @@ -81,6 +97,10 @@ class HFModel(Module): def from_pretrained_id(cls, model_id: str): return cls(config=HFModelConfigFromId(model_id=model_id)) + @property + def autoconfig(self): + return AutoConfig + @property def automodel(self): return AutoModel @@ -93,7 +113,9 @@ def __initialize__(self, options: ModuleInitOptions): """ super().__initialize__(options) - self.hf_config, self.model = self.config(options, self.automodel) + self.hf_config, self.model = self.config( + options, self.autoconfig, self.automodel + ) @property def contextual_model(self) -> nn.Module: