Skip to content

Commit

Permalink
Support alpha=0 in Ridge
Browse files Browse the repository at this point in the history
When `alpha = 0`, `Ridge` is equivalent to a `LinearRegression`.
Previously we checked if alpha was positive, scikit-learn instead
requires that alpha is non-negative. This updates the check and adds a
test.
  • Loading branch information
jcrist committed Jan 17, 2025
1 parent d95cae5 commit a507869
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
7 changes: 3 additions & 4 deletions python/cuml/cuml/linear_model/ridge.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -241,9 +241,8 @@ class Ridge(UniversalBase,
self.intercept_value = 0.0

def _check_alpha(self, alpha):
if alpha <= 0.0:
msg = "alpha value has to be positive"
raise TypeError(msg.format(alpha))
if alpha < 0.0:
raise ValueError(f"alpha must be non-negative, got {alpha}")

def _get_algorithm_int(self, algorithm):
if self.solver not in ['svd', 'eig', 'cd']:
Expand Down
11 changes: 11 additions & 0 deletions python/cuml/cuml/tests/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ def test_ridge_regression_model(datatype, algorithm, nrows, column_info):
)


def test_ridge_and_least_squares_equal_when_alpha_is_0():
X, y = make_regression(n_samples=5, n_features=4, random_state=0)

ridge = cuRidge(alpha=0.0, fit_intercept=False)
ols = cuLinearRegression(fit_intercept=False)

ridge.fit(X, y)
ols.fit(X, y)
assert array_equal(ridge.coef_, ols.coef_)


@pytest.mark.parametrize("datatype", [np.float32, np.float64])
@pytest.mark.parametrize("algorithm", ["eig", "svd"])
@pytest.mark.parametrize(
Expand Down

0 comments on commit a507869

Please sign in to comment.