From 9efca5ba755adf4b0483bc9c5d192fe5336897f1 Mon Sep 17 00:00:00 2001 From: "Eric S. Tellez" Date: Sun, 16 Jan 2022 07:58:58 -0600 Subject: [PATCH] uses the simpler KnnResult api --- Project.toml | 2 +- src/knc.jl | 5 ++--- src/kncproto.jl | 11 +++++------ 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index a6a0f33..8bc30b4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KNearestCenters" uuid = "4dca28ae-43b8-11eb-1f5e-d55054101997" authors = ["Eric S. Tellez"] -version = "0.6.1" +version = "0.6.2" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/knc.jl b/src/knc.jl index 110550d..2a6a91c 100644 --- a/src/knc.jl +++ b/src/knc.jl @@ -63,13 +63,12 @@ end function predict(nc::Knc, x, res::KnnResult=reuse!(nc.res)) C = nc.centers - st = initialstate(res) for i in eachindex(C) d = -evaluate(nc.config.kernel, x, C[i], nc.dmax[i]) - st = push!(res, st, i, d) + push!(res, i, d) end - argmin(res, st) + argmin(res) end function Base.broadcastable(nc::Knc) diff --git a/src/kncproto.jl b/src/kncproto.jl index 818e666..3ac4365 100644 --- a/src/kncproto.jl +++ b/src/kncproto.jl @@ -247,12 +247,12 @@ function KncProto( end """ - most_frequent_label(nc::KncProto, res::KnnResult, st::KnnResultState) + most_frequent_label(nc::KncProto, res::KnnResult) Summary function that computes the label as the most frequent label among labels of the k nearest prototypes (categorical labels) """ -function most_frequent_label(nc::KncProto, res::KnnResult, st::KnnResultState) - c = counts([nc.class_map[id] for id in idview(res, st)], 1:nc.nclasses) +function most_frequent_label(nc::KncProto, res::KnnResult) + c = counts([nc.class_map[id] for id in idview(res)], 1:nc.nclasses) findmax(c)[end] end @@ -273,14 +273,13 @@ Predicts the class of `x` using the label of the `k` nearest centers under the ` """ function predict(nc::KncProto, x, res::KnnResult=reuse!(nc.res)) C = nc.centers - st = initialstate(res) dmax = nc.dmax for i in eachindex(C) s = evaluate(nc.config.kernel, x, C[i], dmax[i]) - st = push!(res, st, i, -s) + push!(res, i, -s) end - most_frequent_label(nc, res, st) + most_frequent_label(nc, res) end function Base.broadcastable(nc::KncProto)