Skip to content

Commit

Permalink
upd ablation exps
Browse files Browse the repository at this point in the history
  • Loading branch information
admin committed Sep 2, 2024
1 parent 332f29d commit 4aa071e
Showing 1 changed file with 93 additions and 39 deletions.
132 changes: 93 additions & 39 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import warnings
import subprocess
import soundfile as sf
from modelscope import snapshot_download
from transformers import GPT2Config
from music21 import converter, interval, clef, stream
Expand Down Expand Up @@ -110,14 +111,21 @@ def transpose_octaves_abc(abc_notation: str, out_xml_file: str, offset=-12):
return xml2abc(out_xml_file), out_xml_file


def adjust_volume(in_audio: str, dB_change: int):
y, sr = sf.read(in_audio)
sf.write(in_audio, y * 10 ** (dB_change / 20), sr)


def generate_music(
args,
emo: str,
weights: str,
outdir=TEMP_DIR,
fix_t=True,
fix_m=True,
fix_p=True,
fix_tempo=True,
fix_mode=True,
fix_pitch=True,
fix_std=True,
fix_volume=True,
):
patchilizer = Patchilizer()
patch_config = GPT2Config(
Expand All @@ -137,7 +145,7 @@ def generate_music(
model.load_state_dict(checkpoint["model"])
model = model.to(DEVICE)
model.eval()
prompt = f"A:{emo}\n"
prompt = ""
tunes = ""
num_tunes = args.num_tunes
max_patch = args.max_patch
Expand All @@ -151,6 +159,24 @@ def generate_music(
for arg in args_dict.keys():
print(f"{arg}: {str(args_dict[arg])}")

# fix mode / pitch_std
if fix_mode and fix_std:
prompt = f"A:{emo}\n"

elif fix_mode:
if emo == "Q1" or emo == "Q4":
prompt = "A:" + random.choice(["Q1", "Q4"]) + "\n"

elif emo == "Q2" or emo == "Q3":
prompt = "A:" + random.choice(["Q2", "Q3"]) + "\n"

elif fix_std:
if emo == "Q1" or emo == "Q2":
prompt = "A:" + random.choice(["Q1", "Q2"]) + "\n"

elif emo == "Q3" or emo == "Q4":
prompt = "A:" + random.choice(["Q3", "Q4"]) + "\n"

print("\n", " Output tunes ".center(60, "#"))
start_time = time.time()
for i in range(num_tunes):
Expand All @@ -172,7 +198,8 @@ def generate_music(
skip = True

input_patches = torch.tensor(
[patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=DEVICE
[patchilizer.encode(prompt, add_special_patches=True)[:-1]],
device=DEVICE,
)
if tune == "":
tokens = None
Expand Down Expand Up @@ -207,10 +234,12 @@ def generate_music(
next_bar = remaining_tokens + next_bar
remaining_tokens = ""
predicted_patch = torch.tensor(
patchilizer.bar2patch(next_bar), device=DEVICE
patchilizer.bar2patch(next_bar),
device=DEVICE,
).unsqueeze(0)
input_patches = torch.cat(
[input_patches, predicted_patch.unsqueeze(0)], dim=1
[input_patches, predicted_patch.unsqueeze(0)],
dim=1,
)

else:
Expand All @@ -221,7 +250,7 @@ def generate_music(

# fix tempo
tempo = ""
if fix_t:
if fix_tempo:
tempo = f"Q:{random.randint(88, 132)}\n"
if emo == "Q1":
tempo = f"Q:{random.randint(160, 184)}\n"
Expand All @@ -238,25 +267,28 @@ def generate_music(

tunes = tunes.replace(f"A:{emo}\n", tempo)
# fix mode:major/minor
key = "major" if emo == "Q1" or emo == "Q4" else "minor"
if fix_m:
mode = "major" if emo == "Q1" or emo == "Q4" else "minor"
if fix_mode:
K_val = get_abc_key_val(tunes)
if key == "major" and K_val and "m" in K_val:
if mode == "major" and K_val and "m" in K_val:
tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.split('m')[0]}\n")

elif key == "minor" and K_val and not "m" in K_val:
elif mode == "minor" and K_val and not "m" in K_val:
tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.lower()}min\n")

print("Generation time: {:.2f} seconds".format(time.time() - start_time))
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
try:
if key == "minor" and fix_p:
# fix avg_pitch (octave)
if mode == "minor" and fix_pitch:
offset = -12
if emo == "Q2":
offset -= 12

tunes, xml = transpose_octaves_abc(
tunes, f"{outdir}/{timestamp}.musicxml", offset
tunes,
f"{outdir}/{timestamp}.musicxml",
offset,
)
tunes = tunes.replace(title + title, title)
os.rename(xml, f"{outdir}/[{emo}]{timestamp}.musicxml")
Expand All @@ -270,7 +302,16 @@ def generate_music(
os.remove(xml)

if os.path.exists(audio):
# fix rms vol
if fix_volume:
if emo == "Q1":
adjust_volume(audio, 5)

elif emo == "Q2":
adjust_volume(audio, 10)

return audio

else:
return ""

Expand All @@ -286,6 +327,8 @@ def infers(
fix_tempo=True,
fix_mode=True,
fix_pitch=True,
fix_std=True,
fix_volume=True,
):
os.makedirs(outdir, exist_ok=True)
parser = argparse.ArgumentParser()
Expand All @@ -296,9 +339,11 @@ def infers(
emo=emotion,
weights=f"{emusicgen_weights_dir}/{dataset.lower()}/weights.pth",
outdir=outdir,
fix_t=fix_tempo,
fix_m=fix_mode,
fix_p=fix_pitch,
fix_tempo=fix_tempo,
fix_mode=fix_mode,
fix_pitch=fix_pitch,
fix_std=fix_std,
fix_volume=fix_volume,
)


Expand All @@ -309,29 +354,36 @@ def add_to_log(message: str, log_file_path=f"{EXPERIMENT_DIR}/success_rates.log"


def generate_exps(
fix_t=False,
fix_m=False,
fix_p=False,
fix_t=True,
fix_m=True,
fix_p=True,
fix_s=True,
fix_v=True,
total=100,
labels=["Q1", "Q2", "Q3", "Q4"],
):
outdir = f"{EXPERIMENT_DIR}/"
if fix_t and fix_m and fix_p:
outdir += "all"
elif fix_t:
outdir += "tempo"
elif fix_m:
outdir += "mode"
elif fix_p:
outdir += "pitch"
else:
outdir += "none"
subdir = "none"
if not fix_t:
subdir = "tempo"

if not fix_m:
subdir = "mode"

if not fix_p:
subdir = "pitch"

if not fix_s:
subdir = "std"

if not fix_v:
subdir = "volume"

outdir = f"{EXPERIMENT_DIR}/{subdir}"
hit_rate = []
for emo in labels:
success, fail = 0, 0
while success < total // len(labels):
if infers("Rough4Q", emo, outdir, fix_t, fix_m, fix_p):
if infers("Rough4Q", emo, outdir, fix_t, fix_m, fix_p, fix_s, fix_v):
success += 1
else:
fail += 1
Expand Down Expand Up @@ -362,10 +414,12 @@ def success_rate(total=100, subset="EMOPIA", labels=["Q1", "Q2", "Q3", "Q4"]):
if os.path.exists(EXPERIMENT_DIR):
shutil.rmtree(EXPERIMENT_DIR)

generate_exps()
generate_exps(fix_t=True)
generate_exps(fix_m=True)
generate_exps(fix_p=True)
generate_exps(fix_t=True, fix_m=True, fix_p=True)
success_rate()
success_rate(subset="VGMIDI")
generate_exps() # no ablation
generate_exps(fix_t=False) # ablate tempo
generate_exps(fix_m=False) # ablate mode
generate_exps(fix_p=False) # ablate avg_pitch (octave)
generate_exps(fix_s=False) # ablate pitch_std
generate_exps(fix_t=False) # ablate volume

success_rate() # calc render success rate for EMOPIA
success_rate(subset="VGMIDI") # calc render success rate for VGMIDI

0 comments on commit 4aa071e

Please sign in to comment.