diff --git a/lib/sequentia/classifiers/knn/knn_classifier.py b/lib/sequentia/classifiers/knn/knn_classifier.py index 444f8002..d3fcf3aa 100644 --- a/lib/sequentia/classifiers/knn/knn_classifier.py +++ b/lib/sequentia/classifiers/knn/knn_classifier.py @@ -241,11 +241,11 @@ def load(cls, path, encoding='utf-8', metric=euclidean, weighting=(lambda x: 1)) with h5py.File(path, 'r') as f: # Deserialize the model hyper-parameters params = f['params'] - clf = cls(k=int(params['k'][()]), radius=int(params['radius'][()]), metric=metric) + clf = cls(k=int(params['k'][()]), radius=int(params['radius'][()]), metric=metric, weighting=weighting) # Deserialize the training data and labels X, y = f['data']['X'], f['data']['y'] - clf._X = [np.array(X[k]) for k in X.keys()] + clf._X = [np.array(X[k]) for k in sorted(X.keys(), key=lambda k: int(k))] clf._y = [label.decode(encoding) for label in y] return clf \ No newline at end of file