Skip to content

Commit

Permalink
Skip dispatching to GPU for unimplemented metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
betatim committed Jan 14, 2025
1 parent 47bac70 commit 2788304
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
16 changes: 16 additions & 0 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,22 @@ class UMAP(UniversalBase,
_cpu_estimator_import_path = 'umap.UMAP'
embedding_ = CumlArrayDescriptor(order='C')

_hyperparam_interop_translator = {
"metric": {
"sokalsneath": "NotImplemented",
"rogerstanimoto": "NotImplemented",
"sokalmichener": "NotImplemented",
"yule": "NotImplemented",
"ll_dirichlet": "NotImplemented",
"russelrao": "NotImplemented",
"kulsinski": "NotImplemented",
"dice": "NotImplemented",
"wminkowski": "NotImplemented",
"mahalanobis": "NotImplemented",
"haversine": "NotImplemented",
}
}

@device_interop_preparation
def __init__(self, *,
n_neighbors=15,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,10 +46,32 @@ def test_umap_min_dist(manifold_data, min_dist):


@pytest.mark.parametrize(
"metric", ["euclidean", "manhattan", "chebyshev", "cosine"]
"metric",
[
"euclidean",
"manhattan",
"chebyshev",
"cosine",
# These metrics are currently not supported in cuml,
# we test them here to make sure no exception is raised
"sokalsneath",
"rogerstanimoto",
"sokalmichener",
"yule",
"ll_dirichlet",
"russellrao",
"kulsinski",
"dice",
"wminkowski",
"mahalanobis",
"haversine",
],
)
def test_umap_metric(manifold_data, metric):
X = manifold_data
# haversine only works for 2D data
if metric == "haversine":
X = X[:, :2]
umap = UMAP(metric=metric, random_state=42)
X_embedded = umap.fit_transform(X)
trust = trustworthiness(X, X_embedded, n_neighbors=5)
Expand Down

0 comments on commit 2788304

Please sign in to comment.