Skip to content

Commit

Permalink
use jsonl log
Browse files Browse the repository at this point in the history
  • Loading branch information
monetjoe committed May 25, 2024
1 parent ddfb14e commit 11a2464
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 41 deletions.
6 changes: 3 additions & 3 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
LOAD_FROM_CHECKPOINT = True # Whether to load weights from a checkpoint
# Whether to share weights between the encoder and decoder
SHARE_WEIGHTS = False
WEIGHT_URL = 'https://huggingface.co/MuGeminorum/hoyoGPT/resolve/main/weights.pth'
WEIGHT_URL_ZH = 'https://www.modelscope.cn/api/v1/models/MuGeminorum/hoyoGPT/repo?Revision=master&FilePath=weights.pth'
WEIGHT_URL = 'https://huggingface.co/sander-wood/tunesformer/resolve/main/weights.pth'
WEIGHT_URL_ZH = 'https://www.modelscope.cn/models/monetjoe/tunesformer-mirror/repo?Revision=master&FilePath=weights.pth'
OUTPUT_PATH = './output'
WEIGHT_PATH = f'{OUTPUT_PATH}/weights.pth'
LOG_PATH = f'{OUTPUT_PATH}/logs.txt'
LOG_PATH = f'{OUTPUT_PATH}/logs.jsonl'
PROMPT_PATH = 'prompt.txt'
88 changes: 54 additions & 34 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,43 @@


def get_args(parser):
parser.add_argument('-num_tunes', type=int, default=3,
help='the number of independently computed returned tunes')
parser.add_argument('-max_patch', type=int, default=128,
help='integer to define the maximum length in tokens of each tune')
parser.add_argument('-top_p', type=float, default=0.8,
help='float to define the tokens that are within the sample operation of text generation')
parser.add_argument('-top_k', type=int, default=8,
help='integer to define the tokens that are within the sample operation of text generation')
parser.add_argument('-temperature', type=float, default=1.2,
help='the temperature of the sampling operation')
parser.add_argument('-seed', type=int, default=None,
help='seed for randomstate')
parser.add_argument('-show_control_code', type=bool,
default=True, help='whether to show control code')
parser.add_argument(
"-num_tunes",
type=int,
default=1,
help="the number of independently computed returned tunes",
)
parser.add_argument(
"-max_patch",
type=int,
default=128,
help="integer to define the maximum length in tokens of each tune",
)
parser.add_argument(
"-top_p",
type=float,
default=0.8,
help="float to define the tokens that are within the sample operation of text generation",
)
parser.add_argument(
"-top_k",
type=int,
default=8,
help="integer to define the tokens that are within the sample operation of text generation",
)
parser.add_argument(
"-temperature",
type=float,
default=1.2,
help="the temperature of the sampling operation",
)
parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
parser.add_argument(
"-show_control_code",
type=bool,
default=False,
help="whether to show control code",
)
args = parser.parse_args()

return args
Expand All @@ -35,14 +58,14 @@ def generate_abc(args):
num_hidden_layers=PATCH_NUM_LAYERS,
max_length=PATCH_LENGTH,
max_position_embeddings=PATCH_LENGTH,
vocab_size=1
vocab_size=1,
)

char_config = GPT2Config(
num_hidden_layers=CHAR_NUM_LAYERS,
max_length=PATCH_SIZE,
max_position_embeddings=PATCH_SIZE,
vocab_size=128
vocab_size=128,
)

model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
Expand All @@ -56,11 +79,11 @@ def generate_abc(args):
download()

checkpoint = torch.load(filename)
model.load_state_dict(checkpoint['model'])
model.load_state_dict(checkpoint["model"])
model = model.to(device)
model.eval()

with open(PROMPT_PATH, 'r') as f:
with open(PROMPT_PATH, "r") as f:
prompt = f.read()

