Skip to content

Commit

Permalink
[patch:lib] Fix issues with KNNClassifier deserialization (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
eonu authored May 21, 2020
1 parent af98254 commit eab6fbf
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/sequentia/classifiers/knn/knn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,11 @@ def load(cls, path, encoding='utf-8', metric=euclidean, weighting=(lambda x: 1))
with h5py.File(path, 'r') as f:
# Deserialize the model hyper-parameters
params = f['params']
clf = cls(k=int(params['k'][()]), radius=int(params['radius'][()]), metric=metric)
clf = cls(k=int(params['k'][()]), radius=int(params['radius'][()]), metric=metric, weighting=weighting)

# Deserialize the training data and labels
X, y = f['data']['X'], f['data']['y']
clf._X = [np.array(X[k]) for k in X.keys()]
clf._X = [np.array(X[k]) for k in sorted(X.keys(), key=lambda k: int(k))]
clf._y = [label.decode(encoding) for label in y]

return clf

0 comments on commit eab6fbf

Please sign in to comment.