Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Move logistic regression parameter validation to fit() #6109

Draft
wants to merge 2 commits into
base: branch-24.12
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 44 additions & 42 deletions python/cuml/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -211,68 +211,54 @@ class LogisticRegression(UniversalBase,
verbose=verbose,
output_type=output_type)

if penalty not in supported_penalties:
raise ValueError("`penalty` " + str(penalty) + " not supported.")

if solver not in supported_solvers:
raise ValueError("Only quasi-newton `qn` solver is "
" supported, not %s" % solver)
self.solver = solver

self.C = C

if penalty == "none":
warnings.warn(
"The 'none' option was deprecated in version 24.06, and will "
"be removed in 25.08. Use None instead.",
FutureWarning
)
penalty = None
self.penalty = penalty

self.tol = tol
self.fit_intercept = fit_intercept
self.max_iter = max_iter
self.linesearch_max_iter = linesearch_max_iter
self.l1_ratio = None
if self.penalty == "elasticnet":
if l1_ratio is None:
raise ValueError(
"l1_ratio has to be specified for" "loss='elasticnet'"
)
if l1_ratio < 0.0 or l1_ratio > 1.0:
msg = "l1_ratio value has to be between 0.0 and 1.0"
raise ValueError(msg.format(l1_ratio))
self.l1_ratio = l1_ratio

l1_strength, l2_strength = self._get_qn_params()

loss = "sigmoid"
self.l1_ratio = l1_ratio

if class_weight is not None:
self._build_class_weights(class_weight)
else:
self.class_weight = None

self.solver_model = QN(
loss=loss,
fit_intercept=self.fit_intercept,
l1_strength=l1_strength,
l2_strength=l2_strength,
max_iter=self.max_iter,
linesearch_max_iter=self.linesearch_max_iter,
tol=self.tol,
verbose=self.verbose,
handle=self.handle,
)

if logger.should_log_for(logger.level_debug):
self.verb_prefix = "CY::"
logger.debug(self.verb_prefix + "Estimator parameters:")
logger.debug(pprint.pformat(self.__dict__))
else:
self.verb_prefix = ""

def _validate_params(self):
if self.penalty not in supported_penalties:
raise ValueError("`penalty` " + str(self.penalty) + " not supported.")

if self.solver not in supported_solvers:
raise ValueError("Only quasi-newton `qn` solver is "
" supported, not %s" % self.solver)

if self.penalty == "none":
warnings.warn(
"The 'none' option was deprecated in version 24.06, and will "
"be removed in 25.08. Use None instead.",
FutureWarning
)
penalty = None
else:
penalty = self.penalty

if penalty == "elasticnet":
if self.l1_ratio is None:
raise ValueError(
"l1_ratio has to be specified for" "loss='elasticnet'"
)
if self.l1_ratio < 0.0 or self.l1_ratio > 1.0:
msg = "l1_ratio value has to be between 0.0 and 1.0"
raise ValueError(msg.format(self.l1_ratio))
betatim marked this conversation as resolved.
Show resolved Hide resolved

@generate_docstring(X='dense_sparse')
@cuml.internals.api_base_return_any(set_output_dtype=True)
@enable_device_interop
Expand All @@ -282,6 +268,22 @@ class LogisticRegression(UniversalBase,
Fit the model with X and y.

"""
self._validate_params()

l1_strength, l2_strength = self._get_qn_params()
self.solver_model = QN(
loss="sigmoid",
fit_intercept=self.fit_intercept,
l1_strength=l1_strength,
l2_strength=l2_strength,
max_iter=self.max_iter,
linesearch_max_iter=self.linesearch_max_iter,
tol=self.tol,
verbose=self.verbose,
handle=self.handle,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
handle=self.handle,
handle=self.handle,
output_type=self.output_type

If I'm not mistaken, that's what 99% of the pytest failures are complaining about

output_type=self.output_type,
)

self.n_features_in_ = X.shape[1] if X.ndim == 2 else 1
if hasattr(X, 'index'):
self.feature_names_in_ = X.index
Expand Down
Loading