Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding seed for reproducibility and sampling methods #344

Open
wants to merge 64 commits into
base: 1.7.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
e4d8871
sampling and seed
rwilfong Jul 31, 2024
22b0318
now it runs
stewarthe6 Jul 19, 2024
30ea360
kfold changes
stewarthe6 Jul 19, 2024
dc1f7c4
seed test
rwilfong Jul 31, 2024
7b13967
ruff linter suggestions
rwilfong Jul 31, 2024
6fb1c62
updated kfoldregression
rwilfong Aug 1, 2024
480e5f1
Merge remote-tracking branch 'upstream/1.7.0' into 1.7.0
stewarthe6 Sep 11, 2024
fc24463
added imblearn to pip requirements
stewarthe6 Sep 11, 2024
561c3bb
unpin imblearn
stewarthe6 Sep 11, 2024
49dc67b
Clean up unused random_state or seed parameters or assignments.
stewarthe6 Sep 11, 2024
b41b7d5
fixed merging error
stewarthe6 Sep 11, 2024
b65ba09
Fixed find and replace bug
stewarthe6 Sep 11, 2024
84babd2
make_dc_model does not need random_state or seed arguments
stewarthe6 Sep 11, 2024
ecf23bd
fhnew changes
rwilfong Sep 12, 2024
a821f6a
Changed constructor of ProductionSplitter to call Splitting's init fu…
stewarthe6 Sep 12, 2024
319b2f0
resolving errors
rwilfong Sep 12, 2024
31f3d5f
removed heads
rwilfong Sep 12, 2024
d074f65
removed unused library
rwilfong Sep 12, 2024
b0ecc05
Merge remote-tracking branch 'upstream/1.7.0' into 1.7.0
stewarthe6 Sep 12, 2024
2992bdf
Added more models for seeding test.
stewarthe6 Sep 12, 2024
ccebaed
Fixed seed for GCNModel. Should pass regularly now.
stewarthe6 Sep 12, 2024
dcc4809
Set seed to guarantee resuts in class_config_delaney_fit_nn_ecfp.json
stewarthe6 Sep 12, 2024
922bf0c
Moved 'test' from suffix to prefix
stewarthe6 Sep 18, 2024
82838d1
Renamed these test files to start with test_ so they're caught by the…
stewarthe6 Sep 19, 2024
4e471cb
Changed MultitaskScaffoldSplit and GeneticAlgorithm to use a Generate…
stewarthe6 Sep 19, 2024
baa5478
Added test for MTSS seed and fixed a few cases were the wrong random …
stewarthe6 Sep 19, 2024
4eb4ee4
renamed this file to match wahts in test_seed_splitting.py
stewarthe6 Sep 19, 2024
4588a9d
renamed this to match the test
stewarthe6 Sep 19, 2024
ff58d02
Removed try except blocks in test code. We need to see these errors
stewarthe6 Sep 24, 2024
0028ed7
Added seed to this test so that it passes more consistently
stewarthe6 Sep 24, 2024
0c83b6b
combined_training_data now accounts for synthetic datasets
stewarthe6 Sep 24, 2024
ada3ea8
accept changes
rwilfong Sep 24, 2024
4dd5d99
integrate changes
rwilfong Sep 24, 2024
0a616b2
set uncertainty false for classification test since it is unsupported…
stewarthe6 Sep 24, 2024
16c2a4a
update branchMerge branch '1.7.0' of https://github.com/rwilfong/AMPL…
rwilfong Sep 25, 2024
c3b1922
updated tests
rwilfong Sep 25, 2024
f2a30a9
resolve errors
rwilfong Sep 25, 2024
410f03d
Added seed to test_balancing_transformer for more consistent outputs
stewarthe6 Sep 25, 2024
f247893
added a test to make sure that multitask problems don't work with SMOTE
stewarthe6 Sep 25, 2024
2e03fef
Used parameter to determine if SMOTE or undersampling is being used
stewarthe6 Sep 25, 2024
b48ed02
Added a seed to this test for more consistent results
stewarthe6 Sep 25, 2024
567264a
Changed balancing transformer to just check to see if the weights cha…
stewarthe6 Sep 26, 2024
627cc20
Set the seed to make sure the number of positive and negative compoun…
stewarthe6 Sep 26, 2024
8decc0e
Removed unnecessary loop and printed out results from the perf_data test
stewarthe6 Sep 30, 2024
317cc29
accumulate_preds ignores the id parameter for SimpleRegressionPerfDat…
stewarthe6 Sep 30, 2024
5055889
the positive and negative counts are inconsistent, instead just check…
stewarthe6 Sep 30, 2024
6d0abbd
Merge branch 'ATOMScience-org:1.7.0' into 1.7.0
stewarthe6 Oct 2, 2024
16d50f8
Undo transformations before calculating mean and std of predictions
stewarthe6 Oct 28, 2024
3e58819
Merge branch '1.7.0' of github.com:rwilfong/AMPL into 1.7.0
stewarthe6 Oct 28, 2024
0280941
Removed pdb imports
stewarthe6 Oct 28, 2024
a4c2b83
Updated help for 'seed' input
stewarthe6 Nov 27, 2024
8e29047
Removed commented out seed
stewarthe6 Nov 27, 2024
268ba05
model_retrian has an option to either keep or discard the saved seed.…
stewarthe6 Nov 27, 2024
17ba026
Pass on keep_seed argument
stewarthe6 Nov 27, 2024
b2a0c5a
Looping through all folds is redundant
stewarthe6 Dec 2, 2024
60ed670
Added option to keep the same random seed when retraining a model. De…
stewarthe6 Dec 2, 2024
c5e634f
Move common functions to integrative_utilities
stewarthe6 Dec 3, 2024
36c38ec
Move common functions to integrative_utilities
stewarthe6 Dec 3, 2024
d11ee2c
deleted unused imports
stewarthe6 Dec 3, 2024
a089d1f
moved params to json files
stewarthe6 Dec 3, 2024
271c502
Prevent divide by zero case if the model never learns
stewarthe6 Dec 3, 2024
48635cb
Moved pandas import over to integrative_utilities
stewarthe6 Dec 4, 2024
0c67471
Added a seed here for reproducability
stewarthe6 Dec 4, 2024
524d804
Testing SMOTE and balancing transformer
stewarthe6 Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions atomsci/ddm/docs/PARAMETERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,14 @@ The AMPL pipeline contains many parameters and options to fit models and make pr
|*Description:*|True/False flag for setting verbosity|
|*Default:*|FALSE|
|*Type:*|Bool|

