Skip to content

Commit

Permalink
Fix ModelHubMixin when kwargs and config are both passed (#2138)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored Mar 21, 2024
1 parent 5ad7855 commit 3252e27
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def from_pretrained(
# Forward config to model initialization
model_kwargs["config"] = config

elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()):
for key, value in config.items():
if key not in model_kwargs:
model_kwargs[key] = value
Expand Down
17 changes: 16 additions & 1 deletion tests/test_hub_mixin_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import struct
import unittest
from pathlib import Path
from typing import Any, TypeVar
from typing import Any, Dict, Optional, TypeVar
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -53,10 +53,16 @@ def __init__(
self.num_classes = num_classes
self.state = state
self.not_jsonable = not_jsonable

class DummyModelWithConfigAndKwargs(nn.Module, PyTorchModelHubMixin):
def __init__(self, num_classes: int = 42, state: str = "layernorm", config: Optional[Dict] = None, **kwargs):
super().__init__()

else:
DummyModel = None
DummyModelWithTags = None
DummyModelNoConfig = None
DummyModelWithConfigAndKwargs = None


@requires("torch")
Expand Down Expand Up @@ -346,3 +352,12 @@ def forward(self, x):
b_bias_ptr = state_dict["b.bias"].storage().data_ptr()
assert a_weight_ptr == b_weight_ptr
assert a_bias_ptr == b_bias_ptr

def test_save_pretrained_when_config_and_kwargs_are_passed(self):
# Test creating model with config and kwargs => all values are saved together in config.json
model = DummyModelWithConfigAndKwargs(num_classes=50, state="layernorm", config={"a": 1}, b=2, c=3)
model.save_pretrained(self.cache_dir)
assert model._hub_mixin_config == {"num_classes": 50, "state": "layernorm", "a": 1, "b": 2, "c": 3}

reloaded = DummyModelWithConfigAndKwargs.from_pretrained(self.cache_dir)
assert reloaded._hub_mixin_config == model._hub_mixin_config

0 comments on commit 3252e27

Please sign in to comment.