Skip to content

Commit

Permalink
one more b1 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
joel99 committed Aug 21, 2024
1 parent 93711fb commit dbdd420
Showing 1 changed file with 11 additions and 24 deletions.
35 changes: 11 additions & 24 deletions falcon_challenge/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,35 +659,22 @@ def normalize_signal(x):
for sess_idx in range(len(preds)):
prd, tgt, msk = np.array(preds[sess_idx]), np.array(targets[sess_idx]), np.array(eval_mask[sess_idx])

is_tgt_flattened = prd.shape != tgt.shape
if is_tgt_flattened:
logger.warning(f"Target may already be flattened. Target shape: {tgt.shape} vs Prediction shape: {prd.shape}.")
if prd.shape != msk.shape:
# if prd.shape != tgt.shape or prd.shape != msk.shape:
raise ValueError(f"Targets and predictions have different lengths: {len(tgt)} vs {len(prd)}.")

# Reshape to normalize and compute error at the trial level:
if not is_tgt_flattened:
tgt = tgt.reshape(-1, trial_len, tgt.shape[-2], tgt.shape[-1])
prd = prd.reshape(-1, trial_len, prd.shape[-2], prd.shape[-1])
msk = msk.reshape(-1, trial_len, msk.shape[-2], msk.shape[-1])

samples, frequencies = prd.shape[1], prd.shape[-1]

error_per_trial = []
for trial in range(len(prd)):

sess_sxx_eval_mask = msk[trial].reshape(samples, frequencies)
if not is_tgt_flattened:
original_sxx_masked = tgt[trial].reshape(samples, frequencies)[sess_sxx_eval_mask]
else:
original_sxx_masked = tgt[trial]
reconstructed_sxx_masked = prd[trial].reshape(samples, frequencies)[sess_sxx_eval_mask]
is_tgt_flattened = prd.shape != tgt.shape
msk = msk.flatten()
if is_tgt_flattened:
logger.warning(f"Target may already be flattened. Target shape: {tgt.shape} vs Prediction shape: {prd.shape}.")
else:
tgt = tgt.flatten()
tgt = tgt[msk]
prd = prd.flatten()
prd = prd[msk]

# Calculate spectrogram reconstruction error
error_per_trial.append(mean_squared_error(normalize_signal(original_sxx_masked), normalize_signal(reconstructed_sxx_masked)))

error_per_session.append(np.mean(error_per_trial))
# Calculate spectrogram reconstruction error
error_per_session.append(mean_squared_error(normalize_signal(tgt), normalize_signal(prd)))

base_metrics = {
"MSE Mean": np.mean(error_per_session),
Expand Down

0 comments on commit dbdd420

Please sign in to comment.