From 87817c14a8c03771da1d505ad92ad1956add1f0b Mon Sep 17 00:00:00 2001 From: Jin Dou <25721564+powerfulbean@users.noreply.github.com> Date: Wed, 19 Jun 2024 13:11:05 +0800 Subject: [PATCH 1/4] fix bug of decoder training in nested_crossval https://github.com/powerfulbean/mTRFpy/issues/32 --- mtrf/stats.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/mtrf/stats.py b/mtrf/stats.py index bfc98c3..abb9d90 100644 --- a/mtrf/stats.py +++ b/mtrf/stats.py @@ -266,13 +266,21 @@ def nested_crossval( regularization_split_i = list(regularization)[np.argmax(metric)] else: regularization_split_i = regularization - model.train( - [stimulus[i] for i in idx_train_val], - [response[i] for i in idx_train_val], - fs, - tmin, - tmax, - regularization_split_i, + # model.train( + # [stimulus[i] for i in idx_train_val], + # [response[i] for i in idx_train_val], + # fs, + # tmin, + # tmax, + # regularization_split_i, + # ) + model._train( + [x[i] for i in idx_train_val], + [y[i] for i in idx_train_val], + fs, + tmin, + tmax, + regularization_split_i ) _, metric_test[split_i] = model.predict( [stimulus[i] for i in idx_test], [response[i] for i in idx_test] From b64cdeb875819bffe9d3e77080b1ccf4a90c9ce7 Mon Sep 17 00:00:00 2001 From: Jin Dou <25721564+powerfulbean@users.noreply.github.com> Date: Wed, 19 Jun 2024 13:29:09 +0800 Subject: [PATCH 2/4] Update stats.py fix code style bug --- mtrf/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mtrf/stats.py b/mtrf/stats.py index abb9d90..41d641c 100644 --- a/mtrf/stats.py +++ b/mtrf/stats.py @@ -280,7 +280,7 @@ def nested_crossval( fs, tmin, tmax, - regularization_split_i + regularization_split_i, ) _, metric_test[split_i] = model.predict( [stimulus[i] for i in idx_test], [response[i] for i in idx_test] From 6a75d477465deecf4e1a00fc833b02aab9b8d79a Mon Sep 17 00:00:00 2001 From: Jin Dou <25721564+powerfulbean@users.noreply.github.com> Date: Wed, 19 Jun 2024 13:42:14 +0800 Subject: [PATCH 3/4] fix style check bug --- mtrf/stats.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mtrf/stats.py b/mtrf/stats.py index 41d641c..46134b0 100644 --- a/mtrf/stats.py +++ b/mtrf/stats.py @@ -275,11 +275,11 @@ def nested_crossval( # regularization_split_i, # ) model._train( - [x[i] for i in idx_train_val], - [y[i] for i in idx_train_val], - fs, - tmin, - tmax, + [x[i] for i in idx_train_val], + [y[i] for i in idx_train_val], + fs, + tmin, + tmax, regularization_split_i, ) _, metric_test[split_i] = model.predict( From 6263e038e8eab6f2e6ee860046ea1de6155393b5 Mon Sep 17 00:00:00 2001 From: Jin Dou <25721564+powerfulbean@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:31:07 +0800 Subject: [PATCH 4/4] Update stats.py remove commented out code --- mtrf/stats.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mtrf/stats.py b/mtrf/stats.py index 46134b0..3479a4b 100644 --- a/mtrf/stats.py +++ b/mtrf/stats.py @@ -266,14 +266,6 @@ def nested_crossval( regularization_split_i = list(regularization)[np.argmax(metric)] else: regularization_split_i = regularization - # model.train( - # [stimulus[i] for i in idx_train_val], - # [response[i] for i in idx_train_val], - # fs, - # tmin, - # tmax, - # regularization_split_i, - # ) model._train( [x[i] for i in idx_train_val], [y[i] for i in idx_train_val],