Skip to content

Commit

Permalink
feat(transformers): support adapter loading for from_pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
townwish4git committed Oct 15, 2024
1 parent 4082c8f commit ddf1348
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion mindone/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
cached_file,
download_url,
extract_commit_hash,
find_adapter_config_file,
has_file,
is_offline_mode,
is_remote_url,
Expand Down Expand Up @@ -1311,6 +1312,8 @@ def from_pretrained(
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")

if use_auth_token is not None:
warnings.warn(
Expand All @@ -1323,6 +1326,9 @@ def from_pretrained(
)
token = use_auth_token

if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
adapter_kwargs["token"] = token

if use_safetensors is None and not is_safetensors_available():
use_safetensors = False

Expand All @@ -1348,6 +1354,25 @@ def from_pretrained(
else:
commit_hash = getattr(config, "_commit_hash", None)

# Always True: if is_peft_available():
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)

if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
_commit_hash=commit_hash,
**adapter_kwargs,
)
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]

from_pt = not (from_tf | from_flax)

user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
Expand Down Expand Up @@ -1643,7 +1668,7 @@ def from_pretrained(
with safe_open(resolved_archive_file, framework="np") as f:
metadata = f.metadata()

if metadata.get("format") == "pt":
if metadata.get("format") in ("np", "pt"):
pass
elif metadata.get("format") == "tf":
from_tf = True
Expand Down Expand Up @@ -1742,6 +1767,14 @@ def from_pretrained(
keep_in_fp32_modules=keep_in_fp32_modules,
)

if _adapter_model_path is not None:
model.load_adapter(
_adapter_model_path,
adapter_name=adapter_name,
token=token,
adapter_kwargs=adapter_kwargs,
)

# Set model in evaluation mode to deactivate DropOut modules by default
model.set_train(False)

Expand Down

0 comments on commit ddf1348

Please sign in to comment.