Skip to content

Commit

Permalink
Fix bootstrap fixed random state (#232)
Browse files Browse the repository at this point in the history
* Fix bootstrap fixed random state (closes #231)

Use old-school optional type

Enable colour output in CI for pytest

Fix typo in arg name

Rethink strategy; provide way to fix random seed but still call resample multiple times

Be explicit about number of samples

* Attempt to stratify bootstrap if classification props are present

* Pass `targets_groups` up to Ensemble from member

* use first prop for stratification

---------

Co-authored-by: ppdebreuck <pierre-paul.debreuck@student.uclouvain.be>
  • Loading branch information
ml-evs and ppdebreuck authored Oct 31, 2024
1 parent 8d41f94 commit 53a9a6b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ on:
push:
branches:
- master

env:
# Enable color output for pytest
FORCE_COLOR: true

jobs:

lint:
Expand Down
39 changes: 28 additions & 11 deletions modnet/models/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
bootstrap=True,
models=None,
modnet_models=None,
random_state: Optional[int] = None,
**kwargs,
):
"""
Expand All @@ -55,11 +56,13 @@ def __init__(
n_models: number of inner MODNetModels, each model has the same architecture defined by the args nd kwargs.
bootstrap: whether to bootstrap the samples for each inner MODNet fit.
models: List of user provided MODNetModels. Enables to have different architectures. n_models is discarded in this case.
random_state: fix a random state for use with this model.
modnet_model: Deprecated. Same argument as models. For backward compatibility only.
**kwargs: See MODNetModel
"""
self.__modnet_version__ = __version__
self.bootstrap = bootstrap
self.random_state = random_state
if modnet_models is not None and models is None:
models = modnet_models
if models is None:
Expand All @@ -74,6 +77,7 @@ def __init__(
self.targets = self.models[0].targets
self.weights = self.models[0].weights
self.num_classes = self.models[0].num_classes
self.targets_groups = self.models[0].targets_groups
self.out_act = self.models[0].out_act

def fit(
Expand All @@ -89,18 +93,31 @@ def fit(

if self.bootstrap:
LOG.info("Generating bootstrap data...")
if self.random_state is None:
random_state = self.n_models * [None]
else:
random_state = np.arange(self.n_models) + self.random_state

# Loop over all targets and check if any involve classification, if so, stratify
stratify = None
for prop in self.targets_groups:
if self.num_classes[prop[0]] >= 2: # Classification
stratify = training_data.df_targets[prop[0]].array
break
train_indices = [
resample(
np.arange(len(training_data.df_targets)),
replace=True,
n_samples=len(training_data.df_targets),
random_state=random_state[i],
stratify=stratify,
)
for i in range(self.n_models)
]

train_datas = [
training_data.split(
(
resample(
np.arange(len(training_data.df_targets)),
replace=True,
random_state=2943,
),
[],
)
)[0]
for _ in range(self.n_models)
training_data.split((train_indices[i], []))[0]
for i in range(self.n_models)
]
else:
train_datas = [training_data for _ in range(self.n_models)]
Expand Down

0 comments on commit 53a9a6b

Please sign in to comment.