Skip to content

Commit

Permalink
save lora-ga init model for w0=w0-A0*B0;tiny fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fclearner committed Aug 12, 2024
1 parent 1af0d47 commit 223d284
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
4 changes: 2 additions & 2 deletions wenet/finetune/lora/config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
init_batch_size: 2
init_iters: 4
init_iters: 8
init_config:
mode: "gradient" # option: "simple", "svd", "gradient"
lora_A: "unit" # option: "gaussian", "kaiming", "fan_out_kaiming", "xavier", "zeros", "unit", "orthogonal"
lora_A_std: 0.01 # only needed when lora_A is "gaussian"
lora_B: "unit" # option: "gaussian", "kaiming", "fan_out_kaiming", "xavier", "zeros", "unit", "orthogonal"
lora_B_std: 0.01 # only needed when lora_B is "gaussian"
scale: "stable" # option: "default", "stable", "unit", "normalized", "gd", "weightS"
stable_gamma: 64 # only needed when scale is "stable"
stable_gamma: 2 # only needed when scale is "stable"
direction: "ArB2r" # option: "ArBr", "A2rBr", "ArB2r"(only needed when mode is "gradient")
dtype: "fp32" # option: "bf16", "fp32"
norm_clip: false # norm clipping
6 changes: 4 additions & 2 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,6 @@ def init_model(args, configs):

if hasattr(args, 'use_lora') and args.use_lora:
inject_lora_to_model(model, configs['lora_conf'])
if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path:
load_checkpoint(model, args.lora_ckpt_path)

# If specify checkpoint, load some info from checkpoint
if hasattr(args, 'checkpoint') and args.checkpoint is not None:
Expand All @@ -211,6 +209,10 @@ def init_model(args, configs):
infos = {}
configs["init_infos"] = infos

if hasattr(args, 'use_lora') and args.use_lora:
if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path:
load_checkpoint(model, args.lora_ckpt_path)

print(configs)
# Trye to tie some weights
if hasattr(model, 'tie_or_clone_weights'):
Expand Down
22 changes: 13 additions & 9 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,21 +884,21 @@ def freeze_modules(model, args):
logging.debug("{} module is freezed".format(name))


def reinit_lora(model, args, dataset_configs, tokenizer, seed=777):
def reinit_lora(model, args, configs, tokenizer, seed=777):
from tqdm import tqdm
from wenet.finetune.lora.utils import estimate_gradient, reinit_lora_modules
from wenet.finetune.lora.layers import LoRALayer
from types import SimpleNamespace

logging.info("reinit lora modules.")
with open(args.lora_init_yaml, 'r') as file:
config = yaml.safe_load(file)
lora_config = yaml.safe_load(file)

generator = torch.Generator()
generator.manual_seed(seed)
dataset_conf = copy.deepcopy(dataset_configs['dataset_conf'])
dataset_conf['batch_conf']['batch_size'] = config['init_batch_size']
dataset_type = dataset_configs.get('dataset', 'asr')
dataset_conf = copy.deepcopy(configs['dataset_conf'])
dataset_conf['batch_conf']['batch_size'] = lora_config['init_batch_size']
dataset_type = configs.get('dataset', 'asr')
dataset = init_dataset(dataset_type, args.data_type, args.train_data,
tokenizer, dataset_conf, True)
dataloader = DataLoader(dataset,
Expand All @@ -909,14 +909,18 @@ def reinit_lora(model, args, dataset_configs, tokenizer, seed=777):
generator=generator,
prefetch_factor=args.prefetch)
additional_kwargs = {}
if config["init_config"]["mode"] == "gradient":
named_grads = estimate_gradient(model, dataloader, config['init_iters'])
if lora_config["init_config"]["mode"] == "gradient":
named_grads = estimate_gradient(model, dataloader,
lora_config['init_iters'])
additional_kwargs["named_grads"] = named_grads
config = SimpleNamespace(**config["init_config"])
lora_config = SimpleNamespace(**lora_config["init_config"])
for name, module in tqdm(
model.named_modules(),
desc="Reinitializing Lora",
total=len(list(model.named_modules())),
):
if isinstance(module, LoRALayer):
reinit_lora_modules(name, module, config, **additional_kwargs)
reinit_lora_modules(name, module, lora_config, **additional_kwargs)
# lora_init_model needs to be saved, w0 = w0 - A0 * B0
save_checkpoint(model, os.path.join(args.model_dir, "lora_init.pt"),
infos={"tag":"lora_init", **configs})

0 comments on commit 223d284

Please sign in to comment.