-
Notifications
You must be signed in to change notification settings - Fork 32
/
train_nmt.py
1731 lines (1645 loc) · 154 KB
/
train_nmt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
# Copyright 2021 National Institute of Information and Communication Technology (Raj Dabre)
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the
# Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
# The above copyright notice and this permission notice shall
# be included in all copies or substantial portions of the
# Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY
# KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS
# OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
## Basic imports
import os
import sys
import argparse
import time
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
##
## Huggingface imports
import transformers
from transformers import AutoTokenizer, MBartTokenizer, MBart50Tokenizer, BartTokenizer, AlbertTokenizer, BarthezTokenizer
from transformers import MBartForConditionalGeneration, BartForConditionalGeneration, MBartConfig, BartConfig, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
from transformers import AdamW
##
## Pytorch imports
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
try:
import wandb
except:
raise ImportError("Wandb not installed. Recommended: pip install wandb")
try:
import bitsandbytes as bnb
except:
bnb=None
print("Bits and bytes not installed. Dont use the flag --adam_8bit")
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP, MixedPrecision, FullStateDictConfig, ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType, FullStateDictConfig, LocalStateDictConfig
from torch.distributed._shard.checkpoint import (
FileSystemReader,
FileSystemWriter,
save_state_dict,
load_state_dict,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
from torch.distributed.fsdp import BackwardPrefetch
from functools import partial
try:
from torchdistx import deferred_init
from torchdistx.optimizers import AnyPrecisionAdamW
except:
print("torchdistx not installed. Large models will load REALLLLYYYY SLOWLY!")
deferred_init = None
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl
##
## Our imports
from common_utils import *
##
## Other imports
import math
import random
import numpy as np
import sacrebleu
from rouge_score import rouge_scorer
import gc
import functools
from contextlib import nullcontext
import shutil
##
## Seed setting here
torch.manual_seed(621311)
##
## Get torch version
torch_version = torch.__version__
##
def model_create_load_run_save(gpu, args, train_files, dev_files, ewc_files):
"""The main function which does the overall training. Should be split into multiple parts in the future. Currently monolithc intentionally."""
rank = args.nr * args.gpus + gpu ## The rank of the current process out of the total number of processes indicated by world_size.
print("Launching process:", rank)
sys.stdout.flush()
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
if args.use_official_pretrained_tokenizer or args.use_official_pretrained: # If we use an official model then we are using its tokenizer by default.
if "mbart" in args.pretrained_model or "IndicBART" in args.pretrained_model:
if "50" in args.pretrained_model:
tok = MBart50Tokenizer.from_pretrained(args.tokenizer_name_or_path, use_fast=False)
elif "IndicBART" in args.pretrained_model:
tok = AlbertTokenizer.from_pretrained(args.tokenizer_name_or_path, do_lower_case=False, use_fast=False, keep_accents=True)
else:
tok = MBartTokenizer.from_pretrained(args.tokenizer_name_or_path, use_fast=False)
else:
tok = BartTokenizer.from_pretrained(args.tokenizer_name_or_path, use_fast=False)
tgt_tok = None
decoding_tok = tok
else:
if "albert" in args.tokenizer_name_or_path:
tok = AlbertTokenizer.from_pretrained(args.tokenizer_name_or_path, do_lower_case=False, use_fast=False, keep_accents=True)
if args.tgt_tokenizer_name_or_path is not None:
tgt_tok = AlbertTokenizer.from_pretrained(args.tgt_tokenizer_name_or_path, do_lower_case=False, use_fast=False, keep_accents=True)
else:
tgt_tok = None
elif "mbart" in args.tokenizer_name_or_path:
tok = MBartTokenizer.from_pretrained(args.tokenizer_name_or_path, do_lower_case=False, use_fast=False, keep_accents=True)
if args.tgt_tokenizer_name_or_path is not None:
tgt_tok = MBartTokenizer.from_pretrained(args.tgt_tokenizer_name_or_path, do_lower_case=False, use_fast=False, keep_accents=True)
else:
tgt_tok = None
if tgt_tok is None:
decoding_tok = tok
else:
decoding_tok = tgt_tok
## Fast tokenizers are not good because their behavior is weird. Accents should be kept or else the segmentation will be messed up on languages with accented characters. No lower case obviously because we want to train on the original case. Set to false if you are ok with the model not dealing with cases.
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=False) ## In case we do summarization.
tok.save_pretrained(args.model_path+"_deploy") ## Save the tokenizer for future use.
# Copy the specially_added_tokens file into the deploy folder. This file exists when we arent using official pretrained models. We are not going to support this for separate tokenizers.
if os.path.exists(args.tokenizer_name_or_path+"/specially_added_tokens"):
shutil.copyfile(args.tokenizer_name_or_path+"/specially_added_tokens", args.model_path+"_deploy/specially_added_tokens")
print("Tokenizer is:", tok)
if args.tgt_tokenizer_name_or_path is not None:
print("Target tokenizer is:", tgt_tok) # We are not going to save the target tokenizer because it is not compatible with what hugingface expects.
if args.shard_files and rank == 0: ## First shard the data using process 0 aka the prime process or master process. Other processes will wait.
shard_files_bi(train_files, tok, args, additional_tokenizer=tgt_tok)
if args.ewc_importance != 0.0:
shard_files_bi(ewc_files, tok, args, additional_tokenizer=tgt_tok)
if rank == 0:
# handle quitting
with open(args.model_path + ".quitflag", "w") as f:
f.write("0")
# handle annealing
with open(args.model_path + ".anneal", "w") as f:
f.write("0")
# dist.barrier() ## Stop other processes from proceeding till sharding is done. ## Barriers are bad before loading a model they occupy memory for no reason.
if args.supported_languages is not None:
args.supported_languages = args.supported_languages.split(",")
with open(args.model_path+"_deploy/supported_languages.txt", "w") as f:
for supported_pair in args.supported_languages:
f.write(supported_pair.replace("-", " ")+"\n")
print(f"Running DDP/FSDP checkpoint example on rank {rank}.")
sys.stdout.flush()
if args.fp16: ## Although the code supports FP16/AMP training, it tends to be unstable in distributed setups so use this carefully.
print("We will do fp16 training")
if args.use_fsdp:
scaler = ShardedGradScaler(args.init_scale) ## This is the scaler for FSDP. It is different from the one in torch.cuda.amp.
else:
scaler = torch.cuda.amp.GradScaler(args.init_scale) ## Gradient scaler which will be used with torch's automatic mixed precision
# Get scaler info
scaler_info = scaler.state_dict()
# Print scaler info neatly
print("AMP scaler info:")
for key, value in scaler_info.items():
print(f"{key}: {value}")
# Store current scale value
scale_value = scaler.get_scale()
mixed_precision_policy = MixedPrecision(
# Param precision
param_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
# Gradient communication precision.
reduce_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
)
else:
print("We will do fp32 training")
mixed_precision_policy = None
if args.use_fsdp:
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers.
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
optimizer_fullstate_save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
from transformers.models.mbart.modeling_mbart import MBartEncoderLayer, MBartDecoderLayer
if args.auto_wrap_policy == "transformer": ## A block will be kept on a single device. No minimum number of params.
print("We will use transformer auto wrap policy")
mbart_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
MBartEncoderLayer, MBartDecoderLayer,
},
)
else:
print("We will use size based auto wrap policy with min params:", args.fsdp_min_params)
mbart_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=int(args.fsdp_min_params))
if args.activation_checkpointing:
print("We will use activation checkpointing for FSDP.")
non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
check_fn = lambda submodule: (isinstance(submodule, MBartEncoderLayer) or isinstance(submodule, MBartDecoderLayer))
if args.sharding_strategy == "FULL_SHARD":
print("We will use full sharding")
sharding_strategy = ShardingStrategy.FULL_SHARD
elif args.sharding_strategy == "SHARD_GRAD_OP":
print("We will use gradient and optimizer sharding")
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
elif args.sharding_strategy == "HYBRID_SHARD":
print("We will use hybrid sharding. Model is sharded on a node and then each node forms a replica.")
sharding_strategy = ShardingStrategy.HYBRID_SHARD
elif args.sharding_strategy == "_HYBRID_SHARD_ZERO2":
print("Similar to hybrid sharding except that only optimizer and gradient sharding is done over a node.")
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
else:
raise ValueError("Invalid sharding strategy")
backward_prefetch_policy = BackwardPrefetch.BACKWARD_PRE if args.backward_prefetch else None
torch.backends.cuda.matmul.allow_tf32 = args.allow_tf32_matmul
if args.nodes_per_hsdp_group > 1:
print("We will use HSDP replicas of size:", args.nodes_per_hsdp_group*args.gpus, "GPUs and there are total", args.nodes//args.nodes_per_hsdp_group, "HSDP replicas")
assert args.nodes % args.nodes_per_hsdp_group == 0, "The number of nodes should be divisible by the number of nodes per HSDP group"
hsdp_replica_id = rank//(args.nodes_per_hsdp_group*args.gpus)
print("HSDP replica id is:", hsdp_replica_id)
intranode_process_group, _ = dist.new_subgroups(group_size=args.nodes_per_hsdp_group*args.gpus)
for local_rank in range(args.nodes_per_hsdp_group*args.gpus):
internode_ranks = [hsdp_replica_num*args.nodes_per_hsdp_group*args.gpus+rank%(args.nodes_per_hsdp_group*args.gpus) for hsdp_replica_num in range(args.nodes//args.nodes_per_hsdp_group)]
grp = dist.new_group(ranks=internode_ranks, backend='nccl')
if local_rank == (rank%(args.nodes_per_hsdp_group*args.gpus)):
internode_process_group = grp
print("Process groups are:", dist.get_process_group_ranks(intranode_process_group), dist.get_process_group_ranks(internode_process_group))
process_group = (intranode_process_group, internode_process_group)
else:
print("We will use a single node for each HSDP group")
process_group = None
cpu_offload = dist.fsdp.CPUOffload(offload_params=args.fsdp_cpu_offload)
if args.fsdp_cpu_offload:
print("We will use CPU offloading for FSDP")
# dist.barrier()
sys.stdout.flush()
if args.encoder_tying_config is not None:
print("We will use recurrently stacked layers for the encoder with configuration:", args.encoder_tying_config)
if args.decoder_tying_config is not None:
print("We will use recurrently stacked layers for the decoder with configuration:", args.decoder_tying_config)
if args.unidirectional_encoder:
print("Using unidirectional encoder.")
torch.cuda.set_device(gpu) ## Set the device to the current GPU. This is different from the rank so keep this in mind.
if rank == 0:
writer = SummaryWriter(args.model_path+".tflogs")
if args.wb:
run = wandb.init(
project=args.wb_project,
name=args.wb_run,
config=vars(args),
save_code=True,
)
print("Initialization scheme is:", args.initialization_scheme)
if args.initialization_scheme == "static":
print("Static initialization scheme is used. We will use the init_std value of:", args.init_std)
dist.barrier()
if args.use_official_pretrained:
if "mbart" in args.pretrained_model or "IndicBART" in args.pretrained_model:
config = MBartConfig.from_pretrained(args.pretrained_model)
config.init_std = args.init_std # We should set the init_std to be different when using adaptors or newer params.
config.inititialization_scheme = args.initialization_scheme # We should set the init_std to be different when using adaptors or newer params.
config.dropout = args.dropout ## We should set dropouts manually
config.attention_dropout = args.attention_dropout ## We should set dropouts manually
config.activation_dropout = args.activation_dropout ## We should set dropouts manually
config.encoder_layerdrop = args.layerdrop ## We should set dropouts manually
config.decoder_layerdrop = args.layerdrop ## We should set dropouts manually
config.prompt_tuning = args.prompt_tuning ## We should set prompt_tuning_info manually
config.prompt_projection_hidden_size=args.prompt_projection_hidden_size
config.prompt_init_std=args.prompt_init_std ## We should set prompt_init_std manually
config.layernorm_prompt_projection=args.layernorm_prompt_projection ## We should set layernorm_prompt_projection manually
config.no_projection_prompt=args.no_projection_prompt ## We should set no_projection_prompt manually
config.use_tanh_activation_prompt=args.use_tanh_activation_prompt ## We should set use_tanh_activation_prompt manually
config.residual_connection_prompt=args.residual_connection_prompt ## We should set residual_connection_prompt manually
config.num_prompts = args.num_prompts ## We should set num_prompts manually
config.prompt_dropout = args.prompt_dropout ## We should set prompt_dropout manually
config.recurrent_projections = args.recurrent_projections ## We should set recurrent_projections manually
config.adaptor_tuning = args.adaptor_tuning ## We should set adaptor_tuning_info manually
config.deep_adaptor_tuning = args.deep_adaptor_tuning ## We should set deep_adaptor_tuning_info manually
config.deep_adaptor_tuning_ffn_only = args.deep_adaptor_tuning_ffn_only ## We should set deep_adaptor_tuning_info manually
config.adaptor_dropout = args.adaptor_dropout ## We should set adaptor_dropout manually
config.adaptor_activation_function = args.adaptor_activation_function ## We should set adaptor_activation_function manually
config.parallel_adaptors = args.parallel_adaptors ## We should set parallel_adaptors_info manually
config.layernorm_adaptor_input = args.layernorm_adaptor_input ## We should set layernorm_adaptor_input_info manually
config.adaptor_scaling_factor = args.adaptor_scaling_factor ## We should set adaptor_scaling_factor_info manually
config.residual_connection_adaptor = args.residual_connection_adaptor ## We should set residual_connection_adaptor_info manually
config.encoder_adaptor_tying_config = args.encoder_adaptor_tying_config ## We should set encoder_tying_config manually
config.decoder_adaptor_tying_config = args.decoder_adaptor_tying_config ## We should set decoder_tying_config manually
config.adaptor_hidden_size = args.adaptor_hidden_size ## We should set adaptor_hidden_size manually
config.moe_adaptors=args.moe_adaptors ## We should set moe_adaptors_info manually
config.num_moe_adaptor_experts=args.num_moe_adaptor_experts ## We should set num_moe_adaptor_experts_info manually
config.hypercomplex = args.hypercomplex ## We should set hypercomplex manually
config.hypercomplex_n = args.hypercomplex_n ## We should set hypercomplex_n manually
config.ia3_adaptors = args.ia3_adaptors ## We should set ia3_adaptors info manually
config.lora_adaptors = args.lora_adaptors ## We should set lora_adaptors info manually
config.lora_adaptor_rank = args.lora_adaptor_rank ## We should set lora_adaptor_rank info manually
config.softmax_bias_tuning = args.softmax_bias_tuning ## We should set softmax_bias_tuning_info manually
config.gradient_checkpointing = args.gradient_checkpointing ## We should set gradient_checkpointing_info manually
config.sparsify_attention = args.sparsify_attention
config.sparsify_ffn = args.sparsify_ffn
config.num_sparsify_blocks = args.num_sparsify_blocks
config.sparsification_temperature = args.sparsification_temperature
model = deferred_init.deferred_init(MBartForConditionalGeneration.from_pretrained, args.pretrained_model, config=config) if (args.use_fsdp and deferred_init is not None) else MBartForConditionalGeneration.from_pretrained(args.pretrained_model, config=config) ## We may use FBs official model and fine-tune it for our purposes.
config.architectures = ["MBartForConditionalGeneration"]
config.save_pretrained(args.model_path+"_deploy") ## Save the config as a json file to ensure easy loading during future fine tuning of the model.
elif "bart" in args.pretrained_model:
config = BartConfig.from_pretrained(args.pretrained_model)
config.init_std = args.init_std # We should set the init_std to be different when using adaptors or newer params.
config.initialization_scheme = args.initialization_scheme # We should set the initialization_scheme to be different when using adaptors or newer params.
config.dropout = args.dropout ## We should set dropouts manually
config.attention_dropout = args.attention_dropout ## We should set dropouts manually
config.activation_dropout = args.activation_dropout ## We should set dropouts manually
config.encoder_layerdrop = args.layerdrop ## We should set dropouts manually
config.decoder_layerdrop = args.layerdrop ## We should set dropouts manually
config.gradient_checkpointing = args.gradient_checkpointing ## We should set gradient_checkpointing_info manually
model = deferred_init.deferred_init(BartForConditionalGeneration.from_pretrained, args.pretrained_model, config=config, force_bos_token_to_be_generated=True) if (args.use_fsdp and deferred_init is not None) else BartForConditionalGeneration.from_pretrained(args.pretrained_model, config=config, force_bos_token_to_be_generated=True) ## We may use FBs official model and fine-tune it for our purposes.
config.architectures = ["BartForConditionalGeneration"]
config.save_pretrained(args.model_path+"_deploy") ## Save the config as a json file to ensure easy loading during future fine tuning of the model.
else: # We are going to manually specify a config for our locally trained model.
config = MBartConfig(vocab_size=len(tok), target_vocab_size=len(tgt_tok) if tgt_tok is not None else 0, init_std=args.init_std, initialization_scheme=args.initialization_scheme, encoder_layers=args.encoder_layers, decoder_layers=args.decoder_layers, dropout=args.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, encoder_attention_heads=args.encoder_attention_heads, decoder_attention_heads=args.decoder_attention_heads, encoder_ffn_dim=args.encoder_ffn_dim, decoder_ffn_dim=args.decoder_ffn_dim, d_model=args.d_model, embed_low_rank_dim=args.embed_low_rank_dim, no_embed_norm=args.no_embed_norm, scale_embedding=args.scale_embedding, pad_token_id=tok.pad_token_id, eos_token_id=tok(["</s>"], add_special_tokens=False).input_ids[0][0], bos_token_id=tok(["<s>"], add_special_tokens=False).input_ids[0][0], encoder_tying_config=args.encoder_tying_config, decoder_tying_config=args.decoder_tying_config, gradient_checkpointing=args.gradient_checkpointing, multilayer_softmaxing=args.multilayer_softmaxing, wait_k=args.wait_k, additional_source_wait_k=args.additional_source_wait_k, unidirectional_encoder=args.unidirectional_encoder, multi_source=args.multi_source, multi_source_method=args.multi_source_method, mid_fusion_layers=args.mid_fusion_layers, bottleneck_mid_fusion_tokens=args.bottleneck_mid_fusion_tokens, softmax_temperature=args.softmax_temperature, temperature_calibration=args.temperature_calibration, encoder_layerdrop=args.layerdrop, decoder_layerdrop=args.layerdrop, no_scale_attention_embedding=args.no_scale_attention_embedding, positional_encodings=args.positional_encodings, alibi_encoding=args.alibi_encoding, asymmetric_alibi_encoding=args.asymmetric_alibi_encoding, rope_encoding=args.rope_encoding, num_domains_for_domain_classifier=args.num_domains_for_domain_classifier, gradient_reversal_for_domain_classifier=args.gradient_reversal_for_domain_classifier, activation_function=args.activation_function, no_positional_encoding_encoder=args.no_positional_encoding_encoder, no_positional_encoding_decoder=args.no_positional_encoding_decoder, postnorm_encoder=args.postnorm_encoder, postnorm_decoder=args.postnorm_decoder, use_moe=args.use_moe, num_experts=args.num_experts, expert_ffn_size=args.expert_ffn_size, prompt_tuning=args.prompt_tuning, prompt_dropout=args.prompt_dropout, prompt_projection_hidden_size=args.prompt_projection_hidden_size, prompt_init_std=args.prompt_init_std, layernorm_prompt_projection=args.layernorm_prompt_projection, no_projection_prompt=args.no_projection_prompt, use_tanh_activation_prompt=args.use_tanh_activation_prompt, residual_connection_prompt=args.residual_connection_prompt, num_prompts=args.num_prompts, recurrent_projections=args.recurrent_projections, adaptor_tuning=args.adaptor_tuning, deep_adaptor_tuning=args.deep_adaptor_tuning, adaptor_activation_function=args.adaptor_activation_function, deep_adaptor_tuning_ffn_only=args.deep_adaptor_tuning_ffn_only, adaptor_dropout=args.adaptor_dropout, parallel_adaptors = args.parallel_adaptors, layernorm_adaptor_input = args.layernorm_adaptor_input, adaptor_scaling_factor = args.adaptor_scaling_factor, residual_connection_adaptor = args.residual_connection_adaptor, encoder_adaptor_tying_config=args.encoder_adaptor_tying_config, decoder_adaptor_tying_config=args.decoder_adaptor_tying_config, adaptor_hidden_size=args.adaptor_hidden_size, moe_adaptors=args.moe_adaptors, num_moe_adaptor_experts=args.num_moe_adaptor_experts, hypercomplex=args.hypercomplex, hypercomplex_n=args.hypercomplex_n, ia3_adaptors=args.ia3_adaptors, lora_adaptors=args.lora_adaptors, lora_adaptor_rank=args.lora_adaptor_rank, softmax_bias_tuning=args.softmax_bias_tuning, sparsify_attention=args.sparsify_attention, sparsify_ffn=args.sparsify_ffn, num_sparsify_blocks=args.num_sparsify_blocks, sparsification_temperature=args.sparsification_temperature, tokenizer_class="AlbertTokenizer" if "albert" in args.tokenizer_name_or_path else "MBartTokenizer") ## Configuration. TODO: Save this configuration somehow.
config.architectures = ["MBartForConditionalGeneration"]
config.save_pretrained(args.model_path+"_deploy") ## Save the config as a json file to ensure easy loading during future fine tuning of the model.
model = deferred_init.deferred_init(MBartForConditionalGeneration, config) if (args.use_fsdp and deferred_init is not None) else MBartForConditionalGeneration(config)
model.train()
if args.distillation: ## When distilling we need a parent model. The creation of the model is in the same way as the child. This model is immediately loaded with some pretrained params and then loaded into the GPU.
print("We will do distillation from a parent model.")
if args.use_official_parent_pretrained:
if "mbart" in args.parent_pretrained_model or "IndicBART" in args.pretrained_model:
parent_config = MBartConfig.from_pretrained(args.parent_pretrained_model)
parent_config.dropout = args.parent_dropout ## We should set dropouts manually
parent_config.attention_dropout = args.parent_attention_dropout ## We should set dropouts manually
parent_config.activation_dropout = args.parent_activation_dropout ## We should set dropouts manually
parent_config.encoder_layerdrop = args.layerdrop ## We should set dropouts manually
parent_config.decoder_layerdrop = args.layerdrop ## We should set dropouts manually
parent_model = deferred_init.deferred_init(MBartForConditionalGeneration.from_pretrained, args.parent_pretrained_model, config=parent_config) if (args.use_fsdp and deferred_init is not None) else MBartForConditionalGeneration.from_pretrained(args.parent_pretrained_model, config=parent_config) ## We may use FBs official model and fine-tune it for our purposes.
elif "bart" in args.parent_pretrained_model:
parent_config = BartConfig.from_pretrained(args.parent_pretrained_model)
parent_config.dropout = args.parent_dropout ## We should set dropouts manually
parent_config.attention_dropout = args.parent_attention_dropout ## We should set dropouts manually
parent_config.activation_dropout = args.parent_activation_dropout ## We should set dropouts manually
parent_config.encoder_layerdrop = args.layerdrop ## We should set dropouts manually
parent_config.decoder_layerdrop = args.layerdrop ## We should set dropouts manually
parent_model = deferred_init.deferred_init(BartForConditionalGeneration.from_pretrained, args.parent_pretrained_model, config=parent_config, force_bos_token_to_be_generated=True) if (args.use_fsdp and deferred_init is not None) else BartForConditionalGeneration.from_pretrained(args.pretrained_model, config=config, force_bos_token_to_be_generated=True) ## We may use FBs official model and fine-tune it for our purposes.
else: ## Its a locally pre-trained parent model.
parent_config = MBartConfig(vocab_size=len(tok), target_vocab_size=len(tgt_tok) if tgt_tok is not None else 0, encoder_layers=args.parent_encoder_layers, decoder_layers=args.parent_decoder_layers, dropout=args.parent_dropout, attention_dropout=args.parent_attention_dropout, activation_dropout=args.parent_activation_dropout, encoder_attention_heads=args.parent_encoder_attention_heads, decoder_attention_heads=args.parent_decoder_attention_heads, encoder_ffn_dim=args.parent_encoder_ffn_dim, decoder_ffn_dim=args.parent_decoder_ffn_dim, d_model=args.parent_d_model, no_embed_norm=args.no_embed_norm, scale_embedding=args.scale_embedding, pad_token_id=tok.pad_token_id, eos_token_id=tok(["</s>"], add_special_tokens=False).input_ids[0][0], bos_token_id=tok(["<s>"], add_special_tokens=False).input_ids[0][0], encoder_tying_config=args.encoder_tying_config, decoder_tying_config=args.decoder_tying_config, wait_k=args.wait_k, additional_source_wait_k=args.additional_source_wait_k, unidirectional_encoder=args.unidirectional_encoder, multi_source=args.multi_source, multi_source_method=args.multi_source_method, mid_fusion_layers=args.mid_fusion_layers, bottleneck_mid_fusion_tokens=args.bottleneck_mid_fusion_tokens, softmax_temperature=args.softmax_temperature, temperature_calibration=args.temperature_calibration, encoder_layerdrop=args.layerdrop, decoder_layerdrop=args.layerdrop, no_scale_attention_embedding=args.no_scale_attention_embedding, positional_encodings=args.positional_encodings, alibi_encoding=args.alibi_encoding, asymmetric_alibi_encoding=args.asymmetric_alibi_encoding, rope_encoding=args.rope_encoding, activation_function=args.activation_function, no_positional_encoding_encoder=args.no_positional_encoding_encoder, no_positional_encoding_decoder=args.no_positional_encoding_decoder, postnorm_encoder=args.postnorm_encoder, postnorm_decoder=args.postnorm_decoder, use_moe=args.use_moe, num_experts=args.num_experts, expert_ffn_size=args.expert_ffn_size)
parent_model = deferred_init.deferred_init(MBartForConditionalGeneration, parent_config) if (args.use_fsdp and deferred_init is not None) else MBartForConditionalGeneration(parent_config)
parent_model.train() ## We do this to enable dropout but we wont have an optimizer for this so we wont train this model. For now. Future implementations should ask if we want to do co-distill or not. By co-distillation I mean, the parent will learn together with the child.
if not args.use_fsdp:
parent_model.cuda(gpu) ## Move the model to the GPU. ## Remove this and see if the DDP extra memory allocation problem goes away.
print("Memory consumed after moving parent model to GPU", round(torch.cuda.memory_allocated(gpu)/(1024**3), 2), "GB")
else:
print("When FSDP is used, the parent model is not moved to the GPU. This is because FSDP does not support moving the model to the GPU. Instead, it moves the model to the CPU and then to the GPU. This is done to save memory. This is done in the FSDP wrapper itself.")
if args.use_fsdp:
parent_model = FSDP(parent_model, mixed_precision=mixed_precision_policy, device_id=torch.cuda.current_device(), auto_wrap_policy=mbart_auto_wrap_policy, sharding_strategy=sharding_strategy, backward_prefetch=backward_prefetch_policy, process_group=process_group, cpu_offload=cpu_offload, forward_prefetch=args.forward_prefetch) # ,
else:
parent_model = DistributedDataParallel(parent_model, device_ids=[gpu], output_device=gpu)
print("Loading a parent model from which distillation will be done.")
dist.barrier()
# configure map_location properly
if not args.use_official_parent_pretrained:
if args.use_fsdp:
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers.
reader = FileSystemReader(args.parent_pretrained_model + "_sharded")
with FSDP.state_dict_type(parent_model, StateDictType.LOCAL_STATE_DICT):
state_dict = parent_model.state_dict()
load_state_dict(state_dict, reader)
parent_model.load_state_dict(state_dict)
del state_dict
else:
reader = FileSystemReader(args.parent_pretrained_model + "_sharded")
with FSDP.state_dict_type(parent_model, StateDictType.LOCAL_STATE_DICT):
state_dict = parent_model.state_dict()
load_state_dict(state_dict, reader)
parent_model.load_state_dict(state_dict)
del state_dict
else:
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
parent_checkpoint_dict = torch.load(args.parent_pretrained_model, map_location=map_location)
if type(parent_checkpoint_dict) == dict:
parent_model.load_state_dict(parent_checkpoint_dict['model']) # We never do any remapping of the parent. We always reuse it as it is.
else:
parent_model.module.load_state_dict(parent_checkpoint_dict) # We never do any remapping of the parent. We always reuse it as it is.
del parent_checkpoint_dict
dist.barrier()
parent_model.train()
torch.cuda.empty_cache()
freeze_params(model, args.freeze_exception_list, rank)
### NOTE: Please freeze params before wrapping the model in DDP. Mandem almost had a stoke trying to figure this out.
if not args.use_fsdp:
model.cuda(gpu) ## Move the model to the GPU.
print("Memory consumed after moving model to GPU", round(torch.cuda.memory_allocated(gpu)/(1024**3), 2), "GB on rank", rank)
else:
if rank == 0:
print("When FSDP is used, the model is not moved to the GPU. This is because FSDP does not support moving the model to the GPU. Instead, it moves the model to the CPU and then to the GPU. This is done to save memory. This is done in the FSDP wrapper itself.")
if rank == 0:
print("Optimizing", [n for n, p in model.named_parameters() if p.requires_grad])
if args.gradient_checkpointing:
print("Using gradient checkpointing")
num_params_to_optimize = sum(p.numel() for p in model.parameters() if p.requires_grad)
num_model_params = sum(p.numel() for p in model.parameters())
print("Number of model parameters:", num_model_params)
print("Total number of params to be optimized are: ", num_params_to_optimize)
print("Percentage of parameters to be optimized: ", 100*num_params_to_optimize/num_model_params)
if args.use_fsdp:
model = FSDP(model, mixed_precision=mixed_precision_policy, device_id=torch.cuda.current_device(), auto_wrap_policy=mbart_auto_wrap_policy, sharding_strategy=sharding_strategy, backward_prefetch=backward_prefetch_policy, process_group=process_group, cpu_offload=cpu_offload, forward_prefetch=args.forward_prefetch) ## This wrapper around the model will enable sharded distributed training. , forward_prefetch=args.forward_prefetch
if args.sharding_strategy == "HYBRID_SHARD":
print("Process groups are", dist.get_process_group_ranks(model.process_group), dist.get_process_group_ranks(model._inter_node_pg))
# assert dist.get_process_group_ranks(intranode_process_group) == dist.get_process_group_ranks(model.process_group)
# assert dist.get_process_group_ranks(internode_process_group) == dist.get_process_group_ranks(model._inter_node_pg)
else:
model = DistributedDataParallel(model, device_ids=[gpu], output_device=gpu) ## This wrapper around the model will enable distributed training.
if rank == 0:
print("Memory consumed after wrapping with DDP/FSDP", round(torch.cuda.memory_allocated(gpu)/(1024**3), 2), "GB")
if args.activation_checkpointing:
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
"weight_decay": 0.0,
},
] ## We suppose that weight decay will be used except for biases and layer norm weights.
if args.prompt_tuning:
print("Although the percentage of parameters to be optimized is high, during training the number of actual params during decoding are way way lower.")
if args.adam_8bit:
print("Using an 8-bit AdamW optimizer.")
optimizer = bnb.optim.AdamW8bit(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_eps, betas=(0.9, 0.995)) # Our glorious 8 bit optimizer. All hail our lord and savior Tim Dettmers.
elif args.adam_anyprecision:
print("Using an anyprecision AdamW optimizer.")
optimizer = AnyPrecisionAdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_eps, betas=(0.9, 0.995), use_kahan_summation=True, momentum_dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16), variance_dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16), compensation_buffer_dtype=(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16)) # Our glorious anyprecision optimizer. All hail our lord and savior Tim Dettmers.
else:
print("Using an 32-bit AdamW optimizer.")
optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_eps) ## Our glorious optimizer.
model.train()
if args.lr_scheduler == "linear":
scheduler = get_linear_schedule_with_warmup(optimizer, args.warmup_steps, args.num_batches) ## A warmup and decay scheduler. We use the linear scheduler for now. TODO: Enable other schedulers with a flag.
elif args.lr_scheduler == "cosine":
scheduler = get_cosine_schedule_with_warmup(optimizer, args.warmup_steps, args.num_batches, num_cycles=args.cosine_scheduler_num_cycles) ## A warmup and decay scheduler. We use the linear scheduler for now. TODO: Enable other schedulers with a flag.
elif args.lr_scheduler == "cosine_with_restarts":
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, args.warmup_steps, args.num_batches, num_cycles=args.cosine_scheduler_num_cycles)
else:
raise ValueError("Invalid LR scheduler")
while scheduler.get_lr()[0] < 1e-7: ## We want to keep a minimum learning rate else for the initial batch or initial few batches barely anything will be learned which is a waste of computation. This minimum value is kept to 1e-7 by default in accordance with previous literature, other implementations and the Paris peace accords.
scheduler.step()
if rank == 0:
print("Initial LR is:", scheduler.get_lr()[0], ", max LR is:", args.lr, ", warmup steps are:", args.warmup_steps, ", total number of batches/steps are:", args.num_batches)
sys.stdout.flush()
if args.pretrained_model != "" and (not args.use_official_pretrained or args.locally_fine_tuned_model_path is not None): ## Here we load a pretrained NMT model or a previous checkpoint in case training crashed. Note the args.locally_fine_tuned_model_path. This is in case we were tuning an official mbart or indicbart or bart model but want to further tine tune it or it crashed and we want to resume training it. FIXME FSDP loading needs to be handled here.
print("Loading from checkpoint. Strict loading by default but if there are missing or non matching keys or if we use prompt or adaptor tuning, they will be ignored when layer remapping or component selection is done. In case of prompt and adaptor tuning, new params are added to the model and hence strict matching of keys is not possible.")
dist.barrier()
# configure map_location properly
if args.locally_fine_tuned_model_path is not None: ## Now that the pretrained_model argument was used to instantiate the model, it can be replaced with the local model path. Remember to specify pure model or the model with the optimizer and scheduler states depending on your requirement by relying on the flag --no_reload_optimizer_ctr_and_scheduler.
args.pretrained_model = args.locally_fine_tuned_model_path
if args.use_fsdp: # With FSDP models I would rather not risk pruning or layer remapping. So I am not going to do it. I am going to load the model as it is. Consider doing this externally before loading the model.
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers.
reader = FileSystemReader(args.pretrained_model + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
load_state_dict(state_dict, reader)
model.load_state_dict(state_dict)
del state_dict
else:
reader = FileSystemReader(args.pretrained_model + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
load_state_dict(state_dict, reader)
model.load_state_dict(state_dict) # Check if strict loading is required here. We ideally dont want it to be so if we add prompts or adaptors.
if args.prompt_tuning and args.initialize_prompts_with_random_embeddings: # This might fail for FSDP so please check. TODO.
model.module.initialize_prompt_params_with_random_embeddings()
if not args.no_reload_optimizer_ctr_and_scheduler:
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers so we will load a sharded optimizer.
reader = FileSystemReader(args.pretrained_model+ "_optim_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
optim_dict = load_sharded_optimizer_state_dict(
model_state_dict=model.state_dict(),
optimizer_key="optim",
storage_reader=reader,
)
flattened_osd = FSDP.optim_state_dict_to_load(model, optimizer, optim_dict["optim"])
optimizer.load_state_dict(flattened_osd)
del flattened_osd
del optim_dict
else:
full_optimizer = None
if rank == 0:
full_optimizer = torch.load(args.pretrained_model+ "_optim") ## We now load only the optimizer and scheduler.
sharded_optimizer = FSDP.scatter_full_optim_state_dict(full_optimizer, model)
optimizer.load_state_dict(sharded_optimizer)
scheduler_and_ctr = torch.load(args.pretrained_model + "_scheduler_and_ctr")
scheduler.load_state_dict(scheduler_and_ctr['scheduler'])
ctr = scheduler_and_ctr['ctr']
del scheduler_and_ctr
del sharded_optimizer
del full_optimizer
else:
ctr = 0
del state_dict
else:
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
checkpoint_dict = torch.load(args.pretrained_model, map_location=map_location)
if type(checkpoint_dict) == dict:
model.load_state_dict(remap_embeddings_eliminate_components_and_eliminate_mismatches(model.state_dict(), remap_layers(checkpoint_dict['model'], 4, args, rank), args), strict=True if (args.remap_encoder == "" and args.remap_decoder == "" and not args.eliminate_encoder_before_initialization and not args.eliminate_decoder_before_initialization and not args.eliminate_embeddings_before_initialization and not args.prompt_tuning and not args.adaptor_tuning and not args.deep_adaptor_tuning and not args.deep_adaptor_tuning_ffn_only and not args.ia3_adaptors and not args.lora_adaptors and not args.softmax_bias_tuning and not args.sparsify_attention) else False)
if args.prompt_tuning and args.initialize_prompts_with_random_embeddings:
model.module.initialize_prompt_params_with_random_embeddings()
if not args.no_reload_optimizer_ctr_and_scheduler and args.remap_encoder == '' and args.remap_decoder == '' and not args.eliminate_encoder_before_initialization and not args.eliminate_decoder_before_initialization and not args.eliminate_embeddings_before_initialization: ## Do not load optimizers, ctr and schedulers when remapping or resuming training.
if 'optimizer' in checkpoint_dict:
print("Reloading optimizer")
optimizer.load_state_dict(checkpoint_dict['optimizer']) ## Dubious
if 'scheduler' in checkpoint_dict:
print("Reloading scheduler")
scheduler.load_state_dict(checkpoint_dict['scheduler']) ## Dubious
if 'ctr' in checkpoint_dict:
print("Reloading ctr. This means we resume training.")
ctr = checkpoint_dict['ctr']
else:
ctr = 0
else:
model.module.load_state_dict(remap_embeddings_eliminate_components_and_eliminate_mismatches(model.state_dict(), remap_layers(checkpoint_dict, 3, args, rank), args), strict=True if (args.remap_encoder == "" and args.remap_decoder == "" and not args.eliminate_encoder_before_initialization and not args.eliminate_decoder_before_initialization and not args.eliminate_embeddings_before_initialization and not args.prompt_tuning and not args.adaptor_tuning and not args.deep_adaptor_tuning and not args.deep_adaptor_tuning_ffn_only and not args.ia3_adaptors and not args.lora_adaptors and not args.softmax_bias_tuning and not args.sparsify_attention) else False)
if args.prompt_tuning and args.initialize_prompts_with_random_embeddings:
model.module.initialize_prompt_params_with_random_embeddings()
ctr = 0
del checkpoint_dict
else:
if args.use_official_pretrained:
print("Training from official pretrained model")
if args.prompt_tuning and args.initialize_prompts_with_random_embeddings:
model.module.initialize_prompt_params_with_random_embeddings() ## This might fail for FSDP so we need to check. TODO.
else:
print("Training from scratch")
CHECKPOINT_PATH = args.model_path
ctr=0
if args.use_fsdp: # For FSDP we will save the model params, optimizer, scheduler and ctr in separate files. This is because FSDP saving everything in a single file is too heavy.
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers so we will load a sharded optimizer.
model_shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_sharded")
optim_shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_optim_sharded")
print("Preparing to go into sharded saving for rank", rank)
sys.stdout.flush()
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
print("Gathering on rank", rank)
sys.stdout.flush()
state_dict = model.state_dict()
print("Saving sharded model")
sys.stdout.flush()
save_state_dict(state_dict, model_shard_writer)
sys.stdout.flush()
print("Gathering optimizer on rank", rank)
optim_dict = {"optim": FSDP.optim_state_dict(model, optimizer)}
print("Saving sharded optimizer")
save_state_dict(optim_dict, optim_shard_writer)
print("Saved sharded model and optimizer")
sys.stdout.flush()
del state_dict
del optim_dict
## Also save the full state dict for the model and optimizer.
print("Preparing to go into full saving for rank", rank)
sys.stdout.flush()
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fullstate_save_policy, optimizer_fullstate_save_policy):
print("Gathering")
sys.stdout.flush()
state_dict = model.state_dict()
optim_dict = FSDP.full_optim_state_dict(model, optimizer)
print("Gathered")
sys.stdout.flush()
if rank == 0:
print("Saving model on rank 0")
sys.stdout.flush()
torch.save(state_dict, CHECKPOINT_PATH)
print("Saving optimizer on rank 0")
sys.stdout.flush()
torch.save(optim_dict, CHECKPOINT_PATH + "_optim")
print("Saving scheduler and ctr on rank 0")
sys.stdout.flush()
scheduler_and_ctr = {'scheduler': scheduler.state_dict(), 'ctr': 0}
torch.save(scheduler_and_ctr, CHECKPOINT_PATH + "_scheduler_and_ctr")
print("Saved model, optimizer, scheduler and ctr")
sys.stdout.flush()
os.system("cp "+CHECKPOINT_PATH+" "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
del scheduler_and_ctr
del state_dict
del optim_dict
else:
shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
save_state_dict(state_dict, shard_writer)
del state_dict
state_dict = None
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fullstate_save_policy): ## This full state dict is what is messing things up. The model should be saved as local state dicts and then assembled as a full state dict in the end if needed. A presharding and unsharding script may be useful. We have used an offload to CPU policy and this means we hopefully wont run out of memory.
state_dict = model.state_dict()
optim_dict = FSDP.full_optim_state_dict(model, optimizer)
if rank == 0:
torch.save(state_dict, CHECKPOINT_PATH)
torch.save(optim_dict, CHECKPOINT_PATH + "_optim")
scheduler_and_ctr = {'scheduler': scheduler.state_dict(), 'ctr': 0}
torch.save(scheduler_and_ctr, CHECKPOINT_PATH + "_scheduler_and_ctr")
os.system("cp "+CHECKPOINT_PATH+" "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
del scheduler_and_ctr
del state_dict
del optim_dict
else:
if rank == 0:
checkpoint_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'ctr': 0}
torch.save(checkpoint_dict, CHECKPOINT_PATH) ## Save a model by default every eval_every steps. This model will be saved with the same file name each time.
torch.save(model.module.state_dict(), CHECKPOINT_PATH+".pure_model")
os.system("cp "+CHECKPOINT_PATH+".pure_model "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
del checkpoint_dict
sys.stdout.flush()
dist.barrier()
if args.use_fsdp: ## This is consuming CPU ram. Need an optimization here. We need to make a decision whether we are going to go for a full state dict or a local state dict. If we are going to go for a full state dict, we need to make sure that we are not going to run out of memory.
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers so we will load a sharded optimizer.
reader = FileSystemReader(CHECKPOINT_PATH + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
load_state_dict(state_dict, reader)
model.load_state_dict(state_dict)
del state_dict
reader = FileSystemReader(CHECKPOINT_PATH + "_optim_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
optim_dict = load_sharded_optimizer_state_dict(
model_state_dict=model.state_dict(),
optimizer_key="optim",
storage_reader=reader,
)
flattened_osd = FSDP.optim_state_dict_to_load(model, optimizer, optim_dict["optim"])
optimizer.load_state_dict(flattened_osd)
del flattened_osd
del optim_dict
scheduler_and_ctr = torch.load(CHECKPOINT_PATH + "_scheduler_and_ctr")
scheduler.load_state_dict(scheduler_and_ctr['scheduler'])
ctr = scheduler_and_ctr['ctr']
del scheduler_and_ctr
else:
reader = FileSystemReader(CHECKPOINT_PATH + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
load_state_dict(state_dict, reader)
model.load_state_dict(state_dict)
full_optimizer = None
if rank == 0:
full_optimizer = torch.load(CHECKPOINT_PATH+ "_optim") ## We now load only the optimizer and scheduler.
sharded_optimizer = FSDP.scatter_full_optim_state_dict(full_optimizer, model)
optimizer.load_state_dict(sharded_optimizer)
scheduler_and_ctr = torch.load(CHECKPOINT_PATH + "_scheduler_and_ctr")
scheduler.load_state_dict(scheduler_and_ctr['scheduler'])
ctr = scheduler_and_ctr['ctr']
del scheduler_and_ctr
del sharded_optimizer
del full_optimizer
del state_dict
else:
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
checkpoint_dict = torch.load(CHECKPOINT_PATH, map_location=map_location)
model.load_state_dict(checkpoint_dict['model'])
optimizer.load_state_dict(checkpoint_dict['optimizer'])
scheduler.load_state_dict(checkpoint_dict['scheduler'])
ctr = checkpoint_dict['ctr']
del checkpoint_dict
torch.cuda.empty_cache()
dist.barrier()
model.train()
print("Using label smoothing of", args.label_smoothing)
print("Using gradient clipping norm of", args.max_gradient_clip_value)
print("Using softmax temperature of", args.softmax_temperature)
if args.max_ent_weight != -1:
print("Doing entropy maximization during loss computation.")
if args.multistep_optimizer_steps > 1:
print("Using a multistep optimizer where gradients will be accumulated over", args.multistep_optimizer_steps, "batches.")
if args.ewc_importance != 0: ## Set up elastic weight consolidation
print("Using Elastic Weight Consolidation with importance", args.ewc_importance)
print("Number of training batches to compute Fisher coefficients:", args.ewc_samples)
num_batches_tmp = args.num_batches
args.num_batches = args.ewc_samples
print("Learning Fisher coefficients.")
files = ewc_files
datagenerator = generate_batches_bilingual(tok, args, files, rank, tgt_tok=tgt_tok)
ewc_loss = EWC(model, datagenerator, gpu, args.label_smoothing, ignore_index=tok.pad_token_id)
args.num_batches = num_batches_tmp
print("Fisher coefficients learned.")
num_batches_this_optimizer_step = 0
losses = 0
global_sbleu_history = [] ## To save the global evaluation metric history.
max_global_sbleu = 0 ## Maximum global evaluation metric score.
max_global_sbleu_step = 0 ## Step at which we achieved the maximum global evaluation metric score.
individual_sbleu_history = [[dev_pair, []] for dev_pair in dev_files] ## For multilingual NMT settings we suppose that we will keep a track of the histories for individual language pairs being evaluated and this dictionary keeps track of the history.
max_individual_sbleu = [[dev_pair, 0] for dev_pair in dev_files] ## The maximum score per pair.
max_individual_sbleu_step = [[dev_pair, 0] for dev_pair in dev_files] ## The step at which maximum score was achieved per pair.
curr_eval_step = 0
annealing_attempt = 0 ## We use this to limit the number of times annealing will take place. When we anneal the LR is divided by a factor. How this is achieved will be explained below.
inps = [[dev_pair, [inpline.strip() for inpline in open(dev_pair_info[0])][:args.max_eval_batches*args.dev_batch_size]] for dev_pair, dev_pair_info in dev_files] ## Get all inputs for each pair. Select up to args.max_eval_batches*args.dev_batch_size examples.
if args.is_summarization: ## Slight data structure difference for summarization vs translation when computing the evaluation metric. For summarization the metric is Rouge.
refs = [[dev_pair, [[refline.strip() for refline in open(dev_pair_info[1])][:args.max_eval_batches*args.dev_batch_size]]] for dev_pair, dev_pair_info in dev_files] ## Get all references for each input. Select up to args.max_eval_batches*args.dev_batch_size examples.
else:
refs = [[dev_pair, [[refline.strip() for refline in open(dev_pair_info[1])][:args.max_eval_batches*args.dev_batch_size]]] for dev_pair, dev_pair_info in dev_files] ## Get all references for each input. Select up to args.max_eval_batches*args.dev_batch_size examples.
start = time.time()
# We need a tensor to keep track of batch stats. This tensor should be reduced across all processes.
batch_stats = torch.zeros(7, dtype=torch.long, device=gpu)
avg_memory_stats = torch.zeros(2, dtype=torch.float, device=gpu)
for input_ids, input_masks, decoder_input_ids, labels in generate_batches_bilingual(tok, args, train_files, rank, tgt_tok=tgt_tok): #Batches are generated from here. The argument (0.30, 0.40) is a range which indicates the percentage of the source sentence to be masked in case we want masking during training just like we did during BART pretraining. The argument 3.5 is the lambda to the poisson length sampler which indicates the average length of a word sequence that will be masked.
if ctr % args.eval_every == 0 and num_batches_this_optimizer_step == 0: ## We have to evaluate our model every eval_every steps. We dont do evaluation for FSDP.
CHECKPOINT_PATH = args.model_path
# if args.use_fsdp:
# with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fullstate_save_policy):
# state_dict = model.state_dict()
# optim_state = FSDP.full_optim_state_dict(model, optimizer)
# checkpoint_dict = {'model': state_dict, 'optimizer': optim_state, 'scheduler': scheduler.state_dict(), 'ctr': ctr}
# ## When using FSDP its important to run a dummy batch before running generate.
# dummy_batch = torch.tensor([[tok.bos_token_id, tok.eos_token_id]]).to(gpu)
# dummy_mask = torch.tensor([[1, 1]]).to(gpu)
# dummy_label_mask = torch.tensor([[1, 1]]).to(gpu)
# mod_compute = model(input_ids=dummy_batch, attention_mask=dummy_mask ,decoder_input_ids=dummy_batch, output_hidden_states=args.distillation, output_attentions=args.distillation, additional_input_ids=dummy_batch if args.multi_source else None, additional_input_ids_mask=dummy_mask if args.multi_source else None, label_mask=dummy_mask if args.num_domains_for_domain_classifier > 1 else None) ## Run the model and get logits.
# del dummy_batch, dummy_mask, dummy_label_mask , mod_compute
# else:
if args.use_fsdp:
assert args.no_eval
if not args.no_eval: ## Evaluation will be done only on the prime/master process which is at rank 0. Other processes will sleep. and not args.use_fsdp) or args.use_fsdp
if rank == 0: ## If we dont care about early stopping and only on training for a bazillion batches then you can save time by skipping evaluation.
checkpoint_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'ctr': ctr} ## This training state will be saved.
print("Running eval on dev set(s)")
if args.mixed_wait_k:
model.module.config.wait_k = args.wait_k
hyp = [[dev_pair, []] for dev_pair, dev_pair_info in dev_files]
sbleus = []
model.eval() ## We go to eval mode so that there will be no dropout.
for dev_idx, [dev_pair, dev_pair_info] in enumerate(dev_files): ## For each evaluation pair we will decode and compute scores.
slangtlang =dev_pair.strip().split("-")
if args.multi_source: ## In case we do multisource NMT
slang=slangtlang[0]+"-"+slangtlang[1] ## This will be split in the generate_batches_eval function as we expect a triplet.
tlang=slangtlang[2]
else:
slang=slangtlang[0]
tlang=slangtlang[1]
eval_batch_counter = 0
for dev_input_ids, dev_input_masks in generate_batches_eval_bilingual(tok, args, inps[dev_idx][1], slang):
if args.multi_source:
dev_input_ids_parent = dev_input_ids[1]
dev_input_ids = dev_input_ids[0]
dev_input_masks_parent = dev_input_masks[1]
dev_input_masks = dev_input_masks[0]
dev_input_ids_parent = dev_input_ids_parent.to(gpu) ## Move to GPU.
dev_input_masks_parent = dev_input_masks_parent.to(gpu) ## Move to GPU.
if args.prompt_tuning:
dev_input_shape = dev_input_masks.size()
encoder_pad = torch.ones(dev_input_shape[0], args.num_prompts).clone().detach()
dev_input_masks = torch.cat([encoder_pad, dev_input_masks], dim=1)
dev_input_ids = dev_input_ids.to(gpu) ## Move to GPU.
dev_input_masks = dev_input_masks.to(gpu) ## Move to GPU.
if args.is_summarization and rank==0: ## Things can be slow so best show progress
print("Decoding batch from a pool of", len(inps[dev_idx][1]), "examples")
with torch.no_grad(): ## torch.no_grad is apparently known to prevent the code from allocating memory for gradient computation in addition to making things faster. I have not verified this but have kept it as a safety measure to ensure that my model is not being directly tuned on the development set.
translations = model.module.generate(dev_input_ids, use_cache=True, num_beams=1, max_length=int((len(dev_input_ids[0])*args.max_decode_length_multiplier) if args.max_decode_length_multiplier > 0 else -args.max_decode_length_multiplier), min_length=int((len(dev_input_ids[0])*args.min_decode_length_multiplier) if args.min_decode_length_multiplier > 0 else -args.min_decode_length_multiplier), early_stopping=True, attention_mask=dev_input_masks, pad_token_id=tok.pad_token_id, eos_token_id=tok(["</s>"], add_special_tokens=False).input_ids[0][0], decoder_start_token_id=tok([tlang if args.use_official_pretrained else "<2"+tlang+">"], add_special_tokens=False).input_ids[0][0], bos_token_id=tok(["<s>"], add_special_tokens=False).input_ids[0][0], length_penalty=args.length_penalty, repetition_penalty=args.repetition_penalty, encoder_no_repeat_ngram_size=args.encoder_no_repeat_ngram_size, no_repeat_ngram_size=args.no_repeat_ngram_size, additional_input_ids=dev_input_ids_parent if args.multi_source else None, additional_input_ids_mask=dev_input_masks_parent if args.multi_source else None) ## We translate the batch.
del dev_input_ids ## Delete to avoid retention.
del dev_input_masks ## Delete to avoid retention.
translations = translations.to('cpu') ## Delete to avoid retention.
if args.multi_source:
del dev_input_ids_parent ## Delete to avoid retention.
del dev_input_masks_parent ## Delete to avoid retention.
for translation in translations:
translation = decoding_tok.decode(translation, skip_special_tokens=args.no_skip_special_tokens, clean_up_tokenization_spaces=False) ### Get the raw sentences.
hyp[dev_idx][1].append(translation)
del translations ## Delete to avoid retention.
if args.use_rouge: ## Get the evaluation metric score.
scores = 0
for curr_ref, curr_pred in zip(refs[dev_idx][1][0], hyp[dev_idx][1]):
score = scorer.score(curr_ref, curr_pred)
scores += score['rougeL'].fmeasure
sbleu = scores/len(hyp[dev_idx][1])
metric = 'Rouge'
scorertool = 'RougeScorer'
else:
sbleu = get_sacrebleu(refs[dev_idx][1], hyp[dev_idx][1])
metric = 'BLEU'
scorertool = 'SacreBLEU'
individual_sbleu_history[dev_idx][1].append([sbleu, ctr]) ## Update the score history for this pair.
sbleus.append(sbleu)
print(metric, "score using", scorertool, "after", ctr, "iterations is", round(sbleu, 2), "for language pair", dev_pair)
writer.add_scalar(dev_pair+" bleu/rouge", sbleu, ctr)
if args.wb:
wandb.log(f"{dev_pair} bleu/rouge", sbleu, step=ctr)
if sbleu > max_individual_sbleu[dev_idx][1]: ## Update the best score and step number. If the score has improved then save a model copy for this pair. Although we will stop on the global score (average across scores over all pairs) we save these models if we want a model that performs the best on a single pair.
max_individual_sbleu[dev_idx][1] = sbleu
max_individual_sbleu_step[dev_idx][1] = curr_eval_step
print("New peak reached for", dev_pair,". Saving.")
if args.save_intermediate_checkpoints:
torch.save(checkpoint_dict, CHECKPOINT_PATH+".best_dev_bleu."+dev_pair+"."+str(ctr))
# if not args.use_fsdp:
torch.save(model.module.state_dict(), CHECKPOINT_PATH+".best_dev_bleu."+dev_pair+"."+str(ctr)+".pure_model") ## Pure model without any ddp markers or optimizer info.
torch.save(checkpoint_dict, CHECKPOINT_PATH+".best_dev_bleu."+dev_pair)
# if not args.use_fsdp:
torch.save(model.module.state_dict(), CHECKPOINT_PATH+".best_dev_bleu."+dev_pair+".pure_model")
## Global stats
sbleu = sum(sbleus)/len(sbleus) ## The global score.
global_sbleu_history.append([sbleu, ctr]) ## Update the global score history.
print("Global", metric, "score using", scorertool, "after", ctr, "iterations is:", round(sbleu, 2))
writer.add_scalar("global bleu/rouge", sbleu, ctr)
if args.wb:
wandb.log("global bleu/rouge", sbleu, step=ctr)
if sbleu > max_global_sbleu: ## Update the best score and step number. If this has improved then save a copy for the model. Note that this model MAY NOT be the model that gives the best performance for all pairs.
max_global_sbleu = sbleu
max_global_sbleu_step = curr_eval_step
print("New peak reached. Saving.")
if args.save_intermediate_checkpoints:
torch.save(checkpoint_dict, CHECKPOINT_PATH+".best_dev_bleu.global."+str(ctr))
# if not args.use_fsdp:
torch.save(model.module.state_dict(), CHECKPOINT_PATH+".best_dev_bleu.global."+str(ctr)+".pure_model") ## Pure model without any ddp markers or optimizer info.
torch.save(checkpoint_dict, CHECKPOINT_PATH+".best_dev_bleu.global")
# if not args.use_fsdp:
torch.save(model.module.state_dict(), CHECKPOINT_PATH+".best_dev_bleu.global.pure_model") ## Pure model without any ddp markers or optimizer info.
## Copy the global best pure model to the deploy folder.
os.system("cp "+CHECKPOINT_PATH+".best_dev_bleu.global.pure_model "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
if curr_eval_step - max_global_sbleu_step > (args.early_stop_checkpoints + annealing_attempt*args.additional_early_stop_checkpoints_per_anneal_step): ## If the global scores have not improved for more than early_stop_checkpoints + some additional checkpoints to wait for till annealing is done then we stop training.
if annealing_attempt < args.max_annealing_attempts: ## We will only downscale the LR a fixed number of times. Each time we downscale the number of checkpoints to wait for declaring convergence will increase by a fixed value.
annealing_attempt += 1
with open(args.model_path + ".anneal", "w") as f:
f.write(str(annealing_attempt)) # Other processes will see an increase in this value and will downscale the LR.
curr_lr = scheduler.get_lr()[0]
print("LR before annealing is:", curr_lr)
while scheduler.get_lr()[0] > (curr_lr/args.learning_rate_scaling): ## Currently we down scale the LR by advancing the scheduler by some steps. Now this is a bad idea because the scheduler may reach maximum number of steps where the LR is 0. However the training loop will continue and nothing will be updated. The loophole I have used is to set the maximum number of steps to a large value. Thus far I have not seen a case where this has a bad effect but users who do not trust this part of the code should not use annealing.
scheduler.step()
print("LR after annealing is:", scheduler.get_lr()[0])
else: ## Convergence has been reached and we stop and report the final metrics.
print("We have seemingly converged as", metric, "failed to increase for the following number of checkpoints:", args.early_stop_checkpoints+annealing_attempt*args.additional_early_stop_checkpoints_per_anneal_step, ". You may want to consider increasing the number of tolerance steps, doing additional annealing or having a lower peak learning rate or something else.")
print("Terminating training")
print("Global dev", metric, "history:", [[round(x,2), y] for x,y in global_sbleu_history])
print("Individual", metric, "history:", [[lang_pair, [[round(x,2), y] for x,y in individual_sbleu_info_for_language]] for lang_pair, individual_sbleu_info_for_language in individual_sbleu_history])
with open(args.model_path + ".quitflag", "w") as f:
f.write("1")
curr_eval_step += 1
del checkpoint_dict
model.train() ## Put the model back in training mode where dropout will be done.
else: ## Regardless of evaluation I consider it prudent to save the model every 1000 checkpoints by default. Change this to whatever value you want.
pass
if args.use_fsdp: # For FSDP we will save the model params, optimizer, scheduler and ctr in separate files. This is because FSDP saving everything in a single file is too heavy.
if torch_version.startswith("2."): ## From 2.0 onwards, FSDP allows sharded optimizers so we will load a sharded optimizer.
model_shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_sharded")
optim_shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_optim_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
save_state_dict(state_dict, model_shard_writer)
optim_dict = {"optim": FSDP.optim_state_dict(model, optimizer)}
save_state_dict(optim_dict, optim_shard_writer)
if ctr % args.save_intermediate_checkpoints_every == 0 and args.save_intermediate_checkpoints:
model_shard_writer = FileSystemWriter(CHECKPOINT_PATH +"."+str(ctr)+ "_sharded")
optim_shard_writer = FileSystemWriter(CHECKPOINT_PATH +"."+str(ctr)+ "_optim_sharded")
save_state_dict(state_dict, model_shard_writer)
save_state_dict(optim_dict, optim_shard_writer)
del state_dict
del optim_dict
## Also save the full state dict for the model and optimizer.
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fullstate_save_policy, optimizer_fullstate_save_policy):
state_dict = model.state_dict()
optim_dict = FSDP.full_optim_state_dict(model, optimizer)
if rank == 0:
torch.save(state_dict, CHECKPOINT_PATH)
torch.save(optim_dict, CHECKPOINT_PATH + "_optim")
scheduler_and_ctr = {'scheduler': scheduler.state_dict(), 'ctr': ctr}
torch.save(scheduler_and_ctr, CHECKPOINT_PATH + "_scheduler_and_ctr")
os.system("cp "+CHECKPOINT_PATH+" "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
if ctr % args.save_intermediate_checkpoints_every == 0 and args.save_intermediate_checkpoints:
torch.save(state_dict, CHECKPOINT_PATH +"."+str(ctr))
torch.save(optim_dict, CHECKPOINT_PATH +"."+str(ctr)+ "_optim")
torch.save(scheduler_and_ctr, CHECKPOINT_PATH +"."+str(ctr)+ "_scheduler_and_ctr")
del scheduler_and_ctr
del state_dict
del optim_dict
else:
shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_sharded")
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
state_dict = model.state_dict()
if args.save_intermediate_checkpoints and ctr % args.save_intermediate_checkpoints_every == 0:
shard_writer = FileSystemWriter(CHECKPOINT_PATH +"."+str(ctr)+ "_sharded")
save_state_dict(state_dict, shard_writer)
shard_writer = FileSystemWriter(CHECKPOINT_PATH + "_sharded")
save_state_dict(state_dict, shard_writer)
del state_dict
state_dict = None
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, fullstate_save_policy): ## This full state dict is what is messing things up. The model should be saved as local state dicts and then assembled as a full state dict in the end if needed. A presharding and unsharding script may be useful. We have used an offload to CPU policy and this means we hopefully wont run out of memory.
state_dict = model.state_dict()
optim_dict = FSDP.full_optim_state_dict(model, optimizer)
if rank == 0:
print("Saving the model")
if args.save_intermediate_checkpoints and ctr % args.save_intermediate_checkpoints_every == 0:
print("Saving an intermediate checkpoint")
torch.save(state_dict, CHECKPOINT_PATH+"."+str(ctr))
sys.stdout.flush()
torch.save(state_dict, CHECKPOINT_PATH)
scheduler_and_ctr = {'scheduler': scheduler.state_dict(), 'ctr': ctr}
if args.save_intermediate_checkpoints and ctr % args.save_intermediate_checkpoints_every == 0:
torch.save(optim_dict, CHECKPOINT_PATH + "."+str(ctr)+"_optim")
torch.save(scheduler_and_ctr, CHECKPOINT_PATH + "."+str(ctr)+"_scheduler_and_ctr")
torch.save(optim_dict, CHECKPOINT_PATH + "_optim")
torch.save(scheduler_and_ctr, CHECKPOINT_PATH + "_scheduler_and_ctr")
os.system("cp "+CHECKPOINT_PATH+" "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
del scheduler_and_ctr
del state_dict
del optim_dict
else:
if rank == 0:
print("Saving the model")
checkpoint_dict = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'ctr': ctr}
if args.save_intermediate_checkpoints and ctr % args.save_intermediate_checkpoints_every == 0:
print("Saving an intermediate checkpoint")
torch.save(checkpoint_dict, CHECKPOINT_PATH+"."+str(ctr))
sys.stdout.flush()
torch.save(checkpoint_dict, CHECKPOINT_PATH) ## Save a model by default every eval_every steps. This model will be saved with the same file name each time.
torch.save(model.module.state_dict(), CHECKPOINT_PATH+".pure_model")
os.system("cp "+CHECKPOINT_PATH+".pure_model "+CHECKPOINT_PATH+"_deploy/pytorch_model.bin")
del checkpoint_dict
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
with open(args.model_path + ".quitflag", "r") as f:
if f.read().strip() == "1":
print("All processess to die!")
break
torch.cuda.empty_cache()
# start = time.time() ## All eval and ckpt saving is done here so start counting from here.
dist.barrier()
if not args.use_fsdp and rank != 0: # handle annealing for all other processes in non-fsdp settings.
with open(args.model_path + ".anneal", "r") as f:
master_annealing_attempt = int(f.read().strip())
if master_annealing_attempt > annealing_attempt: