-
Notifications
You must be signed in to change notification settings - Fork 0
/
knn_classification.m
76 lines (59 loc) · 2.26 KB
/
knn_classification.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
function [acc, corr, cmat] = knn_classification(D, classes, k)
% Performs a k-nearest neighbor classification experiment. If there is a
% tie, the nearest neighbor determines the class
%
% This file is part of the HUB TOOLBOX available at
% http://ofai.at/research/impml/projects/hubology.html
% https://github.com/OFAI/hub-toolbox-matlab/
% (c) 2013, Dominik Schnitzer <dominik.schnitzer@ofai.at>
% (c) 2016, Roman Feldbauer <roman.feldbauer@ofai.at>
%
% Usage:
% [acc, corr, cmat] = knn_classification(D, classes, k) - Use the distance
% matrix D (NxN) and the classes and perform a k-NN experiment. The
% classification accuracy is returned in acc. corr is a raw vector of the
% correctly classified items. cmat is the confusion matrix.
acc = zeros(length(k), 1);
corr = zeros(length(D), length(k));
n = length(D);
cl = sort(unique(classes));
cmat = zeros(length(cl));
for i = 1:n
classes(i) = find(cl == classes(i));
end
% match the class labels in 'cl' with those in 'classes'
cl = 1:length(cl);
for i = 1:n
seed_class = classes(i);
row = D(i, :);
row(i) = +Inf;
% Randomize, in case there are several points of same distance
% (especially relevant for SNN rescaling)
rp = randperm(size(D, 2));
d2 = row(rp);
[~, d2idx] = sort(d2, 'ascend');
idx = rp(d2idx);
% OLD code, non randomized
%[tmp, idx] = sort(row);
for j = 1:length(k)
nn_class = classes(idx(1:k(j)));
cs = histc(nn_class, cl);
id = find(cs == max(cs));
% "tie": use nearest neighbor
if (length(id) > 1)
if (seed_class == nn_class(1))
acc(j) = acc(j) + 1/n;
corr(i,j) = 1;
end
cmat(seed_class, nn_class(1)) = cmat(seed_class, nn_class(1)) + 1;
% majority vote
else
if (cl(id) == seed_class)
acc(j) = acc(j) + 1/n;
corr(i,j) = 1;
end
cmat(seed_class, cl(id)) = cmat(seed_class, cl(id)) + 1;
end
end
end
end