-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
132 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,142 @@ | ||
""" | ||
Modified from | ||
https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py | ||
""" | ||
import argparse | ||
import json | ||
import os | ||
from collections import defaultdict | ||
from typing import Dict, List | ||
|
||
import torch | ||
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file | ||
from tqdm import tqdm | ||
|
||
import mindspore as ms | ||
|
||
|
||
def _load_torch_ckpt(ckpt_file): | ||
source_data = torch.load(ckpt_file, map_location="cpu") | ||
if "state_dict" in source_data: | ||
source_data = source_data["state_dict"] | ||
return source_data | ||
|
||
|
||
def _load_huggingface_safetensor(ckpt_file): | ||
from safetensors import safe_open | ||
|
||
db_state_dict = {} | ||
with safe_open(ckpt_file, framework="pt", device="cpu") as f: | ||
for key in f.keys(): | ||
db_state_dict[key] = f.get_tensor(key) | ||
return db_state_dict | ||
|
||
|
||
LOAD_PYTORCH_FUNCS = {"others": _load_torch_ckpt, "safetensors": _load_huggingface_safetensor} | ||
|
||
|
||
def load_torch_ckpt(ckpt_path_list): | ||
total_params = {} | ||
for ckpt_path in ckpt_path_list: | ||
extension = ckpt_path.split(".")[-1] | ||
if extension not in LOAD_PYTORCH_FUNCS.keys(): | ||
extension = "others" | ||
torch_params = LOAD_PYTORCH_FUNCS[extension](ckpt_path) | ||
total_params.update(torch_params) | ||
return total_params | ||
|
||
|
||
def convert_pt_name_to_ms(content: str) -> str: | ||
# embedding table name conversion | ||
content = content.replace("shared.weight", "shared.embedding_table") | ||
content = content.replace("encoder.embed_tokens.weight", "encoder.embed_tokens.embedding_table") | ||
content = content.replace( | ||
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", | ||
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.embedding_table", | ||
) | ||
return content | ||
|
||
|
||
def torch_to_ms_weight(source_fp, target_fp): | ||
source_data = load_torch_ckpt(source_fp) | ||
if "ema" in source_data: | ||
source_data = source_data["ema"] | ||
if "state_dict" in source_data: | ||
source_data = source_data["state_dict"] | ||
target_data = [] | ||
for _name_pt in tqdm(source_data, total=len(source_data)): | ||
_name_ms = convert_pt_name_to_ms(_name_pt) | ||
_source_data = source_data[_name_pt].cpu().detach().numpy() | ||
target_data.append({"name": _name_ms, "data": ms.Tensor(_source_data)}) | ||
ms.save_checkpoint(target_data, target_fp) | ||
def _remove_duplicate_names( | ||
state_dict: Dict[str, torch.Tensor], | ||
*, | ||
preferred_names: List[str] = None, | ||
discard_names: List[str] = None, | ||
) -> Dict[str, List[str]]: | ||
if preferred_names is None: | ||
preferred_names = [] | ||
preferred_names = set(preferred_names) | ||
if discard_names is None: | ||
discard_names = [] | ||
discard_names = set(discard_names) | ||
|
||
shareds = _find_shared_tensors(state_dict) | ||
to_remove = defaultdict(list) | ||
for shared in shareds: | ||
complete_names = set([name for name in shared if _is_complete(state_dict[name])]) | ||
if not complete_names: | ||
if len(shared) == 1: | ||
# Force contiguous | ||
name = list(shared)[0] | ||
state_dict[name] = state_dict[name].clone() | ||
complete_names = {name} | ||
else: | ||
raise RuntimeError( | ||
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep" | ||
f" for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model" | ||
f" since you could be storing much more memory than needed. Please refer to" | ||
f" https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." | ||
) | ||
|
||
keep_name = sorted(list(complete_names))[0] | ||
|
||
# Mechanism to preferentially select keys to keep | ||
# coming from the on-disk file to allow | ||
# loading models saved with a different choice | ||
# of keep_name | ||
preferred = complete_names.difference(discard_names) | ||
if preferred: | ||
keep_name = sorted(list(preferred))[0] | ||
|
||
if preferred_names: | ||
preferred = preferred_names.intersection(complete_names) | ||
if preferred: | ||
keep_name = sorted(list(preferred))[0] | ||
for name in sorted(shared): | ||
if name != keep_name: | ||
to_remove[keep_name].append(name) | ||
return to_remove | ||
|
||
|
||
def check_file_size(sf_filename: str, pt_filename: str): | ||
sf_size = os.stat(sf_filename).st_size | ||
pt_size = os.stat(pt_filename).st_size | ||
|
||
if (sf_size - pt_size) / pt_size > 0.01: | ||
raise RuntimeError( | ||
f"""The file size different is more than 1%: | ||
- {sf_filename}: {sf_size} | ||
- {pt_filename}: {pt_size} | ||
""" | ||
) | ||
|
||
|
||
def rename(pt_filename: str) -> str: | ||
filename, ext = os.path.splitext(pt_filename) | ||
local = f"{filename}.safetensors" | ||
local = local.replace("pytorch_model", "model") | ||
return local | ||
|
||
|
||
def convert_multi(filename: str, *, folder: str): | ||
with open(filename, "r") as f: | ||
data = json.load(f) | ||
|
||
filenames = set(data["weight_map"].values()) | ||
for filename in tqdm(filenames): | ||
sf_filename = rename(filename) | ||
sf_filename = os.path.join(folder, sf_filename) | ||
convert_file(os.path.join(folder, filename), sf_filename) | ||
|
||
index = os.path.join(folder, "model.safetensors.index.json") | ||
with open(index, "w") as f: | ||
newdata = {k: v for k, v in data.items()} | ||
newmap = {k: rename(v) for k, v in data["weight_map"].items()} | ||
newdata["weight_map"] = newmap | ||
json.dump(newdata, f, indent=4) | ||
|
||
print("Completed.") | ||
|
||
|
||
def convert_file(pt_filename: str, sf_filename: str): | ||
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True) | ||
if "state_dict" in loaded: | ||
loaded = loaded["state_dict"] | ||
to_removes = _remove_duplicate_names(loaded) | ||
|
||
metadata = {"format": "pt"} | ||
|
||
# a problem with safetensors and transformers | ||
# https://github.com/huggingface/safetensors/issues/202 | ||
for to_remove in to_removes.values(): | ||
for tr in to_remove: | ||
loaded[tr] = loaded[tr].clone() | ||
|
||
# Force tensors to be contiguous | ||
loaded = {k: v.contiguous() for k, v in loaded.items()} | ||
|
||
dirname = os.path.dirname(sf_filename) | ||
os.makedirs(dirname, exist_ok=True) | ||
save_file(loaded, sf_filename, metadata=metadata) | ||
# check_file_size(sf_filename, pt_filename) # same problem as above | ||
reloaded = load_file(sf_filename) | ||
for k in loaded: | ||
pt_tensor = loaded[k] | ||
sf_tensor = reloaded[k] | ||
if not torch.equal(pt_tensor, sf_tensor): | ||
raise RuntimeError(f"The output tensors do not match for key {k}") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--src", | ||
"-s", | ||
nargs="+", | ||
default=[], | ||
help="a list of paths to source torch checkpoints, which ends with .pt or .bin", | ||
) | ||
parser.add_argument( | ||
"--target", | ||
"-t", | ||
type=str, | ||
help="Filename to save. Specify folder, e.g., ./models, or file path which ends with .ckpt, e.g., ./models/t5-v1_1-xxl.ckpt", | ||
) | ||
|
||
parser.add_argument("--model_dir", type=str, required=True, help="Path to a folder containing the model.") | ||
args = parser.parse_args() | ||
filename = os.path.join(args.model_path, "pytorch_model.bin.index.json") | ||
|
||
if not args.target.endswith(".ckpt"): | ||
os.makedirs(args.target, exist_ok=True) | ||
target_fp = os.path.join(args.target, os.path.basename(args.src).split(".")[0] + ".ckpt") | ||
else: | ||
target_fp = args.target | ||
|
||
if os.path.exists(target_fp): | ||
print(f"Warnings: {target_fp} will be overwritten!") | ||
|
||
torch_to_ms_weight(args.src, target_fp) | ||
print(f"Converted weight saved to {target_fp}") | ||
convert_multi(filename, folder=args.model_dir) |