diff --git a/mtrf/stats.py b/mtrf/stats.py index f63ea13..f5091dc 100644 --- a/mtrf/stats.py +++ b/mtrf/stats.py @@ -366,21 +366,13 @@ def nested_crossval( regularization_split_i = list(regularization)[np.argmax(metric)] else: regularization_split_i = regularization - # model.train( - # [stimulus[i] for i in idx_train_val], - # [response[i] for i in idx_train_val], - # fs, - # tmin, - # tmax, - # regularization_split_i, - # ) model._train( - [x[i] for i in idx_train_val], - [y[i] for i in idx_train_val], - fs, - tmin, - tmax, - regularization_split_i + [x[i] for i in idx_train_val], + [y[i] for i in idx_train_val], + fs, + tmin, + tmax, + regularization_split_i, ) _, t_metric_test = model.predict( [stimulus[i] for i in idx_test],