Skip to content

Commit

Permalink
Merge pull request #106 from amarquand/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
amarquand authored Nov 9, 2022
2 parents 829ba14 + 93ea62a commit 36b7ce9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 38 deletions.
9 changes: 4 additions & 5 deletions pcntoolkit/normative.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,10 +1015,10 @@ def transfer(covfile, respfile, testcov=None, testresp=None, maskfile=None,
print("Warning: redundant batch effect parameterisation. Using HBR syntax")

yhat, s2 = nm.predict(Xte, X, Y[:, i],
adaptcovfile = covfile,
adaptrespfile = respfile,
adaptvargroupfile = trbefile,
testvargroupfile = tsbefile,
adaptcov = X,
adaptresp = Y[:, i],
adaptvargroup = batch_effects_train,
testvargroup = batch_effects_test,
**kwargs)

if testcov is not None:
Expand Down Expand Up @@ -1060,7 +1060,6 @@ def transfer(covfile, respfile, testcov=None, testresp=None, maskfile=None,
Z = (Yte - Yhat) / np.sqrt(S2)

print("Evaluating the model ...")
#results = evaluate(Yte, Yhat, S2=S2, mY=mY, sY=sY)
if meta_data and not warp:
results = evaluate(Yte, Yhat, S2=S2, mY=mY, sY=sY)
else:
Expand Down
79 changes: 46 additions & 33 deletions pcntoolkit/normative_model/norm_blr.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,59 +170,72 @@ def predict(self, Xs, X=None, y=None, **kwargs):
theta = self.theta # always use the estimated coefficients
# remove from kwargs to avoid downstream problems
kwargs.pop('theta', None)



Phis = create_poly_basis(Xs, self._model_order)

if X is None:
Phi =None
Phi = None
else:
Phi = create_poly_basis(X, self._model_order)

# process variance groups for the test data
if 'testvargroupfile' in kwargs:
var_groups_test_file = kwargs.pop('testvargroupfile')
if var_groups_test_file.endswith('.pkl'):
var_groups_te = pd.read_pickle(var_groups_test_file)
else:
var_groups_te = np.loadtxt(var_groups_test_file)
if 'testvargroup' in kwargs:
var_groups_te = kwargs.pop('testvargroup')
else:
var_groups_te = None
if 'testvargroupfile' in kwargs:
var_groups_test_file = kwargs.pop('testvargroupfile')
if var_groups_test_file.endswith('.pkl'):
var_groups_te = pd.read_pickle(var_groups_test_file)
else:
var_groups_te = np.loadtxt(var_groups_test_file)
else:
var_groups_te = None

# process test variance covariates
if 'testvarcovfile' in kwargs:
var_cov_test_file = kwargs.get('testvarcovfile')
if var_cov_test_file.endswith('.pkl'):
var_cov_te = pd.read_pickle(var_cov_test_file)
else:
var_cov_te = np.loadtxt(var_cov_test_file)
if 'testvarcov' in kwargs:
var_cov_te = kwargs.pop('testvarcov')
else:
var_cov_te = None
if 'testvarcovfile' in kwargs:
var_cov_test_file = kwargs.get('testvarcovfile')
if var_cov_test_file.endswith('.pkl'):
var_cov_te = pd.read_pickle(var_cov_test_file)
else:
var_cov_te = np.loadtxt(var_cov_test_file)
else:
var_cov_te = None

# do we want to adjust the responses?
if 'adaptrespfile' in kwargs:
y_adapt = fileio.load(kwargs.pop('adaptrespfile'))
if len(y_adapt.shape) == 1:
y_adapt = y_adapt[:, np.newaxis]
if 'adaptresp' in kwargs:
y_adapt = kwargs.pop('adaptresp')
else:
y_adapt = None
if 'adaptrespfile' in kwargs:
y_adapt = fileio.load(kwargs.pop('adaptrespfile'))
if len(y_adapt.shape) == 1:
y_adapt = y_adapt[:, np.newaxis]
else:
y_adapt = None

if 'adaptcovfile' in kwargs:
X_adapt = fileio.load(kwargs.pop('adaptcovfile'))
if 'adaptcov' in kwargs:
X_adapt = kwargs.pop('adaptcov')
Phi_adapt = create_poly_basis(X_adapt, self._model_order)
else:
Phi_adapt = None

if 'adaptvargroupfile' in kwargs:
var_groups_adapt_file = kwargs.pop('adaptvargroupfile')
if var_groups_adapt_file.endswith('.pkl'):
var_groups_ad = pd.read_pickle(var_groups_adapt_file)
if 'adaptcovfile' in kwargs:
X_adapt = fileio.load(kwargs.pop('adaptcovfile'))
Phi_adapt = create_poly_basis(X_adapt, self._model_order)
else:
var_groups_ad = np.loadtxt(var_groups_adapt_file)
Phi_adapt = None

if 'adaptvargroup' in kwargs:
var_groups_ad = kwargs.pop('adaptvargroup')
else:
var_groups_ad = None

if 'adaptvargroupfile' in kwargs:
var_groups_adapt_file = kwargs.pop('adaptvargroupfile')
if var_groups_adapt_file.endswith('.pkl'):
var_groups_ad = pd.read_pickle(var_groups_adapt_file)
else:
var_groups_ad = np.loadtxt(var_groups_adapt_file)
else:
var_groups_ad = None

if y_adapt is None:
yhat, s2 = self.blr.predict(theta, Phi, y, Phis,
Expand Down

0 comments on commit 36b7ce9

Please sign in to comment.