Skip to content

Commit

Permalink
rm app key
Browse files Browse the repository at this point in the history
  • Loading branch information
admin committed Oct 14, 2024
1 parent aca9c01 commit 89e4c5d
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 47 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ git clone git@github.com:monetjoe/EMusicGen.git
cd EMusicGen
```

## Train
```bash
python train.py
```

## Success rate
| Dataset | Rough4Q | VGMIDI | EMOPIA |
| :-----: | :-----: | :----: | :----: |
Expand Down
6 changes: 1 addition & 5 deletions embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from modelscope.hub.api import HubApi
from modelscope.msdatasets import MsDataset
from sklearn.svm import LinearSVC
from tqdm import tqdm
from utils import APP_KEY, TEMP_DIR, OUTPUT_PATH
from utils import TEMP_DIR, OUTPUT_PATH
from config import *


Expand Down Expand Up @@ -38,13 +36,11 @@ def forward(self, x):


def data():
HubApi().login(APP_KEY)
ds = MsDataset.load(
f"monetjoe/{DATASET}",
subset_name="Analysis",
split="train",
cache_dir=f"{TEMP_DIR}/cache",
trust_remote_code=True,
)
dataset = list(ds)
p90 = int(len(dataset) * 0.9)
Expand Down
4 changes: 1 addition & 3 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
import subprocess
import soundfile as sf
from modelscope import snapshot_download
from modelscope.hub.api import HubApi
from transformers import GPT2Config
from music21 import converter, interval, clef, stream
from utils import Patchilizer, TunesFormer, DEVICE, MSCORE, APP_KEY
from utils import Patchilizer, TunesFormer, DEVICE, MSCORE
from config import *

HubApi().login(APP_KEY)
EMUSICGEN_WEIGHTS_DIR = snapshot_download(f"monetjoe/{DATASET}", cache_dir=TEMP_DIR)


Expand Down
5 changes: 1 addition & 4 deletions rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from transformers import GPT2Config
from torch.distributions import Categorical
from modelscope.msdatasets import MsDataset
from modelscope.hub.api import HubApi
from modelscope import snapshot_download
from utils import TunesFormer, Patchilizer, DEVICE, APP_KEY
from utils import TunesFormer, Patchilizer, DEVICE
from generate import infer_abc
from config import *

Expand All @@ -21,12 +20,10 @@ def __init__(self, subset: str):
self.subset = subset

def prepare_prompts(self):
HubApi().login(APP_KEY)
ds = MsDataset.load(
f"monetjoe/{DATASET}",
subset_name=self.subset,
cache_dir=TEMP_DIR,
trust_remote_code=True,
)
dataset = list(ds["train"]) + list(ds["test"])
prompt_set = set("A:Q1\n", "A:Q2\n", "A:Q3\n", "A:Q4\n", "")
Expand Down
14 changes: 4 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler
from utils import Patchilizer, TunesFormer, PatchilizedData, DEVICE, APP_KEY
from utils import Patchilizer, TunesFormer, PatchilizedData, DEVICE
from modelscope.msdatasets import MsDataset
from modelscope.hub.api import HubApi
from modelscope import snapshot_download
from tqdm import tqdm
from transformers import GPT2Config, get_scheduler
Expand All @@ -36,12 +35,10 @@ def init(bsz=4):
max_position_embeddings=PATCH_SIZE,
vocab_size=128,
)
model: nn.Module = TunesFormer(
patch_config, char_config, SHARE_WEIGHTS).to(DEVICE)
model: nn.Module = TunesFormer(patch_config, char_config, SHARE_WEIGHTS).to(DEVICE)
# print parameter number
print(
f"Parameter Number: {sum(p.numel()
for p in model.parameters() if p.requires_grad)}"
f"Parameter Number: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
Expand Down Expand Up @@ -163,12 +160,10 @@ def train(subset: str, dld_mode="reuse_dataset_if_exists", bsz=1):
clean_caches(subset)

# load data
HubApi().login(APP_KEY)
dataset = MsDataset.load(
f"monetjoe/{DATASET}",
subset_name=subset,
cache_dir=f"{TEMP_DIR}/cache",
trust_remote_code=True,
download_mode=dld_mode,
)
classes = dataset["test"].features["label"].names
Expand Down Expand Up @@ -286,8 +281,7 @@ def train(subset: str, dld_mode="reuse_dataset_if_exists", bsz=1):
)
break

print(f"Best Eval Epoch : {str(best_epoch)
}\nMin Eval Loss : {str(min_eval_loss)}")
print(f"Best Eval Epoch: {str(best_epoch)}\nMin Eval Loss: {str(min_eval_loss)}")


if __name__ == "__main__":
Expand Down
37 changes: 12 additions & 25 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

os.environ["MODELSCOPE_LOG_LEVEL"] = "40"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
APP_KEY = os.getenv("ms_app_key")
MSCORE = os.getenv("mscore")


Expand Down Expand Up @@ -45,8 +44,7 @@ def bar2patch(self, bar, patch_size=PATCH_SIZE):
"""
Convert a bar into a patch of specified length.
"""
patch = [self.bos_token_id] + [ord(c)
for c in bar] + [self.eos_token_id]
patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
patch = patch[:patch_size]
patch += [self.pad_token_id] * (patch_size - len(patch))
return patch
Expand Down Expand Up @@ -79,15 +77,13 @@ def encode(

for line in lines:
if len(line) > 1 and (
(line[0].isalpha() and line[1] ==
":") or line.startswith("%%score")
(line[0].isalpha() and line[1] == ":") or line.startswith("%%score")
):
if body:
bars = self.split_bars(body)
patches.extend(
self.bar2patch(
bar + "\n" if idx == len(bars) -
1 else bar, patch_size
bar + "\n" if idx == len(bars) - 1 else bar, patch_size
)
for idx, bar in enumerate(bars)
)
Expand All @@ -104,10 +100,8 @@ def encode(
)

if add_special_patches:
bos_patch = [self.bos_token_id] * \
(patch_size - 1) + [self.eos_token_id]
eos_patch = [self.bos_token_id] + \
[self.eos_token_id] * (patch_size - 1)
bos_patch = [self.bos_token_id] * (patch_size - 1) + [self.eos_token_id]
eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size - 1)
patches = [bos_patch] + patches + [eos_patch]

return patches[:patch_length]
Expand Down Expand Up @@ -216,8 +210,7 @@ def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
tokens = tokens.reshape(1, -1)

# Get input embeddings
tokens = torch.nn.functional.embedding(
tokens, self.base.transformer.wte.weight)
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)

# Concatenate the encoded patch with the input embeddings
tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
Expand All @@ -226,8 +219,7 @@ def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
outputs = self.base(inputs_embeds=tokens)

# Get probabilities of next token
probs = torch.nn.functional.softmax(
outputs.logits.squeeze(0)[-1], dim=-1)
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)

return probs

Expand All @@ -249,8 +241,7 @@ def __init__(self, encoder_config, decoder_config, share_weights=False):
encoder_config.num_hidden_layers, decoder_config.num_hidden_layers
)

max_context_size = max(
encoder_config.max_length, decoder_config.max_length)
max_context_size = max(encoder_config.max_length, decoder_config.max_length)

max_position_embeddings = max(
encoder_config.max_position_embeddings,
Expand Down Expand Up @@ -281,8 +272,7 @@ def forward(
:return: the decoded patches
"""
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
encoded_patches = self.patch_level_decoder(
patches)["last_hidden_state"]
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]

return self.char_level_decoder(
encoded_patches.squeeze(0)[:-1, :],
Expand All @@ -305,8 +295,7 @@ def generate(
:return: the generated patches
"""
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
encoded_patches = self.patch_level_decoder(
patches)["last_hidden_state"]
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]

if tokens == None:
tokens = torch.tensor([self.bos_token_id], device=self.device)
Expand All @@ -323,17 +312,15 @@ def generate(
n_seed = None

prob = (
self.char_level_decoder.generate(
encoded_patches[0][-1], tokens)
self.char_level_decoder.generate(encoded_patches[0][-1], tokens)
.cpu()
.detach()
.numpy()
)

prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
token = temperature_sampling(
prob, temperature=temperature, seed=n_seed)
token = temperature_sampling(prob, temperature=temperature, seed=n_seed)

generated_patch.append(token)
if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
Expand Down

0 comments on commit 89e4c5d

Please sign in to comment.