From 38df9167e3fa21e66064a620c2988eb078031a2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Aug 2024 20:55:58 +0000 Subject: [PATCH] [pre-commit.ci] Add auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- odyssey/data/processor.py | 12 ++++++++---- odyssey/models/ehr_mamba/mamba_utils.py | 6 ++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/odyssey/data/processor.py b/odyssey/data/processor.py index 9564b1e..e4c7d50 100644 --- a/odyssey/data/processor.py +++ b/odyssey/data/processor.py @@ -529,13 +529,17 @@ def get_pretrain_test_split( test_size=test_size, ) else: - test_patients = dataset.sample(n=int(test_size * len(dataset)), random_state=seed) + test_patients = dataset.sample( + n=int(test_size * len(dataset)), random_state=seed + ) test_ids = test_patients["patient_id"].tolist() - pretrain_ids = dataset[~dataset["patient_id"].isin(test_ids)]["patient_id"].tolist() - + pretrain_ids = dataset[~dataset["patient_id"].isin(test_ids)][ + "patient_id" + ].tolist() + random.seed(seed) random.shuffle(pretrain_ids) - + return pretrain_ids, test_ids diff --git a/odyssey/models/ehr_mamba/mamba_utils.py b/odyssey/models/ehr_mamba/mamba_utils.py index de227f6..bcd3645 100644 --- a/odyssey/models/ehr_mamba/mamba_utils.py +++ b/odyssey/models/ehr_mamba/mamba_utils.py @@ -112,7 +112,8 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - Returns: + Returns + ------- """ if inputs_embeds is not None: sequence_outputs = self.backbone( @@ -222,7 +223,8 @@ def forward( task_indices (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Task indices to specify which classification head to use for each example in the batch. - Returns: + Returns + ------- """ if inputs_embeds is not None: sequence_outputs = self.backbone(