diff --git a/examples/opensora_hpcai/README.md b/examples/opensora_hpcai/README.md index 6b4bfd1bef..d171b27b1c 100644 --- a/examples/opensora_hpcai/README.md +++ b/examples/opensora_hpcai/README.md @@ -248,10 +248,9 @@ Please prepare the model checkpoints of T5, VAE, and STDiT and put them under `m - T5: Download the [DeepFloyd/t5-v1_1-xxl](https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main) folder and put it under `models/` - Convert to ms checkpoint: - ``` - python tools/convert_t5.py --src models/t5-v1_1-xxl/pytorch_model-00001-of-00002.bin models/t5-v1_1-xxl/pytorch_model-00002-of-00002.bin --target models/t5-v1_1-xxl/model.ckpt - + Convert to `safetensors` checkpoint (required by `mindone.transformers`): + ```shell + python tools/convert_t5.py --model_dir ./models/t5-v1_1-xxl/ ``` - VAE: Download the safetensor checkpoint from [here]((https://huggingface.co/stabilityai/sd-vae-ft-ema/tree/main)) diff --git a/examples/opensora_hpcai/tools/convert_t5.py b/examples/opensora_hpcai/tools/convert_t5.py index 65c896eada..1c76b89fd6 100644 --- a/examples/opensora_hpcai/tools/convert_t5.py +++ b/examples/opensora_hpcai/tools/convert_t5.py @@ -1,95 +1,142 @@ +""" +Modified from +https://github.com/huggingface/safetensors/blob/main/bindings/python/convert.py +""" import argparse +import json import os +from collections import defaultdict +from typing import Dict, List import torch +from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file from tqdm import tqdm -import mindspore as ms - -def _load_torch_ckpt(ckpt_file): - source_data = torch.load(ckpt_file, map_location="cpu") - if "state_dict" in source_data: - source_data = source_data["state_dict"] - return source_data - - -def _load_huggingface_safetensor(ckpt_file): - from safetensors import safe_open - - db_state_dict = {} - with safe_open(ckpt_file, framework="pt", device="cpu") as f: - for key in f.keys(): - db_state_dict[key] = f.get_tensor(key) - return db_state_dict - - -LOAD_PYTORCH_FUNCS = {"others": _load_torch_ckpt, "safetensors": _load_huggingface_safetensor} - - -def load_torch_ckpt(ckpt_path_list): - total_params = {} - for ckpt_path in ckpt_path_list: - extension = ckpt_path.split(".")[-1] - if extension not in LOAD_PYTORCH_FUNCS.keys(): - extension = "others" - torch_params = LOAD_PYTORCH_FUNCS[extension](ckpt_path) - total_params.update(torch_params) - return total_params - - -def convert_pt_name_to_ms(content: str) -> str: - # embedding table name conversion - content = content.replace("shared.weight", "shared.embedding_table") - content = content.replace("encoder.embed_tokens.weight", "encoder.embed_tokens.embedding_table") - content = content.replace( - "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", - "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.embedding_table", - ) - return content - - -def torch_to_ms_weight(source_fp, target_fp): - source_data = load_torch_ckpt(source_fp) - if "ema" in source_data: - source_data = source_data["ema"] - if "state_dict" in source_data: - source_data = source_data["state_dict"] - target_data = [] - for _name_pt in tqdm(source_data, total=len(source_data)): - _name_ms = convert_pt_name_to_ms(_name_pt) - _source_data = source_data[_name_pt].cpu().detach().numpy() - target_data.append({"name": _name_ms, "data": ms.Tensor(_source_data)}) - ms.save_checkpoint(target_data, target_fp) +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set([name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + if len(shared) == 1: + # Force contiguous + name = list(shared)[0] + state_dict[name] = state_dict[name].clone() + complete_names = {name} + else: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep" + f" for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model" + f" since you could be storing much more memory than needed. Please refer to" + f" https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mechanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +def check_file_size(sf_filename: str, pt_filename: str): + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError( + f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """ + ) + + +def rename(pt_filename: str) -> str: + filename, ext = os.path.splitext(pt_filename) + local = f"{filename}.safetensors" + local = local.replace("pytorch_model", "model") + return local + + +def convert_multi(filename: str, *, folder: str): + with open(filename, "r") as f: + data = json.load(f) + + filenames = set(data["weight_map"].values()) + for filename in tqdm(filenames): + sf_filename = rename(filename) + sf_filename = os.path.join(folder, sf_filename) + convert_file(os.path.join(folder, filename), sf_filename) + + index = os.path.join(folder, "model.safetensors.index.json") + with open(index, "w") as f: + newdata = {k: v for k, v in data.items()} + newmap = {k: rename(v) for k, v in data["weight_map"].items()} + newdata["weight_map"] = newmap + json.dump(newdata, f, indent=4) + + print("Completed.") + + +def convert_file(pt_filename: str, sf_filename: str): + loaded = torch.load(pt_filename, map_location="cpu", weights_only=True) + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded) + + metadata = {"format": "pt"} + + # a problem with safetensors and transformers + # https://github.com/huggingface/safetensors/issues/202 + for to_remove in to_removes.values(): + for tr in to_remove: + loaded[tr] = loaded[tr].clone() + + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata=metadata) + # check_file_size(sf_filename, pt_filename) # same problem as above + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") if __name__ == "__main__": parser = argparse.ArgumentParser() - - parser.add_argument( - "--src", - "-s", - nargs="+", - default=[], - help="a list of paths to source torch checkpoints, which ends with .pt or .bin", - ) - parser.add_argument( - "--target", - "-t", - type=str, - help="Filename to save. Specify folder, e.g., ./models, or file path which ends with .ckpt, e.g., ./models/t5-v1_1-xxl.ckpt", - ) - + parser.add_argument("--model_dir", type=str, required=True, help="Path to a folder containing the model.") args = parser.parse_args() + filename = os.path.join(args.model_path, "pytorch_model.bin.index.json") - if not args.target.endswith(".ckpt"): - os.makedirs(args.target, exist_ok=True) - target_fp = os.path.join(args.target, os.path.basename(args.src).split(".")[0] + ".ckpt") - else: - target_fp = args.target - - if os.path.exists(target_fp): - print(f"Warnings: {target_fp} will be overwritten!") - - torch_to_ms_weight(args.src, target_fp) - print(f"Converted weight saved to {target_fp}") + convert_multi(filename, folder=args.model_dir)