Skip to content

Commit

Permalink
fix(transformers): fix some bugs in 'modeling_utils'
Browse files Browse the repository at this point in the history
  • Loading branch information
Cui-yshoho committed Oct 17, 2024
1 parent 324e49c commit 4906abd
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions mindone/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,18 @@ def _get_pt2ms_mappings(m):
return mappings


def _get_pt2ms_mapped_kv(mappings, key_pt, value_pt=None, prefix=""):
if key_pt.startswith(prefix):
key_ms, value_mapping = mappings.get(key_pt[len(prefix) :], (key_pt[len(prefix) :], lambda x: x))
key_ms = prefix + key_ms
def _get_pt2ms_mapped_kv(
mappings, key_pt, value_pt=None, prefix="", remove_prefix_from_model=False, add_prefix_to_model=False
):
if remove_prefix_from_model:
key_ms, value_mapping = mappings.get(prefix + key_pt, prefix + key_pt, lambda x: x)
key_ms = key_ms[len(prefix) :]
elif add_prefix_to_model:
if key_pt.startswith(prefix):
key_ms, value_mapping = mappings.get(key_pt[len(prefix) :], (key_pt[len(prefix) :], lambda x: x))
key_ms = prefix + key_ms
else:
key_ms, value_mapping = mappings.get(key_pt, (key_pt, lambda x: x))
else:
key_ms, value_mapping = mappings.get(key_pt, (key_pt, lambda x: x))

Expand All @@ -98,14 +106,16 @@ def _get_pt2ms_mapped_kv(mappings, key_pt, value_pt=None, prefix=""):
return key_ms, value_mapping(value_pt)


def _convert_state_dict(m, state_dict_pt, prefix=""):
def _convert_state_dict(m, state_dict_pt, prefix="", remove_prefix_from_model=False, add_prefix_to_model=False):
if not state_dict_pt:
return state_dict_pt
pt2ms_mappings = _get_pt2ms_mappings(m)
state_dict_ms = {}
while state_dict_pt:
name_pt, data_pt = state_dict_pt.popitem()
name_ms, data_ms = _get_pt2ms_mapped_kv(pt2ms_mappings, name_pt, data_pt, prefix)
name_ms, data_ms = _get_pt2ms_mapped_kv(
pt2ms_mappings, name_pt, data_pt, prefix, remove_prefix_from_model, add_prefix_to_model
)
if name_ms is not None:
state_dict_ms[name_ms] = data_ms
return state_dict_ms
Expand Down Expand Up @@ -1770,11 +1780,6 @@ def _load_pretrained_model(
dtype=None,
keep_in_fp32_modules=None,
):
# Mapping loaded_keys from pt to ms
pt2ms_mappings = _get_pt2ms_mappings(model)
loaded_keys = [
_get_pt2ms_mapped_kv(pt2ms_mappings, s, None, f"{model.base_model_prefix}.")[0] for s in loaded_keys
]
# Retrieve missing & unexpected_keys
model_state_dict = {k: v for k, v in model.parameters_and_names()}
expected_keys = list(model_state_dict.keys())
Expand All @@ -1800,6 +1805,16 @@ def _load_pretrained_model(
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]

# Mapping loaded_keys from pt to ms
pt2ms_mappings = _get_pt2ms_mappings(model)
loaded_keys = [
_get_pt2ms_mapped_kv(pt2ms_mappings, s, None, f"{prefix}.", remove_prefix_from_model, add_prefix_to_model)[
0
]
for s in loaded_keys
]
# TODO: Do we need to move 'original_loaded_keys = loaded_keys' here?

missing_keys = sorted(set(expected_keys) - set(loaded_keys))
unexpected_keys = set(loaded_keys) - set(expected_keys)

Expand Down

0 comments on commit 4906abd

Please sign in to comment.