Skip to content

Commit

Permalink
add temporal generalization, general integration with RSAtoolbox
Browse files Browse the repository at this point in the history
  • Loading branch information
henrymj committed Nov 1, 2024
1 parent b9aa374 commit bfd3ab6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 26 deletions.
25 changes: 25 additions & 0 deletions mvEEG/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,31 @@ def decode_across_time(self, X_train, X_test, y_train, y_test):
return accs, accs_shuff, conf_mats, confidence_scores

def temporally_generalize(self, X_train, X_test, y_train, y_test):
"""
Perform temporal generalization by training and testing a classifier
across different time points and calculating accuracy, shuffled accuracy,
confusion matrices, and confidence scores.
Parameters:
X_train : ndarray
Training data of shape (n_samples, n_features, n_times).
X_test : ndarray
Testing data of shape (n_samples, n_features, n_times).
y_train : ndarray
Labels for the training data of shape (n_samples,).
y_test : ndarray
Labels for the testing data of shape (n_samples,).
Returns:
accs : ndarray
Accuracy scores of shape (n_times, n_times) for each pair of train and test times.
accs_shuff : ndarray
Shuffled accuracy scores of shape (n_times, n_times) for each pair of train and test times.
conf_mats : ndarray
Confusion matrices of shape (n_labels, n_labels, n_times, n_times) for each pair of train and test times.
confidence_scores : ndarray
Confidence scores of shape (n_labels, n_times, n_times) for each pair of train and test times.
"""
ntimes = X_train.shape[2]
accs = np.full((ntimes, ntimes), np.nan)
accs_shuff = np.full((ntimes, ntimes), np.nan)
Expand Down
46 changes: 20 additions & 26 deletions mvEEG/crossnobis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from sklearn.covariance import LedoitWolf
from sklearn.preprocessing import LabelEncoder
from rsatoolbox.rdm.calc import _calc_rdm_crossnobis_single


class Crossnobis:
Expand All @@ -14,7 +15,9 @@ class Crossnobis:
"""

def __init__(self):
def __init__(self, labels):
self.labels = labels
self.n_labels = len(labels)
pass

def _mean_by_condition(self, X, conds):
Expand Down Expand Up @@ -55,26 +58,7 @@ def _means_and_prec(self, X, conds):

return cond_means, inv_cov

def _calc_rdm_crossnobis_single(self, X_train, X_test, precision):
"""
Calculates RDM using LDC using means from x and y, and covariance
Largely taken from https://github.com/rsagroup/rsatoolbox/blob/main/src/rsatoolbox/rdm/calc.py#L469
Updated to return the signed square root of the RDM because
LDC is an estimator of the squared mahalonobis distance
Args:
X_train (np.ndarray, shape (n_conditions, n_channels)): Condition averages for training data (first measure)
meas2 (np.ndarray, shape (n_conditions, n_channels)): Condition averages for testing data (second measure)
noise (np.ndarray, shape (n_channels, n_channels)): Precision (inverse covariance) matrix
Returns:
rdm (np.ndarray, shape (n_conditions, n_conditions)): RDM
"""
kernel = X_train @ precision @ X_test.T
rdm = np.expand_dims(np.diag(kernel), 0) + np.expand_dims(np.diag(kernel), 1) - kernel - kernel.T
return np.sign(rdm) * np.sqrt(np.abs(rdm))

def crossnobis_single(self, X_train, X_test, y_train, y_test):
def crossnobis(self, X_train, X_test, y_train, y_test):
"""
Wrapper function to calculate crossnobis RDM over a single fold
Uses condition means from both train and test, but only uses the training
Expand All @@ -92,7 +76,7 @@ def crossnobis_single(self, X_train, X_test, y_train, y_test):
"""
means_train, noise_train = self._means_and_prec(X_train, y_train)
means_test = self._mean_by_condition(X_test, y_test)
rdm = self._calc_rdm_crossnobis_single(means_train, means_test, noise_train)
rdm = _calc_rdm_crossnobis_single(means_train, means_test, noise_train)
return rdm

def crossnobis_across_time(self, X_train, X_test, y_train, y_test):
Expand All @@ -111,9 +95,19 @@ def crossnobis_across_time(self, X_train, X_test, y_train, y_test):
ntimes = X_train.shape[2]

rdm = np.stack(
[
self.crossnobis_single(X_train[:, :, itime], X_test[:, :, itime], y_train, y_test)
for itime in range(ntimes)
]
[self.crossnobis(X_train[:, :, itime], X_test[:, :, itime], y_train, y_test) for itime in range(ntimes)]
)
return rdm


def temporally_generalize(self, X_train, X_test, y_train, y_test):
ntimes = X_train.shape[2]
rdms = np.full((self.n_labels, self.n_labels, ntimes, ntimes), np.nan)

for itime in range(ntimes): # train times
means_i, noise_i = self._means_and_prec(X_train[:, :, itime], y_train)
for jtime in range(ntimes): # test times
means_j = self._mean_by_condition(X_test[:, :, jtime], y_test)
rdms[:, itime, jtime] = _calc_rdm_crossnobis_single(means_i, means_j, noise_i)

return rdms

0 comments on commit bfd3ab6

Please sign in to comment.