diff --git a/config.py b/config.py index 058e5c7..ac80460 100644 --- a/config.py +++ b/config.py @@ -12,9 +12,9 @@ LOAD_FROM_CHECKPOINT = True # Whether to load weights from a checkpoint # Whether to share weights between the encoder and decoder SHARE_WEIGHTS = False -WEIGHT_URL = 'https://huggingface.co/MuGeminorum/hoyoGPT/resolve/main/weights.pth' -WEIGHT_URL_ZH = 'https://www.modelscope.cn/api/v1/models/MuGeminorum/hoyoGPT/repo?Revision=master&FilePath=weights.pth' +WEIGHT_URL = 'https://huggingface.co/sander-wood/tunesformer/resolve/main/weights.pth' +WEIGHT_URL_ZH = 'https://www.modelscope.cn/models/monetjoe/tunesformer-mirror/repo?Revision=master&FilePath=weights.pth' OUTPUT_PATH = './output' WEIGHT_PATH = f'{OUTPUT_PATH}/weights.pth' -LOG_PATH = f'{OUTPUT_PATH}/logs.txt' +LOG_PATH = f'{OUTPUT_PATH}/logs.jsonl' PROMPT_PATH = 'prompt.txt' diff --git a/generate.py b/generate.py index 2deb8bd..d5021e7 100644 --- a/generate.py +++ b/generate.py @@ -9,20 +9,43 @@ def get_args(parser): - parser.add_argument('-num_tunes', type=int, default=3, - help='the number of independently computed returned tunes') - parser.add_argument('-max_patch', type=int, default=128, - help='integer to define the maximum length in tokens of each tune') - parser.add_argument('-top_p', type=float, default=0.8, - help='float to define the tokens that are within the sample operation of text generation') - parser.add_argument('-top_k', type=int, default=8, - help='integer to define the tokens that are within the sample operation of text generation') - parser.add_argument('-temperature', type=float, default=1.2, - help='the temperature of the sampling operation') - parser.add_argument('-seed', type=int, default=None, - help='seed for randomstate') - parser.add_argument('-show_control_code', type=bool, - default=True, help='whether to show control code') + parser.add_argument( + "-num_tunes", + type=int, + default=1, + help="the number of independently computed returned tunes", + ) + parser.add_argument( + "-max_patch", + type=int, + default=128, + help="integer to define the maximum length in tokens of each tune", + ) + parser.add_argument( + "-top_p", + type=float, + default=0.8, + help="float to define the tokens that are within the sample operation of text generation", + ) + parser.add_argument( + "-top_k", + type=int, + default=8, + help="integer to define the tokens that are within the sample operation of text generation", + ) + parser.add_argument( + "-temperature", + type=float, + default=1.2, + help="the temperature of the sampling operation", + ) + parser.add_argument("-seed", type=int, default=None, help="seed for randomstate") + parser.add_argument( + "-show_control_code", + type=bool, + default=False, + help="whether to show control code", + ) args = parser.parse_args() return args @@ -35,14 +58,14 @@ def generate_abc(args): num_hidden_layers=PATCH_NUM_LAYERS, max_length=PATCH_LENGTH, max_position_embeddings=PATCH_LENGTH, - vocab_size=1 + vocab_size=1, ) char_config = GPT2Config( num_hidden_layers=CHAR_NUM_LAYERS, max_length=PATCH_SIZE, max_position_embeddings=PATCH_SIZE, - vocab_size=128 + vocab_size=128, ) model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) @@ -56,11 +79,11 @@ def generate_abc(args): download() checkpoint = torch.load(filename) - model.load_state_dict(checkpoint['model']) + model.load_state_dict(checkpoint["model"]) model = model.to(device) model.eval() - with open(PROMPT_PATH, 'r') as f: + with open(PROMPT_PATH, "r") as f: prompt = f.read() tunes = "" @@ -72,19 +95,19 @@ def generate_abc(args): seed = args.seed show_control_code = args.show_control_code - print(" HYPERPARAMETERS ".center(60, "#"), '\n') + print(" HYPERPARAMETERS ".center(60, "#"), "\n") args = vars(args) for key in args.keys(): - print(f'{key}: {str(args[key])}') + print(f"{key}: {str(args[key])}") - print('\n', " OUTPUT TUNES ".center(60, "#")) + print("\n", " OUTPUT TUNES ".center(60, "#")) start_time = time.time() for i in range(num_tunes): tune = f"X:{str(i + 1)}\n{prompt}" - lines = re.split(r'(\n)', tune) + lines = re.split(r"(\n)", tune) tune = "" skip = False for line in lines: @@ -99,8 +122,7 @@ def generate_abc(args): skip = True input_patches = torch.tensor( - [patchilizer.encode(prompt, add_special_patches=True)[:-1]], - device=device + [patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=device ) if tune == "": @@ -108,10 +130,10 @@ def generate_abc(args): else: prefix = patchilizer.decode(input_patches[0]) - remaining_tokens = prompt[len(prefix):] + remaining_tokens = prompt[len(prefix) :] tokens = torch.tensor( - [patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens], - device=device + [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens], + device=device, ) while input_patches.shape[1] < max_patch: @@ -121,7 +143,7 @@ def generate_abc(args): top_p=top_p, top_k=top_k, temperature=temperature, - seed=seed + seed=seed, ) tokens = None @@ -135,17 +157,15 @@ def generate_abc(args): if next_bar == "": break - next_bar = remaining_tokens+next_bar + next_bar = remaining_tokens + next_bar remaining_tokens = "" predicted_patch = torch.tensor( - patchilizer.bar2patch(next_bar), - device=device + patchilizer.bar2patch(next_bar), device=device ).unsqueeze(0) input_patches = torch.cat( - [input_patches, predicted_patch.unsqueeze(0)], - dim=1 + [input_patches, predicted_patch.unsqueeze(0)], dim=1 ) else: @@ -157,7 +177,7 @@ def generate_abc(args): print("Generation time: {:.2f} seconds".format(time.time() - start_time)) timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) create_dir() - with open(f'{OUTPUT_PATH}/{timestamp}.abc', 'w') as f: + with open(f"{OUTPUT_PATH}/{timestamp}.abc", "w") as f: f.write(tunes) diff --git a/train.py b/train.py index 34f0e1c..f54c396 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,6 @@ import os import time +import json from utils import * from torch.utils.data import DataLoader from modelscope.msdatasets import MsDataset @@ -215,10 +216,16 @@ def eval_epoch(model): # do one epoch for eval train_loss = train_epoch(model, is_autocast, scaler) eval_loss = eval_epoch(model) - with open(LOG_PATH, "a") as f: - f.write( - f"Epoch {str(epoch)}\ntrain_loss: {str(train_loss)}\neval_loss: {str(eval_loss)}\ntime: {time.asctime(time.localtime(time.time()))}\n\n" + with open(LOG_PATH, "a", encoding="utf-8") as jsonl_file: + json_str = json.dumps( + { + "epoch": str(epoch), + "train_loss": str(train_loss), + "eval_loss": str(eval_loss), + "time": f"{time.asctime(time.localtime(time.time()))}", + } ) + jsonl_file.write(json_str + "\n") if eval_loss < min_eval_loss: best_epoch = epoch diff --git a/utils.py b/utils.py index b181fa3..9cf52e6 100644 --- a/utils.py +++ b/utils.py @@ -30,7 +30,7 @@ def download(url=WEIGHT_URL, filename='./output/weights.pth'): chunk_size = 1024 with open(filename, 'wb') as file, tqdm( - desc=f"Downloading weights to '{filename}' from HF...", + desc=f"Downloading weights to '{filename}'...", total=total_size, unit='B', unit_scale=True,