From 3252e27ac86ceac9485146a324686a39322d4bd2 Mon Sep 17 00:00:00 2001 From: Lucain Date: Thu, 21 Mar 2024 15:39:00 +0100 Subject: [PATCH] Fix ModelHubMixin when kwargs and config are both passed (#2138) --- src/huggingface_hub/hub_mixin.py | 2 +- tests/test_hub_mixin_pytorch.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 1b90fd01cf..86dce2dd54 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -398,7 +398,7 @@ def from_pretrained( # Forward config to model initialization model_kwargs["config"] = config - elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()): + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()): for key, value in config.items(): if key not in model_kwargs: model_kwargs[key] = value diff --git a/tests/test_hub_mixin_pytorch.py b/tests/test_hub_mixin_pytorch.py index 2c7601f47f..48511b08d7 100644 --- a/tests/test_hub_mixin_pytorch.py +++ b/tests/test_hub_mixin_pytorch.py @@ -3,7 +3,7 @@ import struct import unittest from pathlib import Path -from typing import Any, TypeVar +from typing import Any, Dict, Optional, TypeVar from unittest.mock import Mock, patch import pytest @@ -53,10 +53,16 @@ def __init__( self.num_classes = num_classes self.state = state self.not_jsonable = not_jsonable + + class DummyModelWithConfigAndKwargs(nn.Module, PyTorchModelHubMixin): + def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[Dict] = None, **kwargs): + super().__init__() + else: DummyModel = None DummyModelWithTags = None DummyModelNoConfig = None + DummyModelWithConfigAndKwargs = None @requires("torch") @@ -346,3 +352,12 @@ def forward(self, x): b_bias_ptr = state_dict["b.bias"].storage().data_ptr() assert a_weight_ptr == b_weight_ptr assert a_bias_ptr == b_bias_ptr + + def test_save_pretrained_when_config_and_kwargs_are_passed(self): + # Test creating model with config and kwargs => all values are saved together in config.json + model = DummyModelWithConfigAndKwargs(num_classes=50, state="layernorm", config={"a": 1}, b=2, c=3) + model.save_pretrained(self.cache_dir) + assert model._hub_mixin_config == {"num_classes": 50, "state": "layernorm", "a": 1, "b": 2, "c": 3} + + reloaded = DummyModelWithConfigAndKwargs.from_pretrained(self.cache_dir) + assert reloaded._hub_mixin_config == model._hub_mixin_config