Skip to content

Commit

Permalink
refactor: update pydantic and typer version.
Browse files Browse the repository at this point in the history
  • Loading branch information
Madson Luiz Dantas Dias (UFC) committed Mar 18, 2024
1 parent 721f8ef commit 5d6384e
Show file tree
Hide file tree
Showing 3 changed files with 2,286 additions and 1,917 deletions.
14 changes: 6 additions & 8 deletions fcmeans/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from joblib import Parallel, delayed
import numpy as np
from numpy.typing import NDArray
from pydantic import BaseModel, Extra, Field, validate_arguments
from pydantic import BaseModel, ConfigDict, Field, validate_call
import tqdm


Expand Down Expand Up @@ -39,22 +39,20 @@ class FCM(BaseModel):
ReferenceError: If called without the model being trained
"""

class Config:
extra = Extra.allow
arbitrary_types_allowed = True
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)

n_clusters: int = Field(5, ge=1)
max_iter: int = Field(150, ge=1, le=1000)
m: float = Field(2.0, ge=1.0)
error: float = Field(1e-5, ge=1e-9)
random_state: Optional[int] = None
trained: bool = Field(False, const=True)
trained: bool = False
n_jobs: int = Field(1, ge=1)
verbose: Optional[bool] = False
distance: Optional[Union[DistanceOptions, Callable]] = DistanceOptions.euclidean
distance_params: Optional[Dict] = {}

@validate_arguments(config=dict(arbitrary_types_allowed=True))
@validate_call(config=dict(arbitrary_types_allowed=True))
def fit(self, X: NDArray) -> None:
"""Train the fuzzy-c-means model
Expand All @@ -76,7 +74,7 @@ def fit(self, X: NDArray) -> None:
break
self.trained = True

@validate_arguments(config=dict(arbitrary_types_allowed=True))
@validate_call(config=dict(arbitrary_types_allowed=True))
def soft_predict(self, X: NDArray) -> NDArray:
"""Soft predict of FCM
Expand All @@ -95,7 +93,7 @@ def soft_predict(self, X: NDArray) -> NDArray:
u_dist = np.vstack(u_dist).T
return 1 / u_dist

@validate_arguments(config=dict(arbitrary_types_allowed=True))
@validate_call(config=dict(arbitrary_types_allowed=True))
def predict(self, X: NDArray) -> NDArray:
"""Predict the closest cluster each sample in X belongs to.
Expand Down
Loading

0 comments on commit 5d6384e

Please sign in to comment.