Skip to content

Commit

Permalink
Added custom autoconfig support
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Jul 25, 2024
1 parent 7ee2819 commit 63b5c16
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions src/xpmir/text/huggingface/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from abc import ABC, abstractmethod
from dataclasses import InitVar
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import InitVar
from pathlib import Path
from typing import Type

import torch.nn as nn
from experimaestro import Config, Param

Expand All @@ -30,7 +31,12 @@ class HFModelConfigFromId(HFModelConfig):
model_id: Param[str]
"""HuggingFace Model ID"""

def get_config(self, options: ModuleInitOptions, automodel: Type[AutoModel]):
def get_config(
self,
options: ModuleInitOptions,
autoconfig: Type[AutoModel],
automodel: Type[AutoConfig],
):
model_id_or_path = self.model_id

# Use saved models
Expand All @@ -49,19 +55,29 @@ def get_config(self, options: ModuleInitOptions, automodel: Type[AutoModel]):
)

# Load the model configuration
config = AutoConfig.from_pretrained(model_id_or_path)
config = autoconfig.from_pretrained(model_id_or_path)

# Return it
return config, model_id_or_path

def __call__(self, options: ModuleInitOptions, automodel: Type[AutoModel]):
config, model_id_or_path = self.get_config(options, automodel)
def __call__(
self,
options: ModuleInitOptions,
autoconfig: Type[AutoConfig],
automodel: Type[AutoModel],
):
config, model_id_or_path = self.get_config(options, autoconfig, automodel)

if options.mode == ModuleInitMode.NONE or options.mode == ModuleInitMode.RANDOM:
logging.info("Random initialization of HF model")
return config, automodel.from_config(config)

logging.info("Loading model from HF (%s)", self.model_id)
logging.info(
"Loading model from HF (%s) with model %s.%s",
self.model_id,
automodel.__module__,
automodel.__name__,
)
return config, automodel.from_pretrained(model_id_or_path, config=config)


Expand All @@ -81,6 +97,10 @@ class HFModel(Module):
def from_pretrained_id(cls, model_id: str):
return cls(config=HFModelConfigFromId(model_id=model_id))

@property
def autoconfig(self):
return AutoConfig

@property
def automodel(self):
return AutoModel
Expand All @@ -93,7 +113,9 @@ def __initialize__(self, options: ModuleInitOptions):
"""
super().__initialize__(options)

self.hf_config, self.model = self.config(options, self.automodel)
self.hf_config, self.model = self.config(
options, self.autoconfig, self.automodel
)

@property
def contextual_model(self) -> nn.Module:
Expand Down

0 comments on commit 63b5c16

Please sign in to comment.