tunes = ""
Expand All @@ -72,19 +95,19 @@ def generate_abc(args):
seed = args.seed
show_control_code = args.show_control_code

print(" HYPERPARAMETERS ".center(60, "#"), '\n')
print(" HYPERPARAMETERS ".center(60, "#"), "\n")
args = vars(args)

for key in args.keys():
print(f'{key}: {str(args[key])}')
print(f"{key}: {str(args[key])}")

print('\n', " OUTPUT TUNES ".center(60, "#"))
print("\n", " OUTPUT TUNES ".center(60, "#"))

start_time = time.time()

for i in range(num_tunes):
tune = f"X:{str(i + 1)}\n{prompt}"
lines = re.split(r'(\n)', tune)
lines = re.split(r"(\n)", tune)
tune = ""
skip = False
for line in lines:
Expand All @@ -99,19 +122,18 @@ def generate_abc(args):
skip = True

input_patches = torch.tensor(
[patchilizer.encode(prompt, add_special_patches=True)[:-1]],
device=device
[patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=device
)

if tune == "":
tokens = None

else:
prefix = patchilizer.decode(input_patches[0])
remaining_tokens = prompt[len(prefix):]
remaining_tokens = prompt[len(prefix) :]
tokens = torch.tensor(
[patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens],
device=device
[patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
device=device,
)

while input_patches.shape[1] < max_patch:
Expand All @@ -121,7 +143,7 @@ def generate_abc(args):
top_p=top_p,
top_k=top_k,
temperature=temperature,
seed=seed
seed=seed,
)
tokens = None

Expand All @@ -135,17 +157,15 @@ def generate_abc(args):
if next_bar == "":
break

next_bar = remaining_tokens+next_bar
next_bar = remaining_tokens + next_bar
remaining_tokens = ""

predicted_patch = torch.tensor(
patchilizer.bar2patch(next_bar),
device=device
patchilizer.bar2patch(next_bar), device=device
).unsqueeze(0)

input_patches = torch.cat(
[input_patches, predicted_patch.unsqueeze(0)],
dim=1
[input_patches, predicted_patch.unsqueeze(0)], dim=1
)

else:
Expand All @@ -157,7 +177,7 @@ def generate_abc(args):
print("Generation time: {:.2f} seconds".format(time.time() - start_time))
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
create_dir()
with open(f'{OUTPUT_PATH}/{timestamp}.abc', 'w') as f:
with open(f"{OUTPUT_PATH}/{timestamp}.abc", "w") as f:
f.write(tunes)


Expand Down
13 changes: 10 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
import json
from utils import *
from torch.utils.data import DataLoader
from modelscope.msdatasets import MsDataset
Expand Down Expand Up @@ -215,10 +216,16 @@ def eval_epoch(model): # do one epoch for eval
train_loss = train_epoch(model, is_autocast, scaler)
eval_loss = eval_epoch(model)

with open(LOG_PATH, "a") as f:
f.write(
f"Epoch {str(epoch)}\ntrain_loss: {str(train_loss)}\neval_loss: {str(eval_loss)}\ntime: {time.asctime(time.localtime(time.time()))}\n\n"
with open(LOG_PATH, "a", encoding="utf-8") as jsonl_file:
json_str = json.dumps(
{
"epoch": str(epoch),
"train_loss": str(train_loss),
"eval_loss": str(eval_loss),
"time": f"{time.asctime(time.localtime(time.time()))}",
}
)
jsonl_file.write(json_str + "\n")

if eval_loss < min_eval_loss:
best_epoch = epoch
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def download(url=WEIGHT_URL, filename='./output/weights.pth'):
chunk_size = 1024

with open(filename, 'wb') as file, tqdm(
desc=f"Downloading weights to '{filename}' from HF...",
desc=f"Downloading weights to '{filename}'...",
total=total_size,
unit='B',
unit_scale=True,
Expand Down

0 comments on commit 11a2464

Please sign in to comment.