- **seed**

|||
|-|-|
|*Description:*|Seed used for initializing a random number generator to ensure results are reproducible. Default is None and a random seed will be generated.|
|*Default:*|None|
|*Type:*|int|

- **production**

Expand Down Expand Up @@ -529,6 +537,30 @@ the model will train for max_epochs regardless of validation error.|
|*Default:*|scaffold|
|*Type:*|str|

- **sampling_method**

|||
|-|-|
|*Description:*|The sampling method for addressing class imbalance in classification datasets. Options include 'undersampling' and 'SMOTE'.|
|*Default:*|None|
|*Type:*|str|

- **sampling_ratio**

|||
|-|-|
|*Description:*|The desired ratio of the minority class to the majority class after sampling (e.g., if str, 'minority', 'not minority'; if float, '0.2', '1.0'). |
|*Default:*|auto|
|*Type:*|str|

- **sampling_k_neighbors**

|||
|-|-|
|*Description:*|The number of nearest neighbors to consider when generating synthetic samples (e.g., 5, 7, 9). Specifically used for SMOTE sampling method.|
|*Default:*|5|
|*Type:*|int|

- **mtss\_num\_super\_scaffolds**

|||
Expand Down
9 changes: 4 additions & 5 deletions atomsci/ddm/pipeline/model_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import traceback
import sys

from atomsci.ddm.pipeline import random_seed as rs
from collections import defaultdict


Expand Down Expand Up @@ -438,7 +437,7 @@ def get_dataset_tasks(self, dset_df):
return self.tasks is not None

