Skip to content

Commit

Permalink
repalce set_random_seed with transformers's set_seed
Browse files Browse the repository at this point in the history
  • Loading branch information
statelesshz committed Nov 7, 2023
1 parent e228d96 commit 2216942
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 12 deletions.
4 changes: 2 additions & 2 deletions finetune/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
)
from constant import SFT
from transformers import set_seed
from utils.data.data_utils import create_prompt_dataset
from utils.ds_utils import get_train_ds_config
from utils.model.model_utils import create_hf_model
Expand All @@ -36,7 +37,6 @@
load_hf_tokenizer,
print_rank_0,
save_hf_format,
set_random_seed,
to_device,
)

Expand Down Expand Up @@ -244,7 +244,7 @@ def main():
)

# If passed along, set the training seed now.
set_random_seed(args.seed)
set_seed(args.seed)
torch.distributed.barrier()

# load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family
Expand Down
11 changes: 1 addition & 10 deletions finetune/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn as nn
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from transformers import AutoTokenizer, set_seed
from transformers import AutoTokenizer


def print_rank_0(msg, rank=0):
Expand Down Expand Up @@ -99,15 +99,6 @@ def save_hf_format(model, tokenizer, args, sub_folder=""):
copy(os.path.join(source, "tokenization_yi.py"), target)


def set_random_seed(seed):
if seed is not None:
set_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def get_all_reduce_mean(tensor):
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
tensor = tensor / torch.distributed.get_world_size()
Expand Down

0 comments on commit 2216942

Please sign in to comment.