diff --git a/python/cuml/cuml/linear_model/logistic_regression.pyx b/python/cuml/cuml/linear_model/logistic_regression.pyx index 164821a5bd..90b275fc8e 100644 --- a/python/cuml/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/cuml/linear_model/logistic_regression.pyx @@ -211,61 +211,20 @@ class LogisticRegression(UniversalBase, verbose=verbose, output_type=output_type) - if penalty not in supported_penalties: - raise ValueError("`penalty` " + str(penalty) + " not supported.") - - if solver not in supported_solvers: - raise ValueError("Only quasi-newton `qn` solver is " - " supported, not %s" % solver) self.solver = solver - self.C = C - - if penalty == "none": - warnings.warn( - "The 'none' option was deprecated in version 24.06, and will " - "be removed in 25.08. Use None instead.", - FutureWarning - ) - penalty = None self.penalty = penalty - self.tol = tol self.fit_intercept = fit_intercept self.max_iter = max_iter self.linesearch_max_iter = linesearch_max_iter - self.l1_ratio = None - if self.penalty == "elasticnet": - if l1_ratio is None: - raise ValueError( - "l1_ratio has to be specified for" "loss='elasticnet'" - ) - if l1_ratio < 0.0 or l1_ratio > 1.0: - msg = "l1_ratio value has to be between 0.0 and 1.0" - raise ValueError(msg.format(l1_ratio)) - self.l1_ratio = l1_ratio - - l1_strength, l2_strength = self._get_qn_params() - - loss = "sigmoid" + self.l1_ratio = l1_ratio if class_weight is not None: self._build_class_weights(class_weight) else: self.class_weight = None - self.solver_model = QN( - loss=loss, - fit_intercept=self.fit_intercept, - l1_strength=l1_strength, - l2_strength=l2_strength, - max_iter=self.max_iter, - linesearch_max_iter=self.linesearch_max_iter, - tol=self.tol, - verbose=self.verbose, - handle=self.handle, - ) - if logger.should_log_for(logger.level_debug): self.verb_prefix = "CY::" logger.debug(self.verb_prefix + "Estimator parameters:") @@ -273,6 +232,33 @@ class LogisticRegression(UniversalBase, else: self.verb_prefix = "" + def _validate_params(self): + if self.penalty not in supported_penalties: + raise ValueError("`penalty` " + str(self.penalty) + " not supported.") + + if self.solver not in supported_solvers: + raise ValueError("Only quasi-newton `qn` solver is " + " supported, not %s" % self.solver) + + if self.penalty == "none": + warnings.warn( + "The 'none' option was deprecated in version 24.06, and will " + "be removed in 25.08. Use None instead.", + FutureWarning + ) + penalty = None + else: + penalty = self.penalty + + if penalty == "elasticnet": + if self.l1_ratio is None: + raise ValueError( + "l1_ratio has to be specified for" "loss='elasticnet'" + ) + if self.l1_ratio < 0.0 or self.l1_ratio > 1.0: + msg = "l1_ratio value has to be between 0.0 and 1.0" + raise ValueError(msg.format(self.l1_ratio)) + @generate_docstring(X='dense_sparse') @cuml.internals.api_base_return_any(set_output_dtype=True) @enable_device_interop @@ -282,6 +268,22 @@ class LogisticRegression(UniversalBase, Fit the model with X and y. """ + self._validate_params() + + l1_strength, l2_strength = self._get_qn_params() + self.solver_model = QN( + loss="sigmoid", + fit_intercept=self.fit_intercept, + l1_strength=l1_strength, + l2_strength=l2_strength, + max_iter=self.max_iter, + linesearch_max_iter=self.linesearch_max_iter, + tol=self.tol, + verbose=self.verbose, + handle=self.handle, + output_type=self.output_type, + ) + self.n_features_in_ = X.shape[1] if X.ndim == 2 else 1 if hasattr(X, 'index'): self.feature_names_in_ = X.index