# ****************************************************************************************
def split_dataset(self):
def split_dataset(self, random_state=None, seed=None):
"""Splits the dataset into paired training/validation and test subsets, according to the split strategy
selected by the model params. For traditional train/valid/test splits, there is only one training/validation
pair. For k-fold cross-validation splits, there are k different train/valid pairs; the validation sets are
Expand All @@ -457,7 +456,7 @@ def split_dataset(self):

# Create object to delegate splitting to.
if self.splitting is None:
self.splitting = split.create_splitting(self.params)
self.splitting = split.create_splitting(self.params, random_state=random_state, seed=seed)
self.train_valid_dsets, self.test_dset, self.train_valid_attr, self.test_attr = \
self.splitting.split_dataset(self.dataset, self.attr, self.params.smiles_col)
if self.train_valid_dsets is None:
Expand Down Expand Up @@ -568,7 +567,7 @@ def create_dataset_split_table(self):
return split_df

# ****************************************************************************************
def load_presplit_dataset(self, directory=None):
def load_presplit_dataset(self, directory=None, random_state=None, seed=None):
"""Loads a table of compound IDs assigned to split subsets, and uses them to split
the currently loaded featurized dataset.

Expand All @@ -595,7 +594,7 @@ def load_presplit_dataset(self, directory=None):
"""

# Load the split table from the datastore or filesystem
self.splitting = split.create_splitting(self.params)
self.splitting = split.create_splitting(self.params, random_state=random_state, seed=seed)

try:
split_df, split_kv = self.load_dataset_split_table(directory)
Expand Down
70 changes: 63 additions & 7 deletions atomsci/ddm/pipeline/model_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from atomsci.ddm.pipeline import parameter_parser as parse
from atomsci.ddm.pipeline import model_tracker as trkr
from atomsci.ddm.pipeline import transformations as trans
from atomsci.ddm.pipeline import random_seed as rs
from atomsci.ddm.pipeline import sampling as sample

logging.basicConfig(format='%(asctime)-15s %(message)s')

Expand Down Expand Up @@ -154,7 +156,7 @@ class ModelPipeline:
data (ModelDataset object): A data object that featurizes and splits the dataset
"""

def __init__(self, params, ds_client=None, mlmt_client=None):
def __init__(self, params, ds_client=None, mlmt_client=None, random_state=None, seed=None):
"""Initializes ModelPipeline object.

Args:
Expand Down Expand Up @@ -188,6 +190,23 @@ def __init__(self, params, ds_client=None, mlmt_client=None):
self.log = logging.getLogger('ATOM')
self.run_mode = 'training' # default, can be overridden later
self.start_time = time.time()

# initialize seed
if seed is None:
seed = getattr(params, 'seed', None)
self.random_gen = rs.RandomStateGenerator(params, seed)
self.seed = self.random_gen.get_seed()
else:
# pass the seed into the RandomStateGenerator
self.random_gen = rs.RandomStateGenerator(seed)
self.seed = self.random_gen.get_seed()

if random_state is None:
self.random_state = self.random_gen.get_random_state()
else:
self.random_state = random_state
# log the seed used
self.log.info('Initiating ModelPipeline with seed {}'.format(self.seed))

# Default dataset_name parameter from dataset_key
if params.dataset_name is None:
Expand Down Expand Up @@ -237,7 +256,7 @@ def __init__(self, params, ds_client=None, mlmt_client=None):

# ****************************************************************************************

def load_featurize_data(self, params=None):
def load_featurize_data(self, params=None, random_state=None, seed=None):
"""Loads the dataset from the datastore or the file system and featurizes it. If we are training
a new model, split the dataset into training, validation and test sets.

Expand All @@ -248,6 +267,7 @@ def load_featurize_data(self, params=None):
Args:
params (Namespace): Optional set of parameters to be used for featurization; by default this function
uses the parameters used when the pipeline was created.
seed (int): Optional seed for reproducibility

Side effects:
Sets the following attributes of the ModelPipeline
Expand All @@ -266,10 +286,13 @@ def load_featurize_data(self, params=None):
self.log.info('Training in production mode. Ignoring '
'previous split and creating production split. '
'Production split will not be saved.')
self.data.split_dataset()
elif not (params.previously_split and self.data.load_presplit_dataset()):
self.data.split_dataset()
self.data.split_dataset(random_state=self.random_state, seed=self.seed)
elif not (params.previously_split and self.data.load_presplit_dataset(random_state=self.random_state, seed=self.seed)):
self.data.split_dataset(random_state=self.random_state, seed=self.seed)
self.data.save_split_dataset()
# write split metadata
self.create_split_metadata()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the random seed get saved into model metadata? If you wanted to retrain a model 10 times to see how variable the predictions are, would you end up training a model with the same random seed or 10 different ones?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I will change that.

