From f085fd84ae458b911938c3c97bfec2455a295a98 Mon Sep 17 00:00:00 2001 From: "Eric S. Tellez" Date: Mon, 2 May 2022 12:38:20 -0500 Subject: [PATCH] improves support for new database abstractions --- Project.toml | 4 ++-- src/knc.jl | 2 +- src/kncproto.jl | 13 ++++++------- test/loaddata.jl | 4 ++-- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 8bc30b4..572e71c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KNearestCenters" uuid = "4dca28ae-43b8-11eb-1f5e-d55054101997" authors = ["Eric S. Tellez"] -version = "0.6.2" +version = "0.7.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -17,7 +17,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] CategoricalArrays = "0.8, 0.9, 0.10" -KCenters = "0.5.0" +KCenters = "0.6" MLDataUtils = "0.5" Parameters = "0.12" SearchModels = "0.3" diff --git a/src/knc.jl b/src/knc.jl index 2a6a91c..dddde43 100644 --- a/src/knc.jl +++ b/src/knc.jl @@ -54,7 +54,7 @@ end Creates a Knc classifier using the given configuration and data. """ -function Knc(config::KncConfig, X, y::CategoricalArray; verbose=true) +function Knc(config::KncConfig, X::AbstractDatabase, y::CategoricalArray; verbose=true) # computes a set of #labels centers using labels for clustering D = kcenters(config.kernel.dist, X, y, config.centerselection) @assert length(levels(y)) == length(D.centers) diff --git a/src/kncproto.jl b/src/kncproto.jl index 3ac4365..2985b45 100644 --- a/src/kncproto.jl +++ b/src/kncproto.jl @@ -126,9 +126,8 @@ end Creates a KncProto classifier using the given configuration and data. """ -function KncProto(config::KncProtoConfig, X, y::CategoricalArray; verbose=true) +function KncProto(config::KncProtoConfig, X::AbstractDatabase, y::CategoricalArray; verbose=true) config.ncenters == 0 && error("invalid ncenter $ncenters; ncenters <= -2 or 2 <= ncenters; please use plain Knc otherwise") - X = convert(AbstractDatabase, X) if config.ncenters > 0 # computes a set of ncenters for all dataset verbose && println(stderr, "KncProto> clustering data without knowing labels", config) @@ -144,7 +143,7 @@ function KncProto(config::KncProtoConfig, X, y::CategoricalArray; verbose=true) ncenters = abs(config.ncenters) verbose && println(stderr, "KncProto> clustering data with label division", config) nclasses = length(levels(y)) - centers = VectorDatabase(eltype(X)) + centers = eltype(X)[] dmax = Float32[] class_map = Int32[] nclasses = length(levels(y)) @@ -170,19 +169,19 @@ function KncProto(config::KncProtoConfig, X, y::CategoricalArray; verbose=true) end end - KncProto(config, centers, dmax, class_map, convert(Int32, nclasses), KnnResult(1)) + KncProto(config, VectorDatabase(centers), dmax, class_map, convert(Int32, nclasses), KnnResult(1)) end end function KncProto( config::KncProtoConfig, input_clusters::ClusteringData, - train_X, + train_X::AbstractDatabase, train_y::CategoricalArray; verbose=false ) train_X = convert(AbstractDatabase, train_X) - centers = VectorDatabase(eltype(train_X)) # clusters + centers = eltype(train_X)[] # clusters classes = Int32[] # class mapping between clusters and classes dmax = Float32[] ncenters = length(input_clusters.centers) @@ -243,7 +242,7 @@ function KncProto( end verbose && println(stderr, "finished with $(length(centers)) centers; started with $(length(input_clusters.centers))") - KncProto(config, centers, dmax, classes, convert(Int32, nclasses), KnnResult(1)) + KncProto(config, VectorDatabase(centers), dmax, classes, convert(Int32, nclasses), KnnResult(1)) end """ diff --git a/test/loaddata.jl b/test/loaddata.jl index 68e00ed..75921fc 100644 --- a/test/loaddata.jl +++ b/test/loaddata.jl @@ -1,6 +1,6 @@ # This file is a part of KNearestCenters.jl -using Test +using Test, SimilaritySearch function loadiris() url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" @@ -19,7 +19,7 @@ function loadiris() push!(y, arr[end]) end - X, y + VectorDatabase(X), y end