Skip to content

Commit

Permalink
t5 model conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
hadipash committed Sep 23, 2024
1 parent f3155e1 commit 369110d
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 86 deletions.
7 changes: 3 additions & 4 deletions examples/opensora_hpcai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,9 @@ Please prepare the model checkpoints of T5, VAE, and STDiT and put them under `m

- T5: Download the [DeepFloyd/t5-v1_1-xxl](https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main) folder and put it under `models/`

Convert to ms checkpoint:
```
python tools/convert_t5.py --src models/t5-v1_1-xxl/pytorch_model-00001-of-00002.bin models/t5-v1_1-xxl/pytorch_model-00002-of-00002.bin --target models/t5-v1_1-xxl/model.ckpt
Convert to `safetensors` checkpoint (required by `mindone.transformers`):
```shell
python tools/convert_t5.py --model_dir ./models/t5-v1_1-xxl/
```

- VAE: Download the safetensor checkpoint from [here]((https://huggingface.co/stabilityai/sd-vae-ft-ema/tree/main))
Expand Down
211 changes: 129 additions & 82 deletions examples/opensora_hpcai/tools/convert_t5.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,142 @@
"""
Modified from
https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py
"""
import argparse
import json
import os
from collections import defaultdict
from typing import Dict, List

import torch
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
from tqdm import tqdm

import mindspore as ms


def _load_torch_ckpt(ckpt_file):
source_data = torch.load(ckpt_file, map_location="cpu")
if "state_dict" in source_data:
source_data = source_data["state_dict"]
return source_data


def _load_huggingface_safetensor(ckpt_file):
from safetensors import safe_open

db_state_dict = {}
with safe_open(ckpt_file, framework="pt", device="cpu") as f:
for key in f.keys():
db_state_dict[key] = f.get_tensor(key)
return db_state_dict


LOAD_PYTORCH_FUNCS = {"others": _load_torch_ckpt, "safetensors": _load_huggingface_safetensor}


def load_torch_ckpt(ckpt_path_list):
total_params = {}
for ckpt_path in ckpt_path_list:
extension = ckpt_path.split(".")[-1]
if extension not in LOAD_PYTORCH_FUNCS.keys():
extension = "others"
torch_params = LOAD_PYTORCH_FUNCS[extension](ckpt_path)
total_params.update(torch_params)
return total_params


def convert_pt_name_to_ms(content: str) -> str:
# embedding table name conversion
content = content.replace("shared.weight", "shared.embedding_table")
content = content.replace("encoder.embed_tokens.weight", "encoder.embed_tokens.embedding_table")
content = content.replace(
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.embedding_table",
)
return content


def torch_to_ms_weight(source_fp, target_fp):
source_data = load_torch_ckpt(source_fp)
if "ema" in source_data:
source_data = source_data["ema"]
if "state_dict" in source_data:
source_data = source_data["state_dict"]
target_data = []
for _name_pt in tqdm(source_data, total=len(source_data)):
_name_ms = convert_pt_name_to_ms(_name_pt)
_source_data = source_data[_name_pt].cpu().detach().numpy()
target_data.append({"name": _name_ms, "data": ms.Tensor(_source_data)})
ms.save_checkpoint(target_data, target_fp)
def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
*,
preferred_names: List[str] = None,
discard_names: List[str] = None,
) -> Dict[str, List[str]]:
if preferred_names is None:
preferred_names = []
preferred_names = set(preferred_names)
if discard_names is None:
discard_names = []
discard_names = set(discard_names)

shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
if not complete_names:
if len(shared) == 1:
# Force contiguous
name = list(shared)[0]
state_dict[name] = state_dict[name].clone()
complete_names = {name}
else:
raise RuntimeError(
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep"
f" for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model"
f" since you could be storing much more memory than needed. Please refer to"
f" https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
)

keep_name = sorted(list(complete_names))[0]

# Mechanism to preferentially select keys to keep
# coming from the on-disk file to allow
# loading models saved with a different choice
# of keep_name
preferred = complete_names.difference(discard_names)
if preferred:
keep_name = sorted(list(preferred))[0]

if preferred_names:
preferred = preferred_names.intersection(complete_names)
if preferred:
keep_name = sorted(list(preferred))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove


def check_file_size(sf_filename: str, pt_filename: str):
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size

if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
"""
)


def rename(pt_filename: str) -> str:
filename, ext = os.path.splitext(pt_filename)
local = f"{filename}.safetensors"
local = local.replace("pytorch_model", "model")
return local


def convert_multi(filename: str, *, folder: str):
with open(filename, "r") as f:
data = json.load(f)

filenames = set(data["weight_map"].values())
for filename in tqdm(filenames):
sf_filename = rename(filename)
sf_filename = os.path.join(folder, sf_filename)
convert_file(os.path.join(folder, filename), sf_filename)

index = os.path.join(folder, "model.safetensors.index.json")
with open(index, "w") as f:
newdata = {k: v for k, v in data.items()}
newmap = {k: rename(v) for k, v in data["weight_map"].items()}
newdata["weight_map"] = newmap
json.dump(newdata, f, indent=4)

print("Completed.")


def convert_file(pt_filename: str, sf_filename: str):
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded)

metadata = {"format": "pt"}

# a problem with safetensors and transformers
# https://github.com/huggingface/safetensors/issues/202
for to_remove in to_removes.values():
for tr in to_remove:
loaded[tr] = loaded[tr].clone()

# Force tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}

dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata=metadata)
# check_file_size(sf_filename, pt_filename) # same problem as above
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--src",
"-s",
nargs="+",
default=[],
help="a list of paths to source torch checkpoints, which ends with .pt or .bin",
)
parser.add_argument(
"--target",
"-t",
type=str,
help="Filename to save. Specify folder, e.g., ./models, or file path which ends with .ckpt, e.g., ./models/t5-v1_1-xxl.ckpt",
)

parser.add_argument("--model_dir", type=str, required=True, help="Path to a folder containing the model.")
args = parser.parse_args()
filename = os.path.join(args.model_path, "pytorch_model.bin.index.json")

if not args.target.endswith(".ckpt"):
os.makedirs(args.target, exist_ok=True)
target_fp = os.path.join(args.target, os.path.basename(args.src).split(".")[0] + ".ckpt")
else:
target_fp = args.target

if os.path.exists(target_fp):
print(f"Warnings: {target_fp} will be overwritten!")

torch_to_ms_weight(args.src, target_fp)
print(f"Converted weight saved to {target_fp}")
convert_multi(filename, folder=args.model_dir)

0 comments on commit 369110d

Please sign in to comment.