self.save_split_metadata()
if self.data.params.prediction_type == 'classification':
self.data._validate_classification_dataset()
# We now create transformers after splitting, to allow for the case where the transformer
Expand All @@ -282,6 +305,8 @@ def load_featurize_data(self, params=None):

if self.run_mode == 'training':
for i, (train, valid) in enumerate(self.data.train_valid_dsets):
if self.data.params.prediction_type == 'classification' and self.params.sampling_method is not None:
train = sample.apply_sampling_method(train, params, random_state=self.random_state, seed=self.seed)
train = self.model_wrapper.transform_dataset(train)
valid = self.model_wrapper.transform_dataset(valid)
self.data.train_valid_dsets[i] = (train, valid)
Expand Down Expand Up @@ -342,6 +367,13 @@ def create_model_metadata(self):
hyperparam_uuid=self.params.hyperparam_uuid,
ampl_version=mu.get_ampl_version()
)
# add in sampling method parameters for documentation/reproducibility
if self.params.sampling_method is not None:
model_params['sampling_method'] = self.params.sampling_method
if self.params.sampling_ratio is not None:
model_params['sampling_ratio'] = self.params.sampling_ratio
if self.params.sampling_k_neighbors is not None:
model_params['sampling_k_neighbors'] = self.params.sampling_k_neighbors

splitting_metadata = self.data.get_split_metadata()
model_metadata = dict(
Expand All @@ -360,6 +392,8 @@ def create_model_metadata(self):
model_metadata[key] = data
for key, data in trans.get_transformer_specific_metadata(self.params).items():
model_metadata[key] = data

model_metadata['seed'] = self.seed

self.model_metadata = model_metadata

Expand Down Expand Up @@ -413,6 +447,28 @@ def save_model_metadata(self, retries=5, sleep_sec=60):
trkr.save_model_tarball(self.output_dir, self.params.model_tarball_path)
self.model_wrapper._clean_up_excess_files(self.model_wrapper.model_dir)

# ****************************************************************************************
def create_split_metadata(self):
"""Creates metadata for each split dataset.
It will save the seed used to create the split dataset and relevant parameters."""
self.split_data = dict(
dataset_key = self.params.dataset_key,
id_col = self.params.id_col,
smiles_col = self.params.smiles_col,
response_cols = self.params.response_cols,
seed = self.seed
)
self.splitting_metadata = self.data.get_split_metadata()
self.split_data['splitting_metadata'] = self.splitting_metadata

# ****************************************************************************************
def save_split_metadata(self):
out_file = os.path.join(self.output_dir, 'split_metadata.json')

with open(out_file, 'w') as out:
json.dump(self.split_data, out, sort_keys=True, indent=4, separators=(',', ': '))
out.write("\n")

# ****************************************************************************************
def create_prediction_metadata(self, prediction_results):
"""Initializes a data structure to hold performance metrics from a model run on a new dataset,
Expand Down Expand Up @@ -540,7 +596,7 @@ def split_dataset(self, featurization=None):

# ****************************************************************************************

def train_model(self, featurization=None):
def train_model(self, featurization=None, random_state=None, seed=None):
"""Build model described by self.params on the training dataset described by self.params.

Generate predictions for the training, validation, and test datasets, and save the predictions and
Expand Down Expand Up @@ -574,7 +630,7 @@ def train_model(self, featurization=None):

## create model wrapper if not split_only
if not self.params.split_only:
self.model_wrapper = model_wrapper.create_model_wrapper(self.params, self.featurization, self.ds_client)
self.model_wrapper = model_wrapper.create_model_wrapper(self.params, self.featurization, self.ds_client, random_state=self.random_state, seed=self.seed)
self.model_wrapper.setup_model_dirs()

self.load_featurize_data()
Expand Down
Loading