From 9a71776d68ef168e0e8375b46ee66498cd748f43 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:25:32 -0400 Subject: [PATCH 1/9] Split `ci.yml` into separate actions Split the "unit-tests" action into multiple actions, currently one for each package contained within the `flepiMoP` repo. Also updated checkout from v3 to v4 to address node16 deprecation warnings and swapped ubuntu 20.04 for ubuntu latest. Changed the gempyor ci to not print stdout and exit on first failure. --- .github/workflows/ci.yml | 70 ---------------------------- .github/workflows/flepicommon-ci.yml | 40 ++++++++++++++++ .github/workflows/gempyor-ci.yml | 45 ++++++++++++++++++ .github/workflows/inference-ci.yml | 40 ++++++++++++++++ 4 files changed, 125 insertions(+), 70 deletions(-) delete mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/flepicommon-ci.yml create mode 100644 .github/workflows/gempyor-ci.yml create mode 100644 .github/workflows/inference-ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 29e9a186d..000000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,70 +0,0 @@ -name: unit-tests - -on: - workflow_dispatch: - push: - branches: - - main - - dev - pull_request: - branches: - - main - - dev - - breaking-improvments - -jobs: - unit-tests: - runs-on: ubuntu-20.04 - container: - image: hopkinsidd/flepimop:latest-dev - options: --user root - steps: - - name: Checkout - uses: actions/checkout@v3 - with: - lfs: true - - name: Set up Rprofile - run: | - cp build/docker/Docker.Rprofile $HOME/.Rprofile - cp /home/app/.bashrc $HOME/.bashrc - shell: bash - - name: Install the gempyor package - run: | - source /var/python/3.10/virtualenv/bin/activate - python -m pip install --upgrade pip - python -m pip install "flepimop/gempyor_pkg[test]" - shell: bash - - name: Install local R packages - run: Rscript build/local_install.R - shell: bash - - name: Run gempyor tests - run: | - source /var/python/3.10/virtualenv/bin/activate - cd flepimop/gempyor_pkg - pytest -s - shell: bash - - name: Run gempyor-cli integration tests from examples - run: | - source /var/python/3.10/virtualenv/bin/activate - cd examples - pytest -s - shell: bash - - name: Run flepicommon tests - run: | - setwd("flepimop/R_packages/flepicommon") - devtools::test(stop_on_failure=TRUE) - shell: Rscript {0} - - name: Run inference tests - run: | - setwd("flepimop/R_packages/inference") - devtools::test(stop_on_failure=TRUE) - shell: Rscript {0} -# - name: Run integration tests -# env: -# CENSUS_API_KEY: ${{ secrets.CENSUS_API_KEY }} -# run: | -# Rscript build/local_install.R -# cd test -# source /var/python/3.10/virtualenv/bin/activate -# pytest run_tests.py -# shell: bash diff --git a/.github/workflows/flepicommon-ci.yml b/.github/workflows/flepicommon-ci.yml new file mode 100644 index 000000000..b7c0cc5f3 --- /dev/null +++ b/.github/workflows/flepicommon-ci.yml @@ -0,0 +1,40 @@ +name: flepicommon-ci + +on: + workflow_dispatch: + push: + paths: + - flepimop/R_packages/flepicommon/**/* + branches: + - main + - dev + pull_request: + branches: + - main + - dev + - breaking-improvements + +jobs: + unit-tests: + runs-on: ubuntu-latest + container: + image: hopkinsidd/flepimop:latest-dev + options: --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Rprofile + run: | + cp build/docker/Docker.Rprofile $HOME/.Rprofile + cp /home/app/.bashrc $HOME/.bashrc + shell: bash + - name: Install local R packages + run: Rscript build/local_install.R + shell: bash + - name: Run flepicommon tests + run: | + setwd("flepimop/R_packages/flepicommon") + devtools::test(stop_on_failure=TRUE) + shell: Rscript {0} diff --git a/.github/workflows/gempyor-ci.yml b/.github/workflows/gempyor-ci.yml new file mode 100644 index 000000000..d9c9b1133 --- /dev/null +++ b/.github/workflows/gempyor-ci.yml @@ -0,0 +1,45 @@ +name: gempyor-ci + +on: + workflow_dispatch: + push: + paths: + - examples/**/* + - flepimop/gempyor_pkg/**/* + branches: + - main + - dev + pull_request: + branches: + - main + - dev + - breaking-improvements + +jobs: + unit-tests: + runs-on: ubuntu-latest + container: + image: hopkinsidd/flepimop:latest-dev + options: --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + - name: Install the gempyor package + run: | + source /var/python/3.10/virtualenv/bin/activate + python -m pip install --upgrade pip + python -m pip install "flepimop/gempyor_pkg[test]" + shell: bash + - name: Run gempyor tests + run: | + source /var/python/3.10/virtualenv/bin/activate + cd flepimop/gempyor_pkg + pytest --exitfirst + shell: bash + - name: Run gempyor-cli integration tests from examples + run: | + source /var/python/3.10/virtualenv/bin/activate + cd examples + pytest --exitfirst diff --git a/.github/workflows/inference-ci.yml b/.github/workflows/inference-ci.yml new file mode 100644 index 000000000..f04e34594 --- /dev/null +++ b/.github/workflows/inference-ci.yml @@ -0,0 +1,40 @@ +name: inference-ci + +on: + workflow_dispatch: + push: + paths: + - flepimop/R_packages/inference/**/* + branches: + - main + - dev + pull_request: + branches: + - main + - dev + - breaking-improvements + +jobs: + unit-tests: + runs-on: ubuntu-latest + container: + image: hopkinsidd/flepimop:latest-dev + options: --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Rprofile + run: | + cp build/docker/Docker.Rprofile $HOME/.Rprofile + cp /home/app/.bashrc $HOME/.bashrc + shell: bash + - name: Install local R packages + run: Rscript build/local_install.R + shell: bash + - name: Run inference tests + run: | + setwd("inference/R_packages/inference") + devtools::test(stop_on_failure=TRUE) + shell: Rscript {0} From e3adcbab3416a4aaf98fc24eb63564f78d8f4a1b Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:44:43 -0400 Subject: [PATCH 2/9] Correct working dir in `inference-ci.yml` Typo in `setwd` call causes error about not being able to change to directory that doesn't exist. --- .github/workflows/inference-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/inference-ci.yml b/.github/workflows/inference-ci.yml index f04e34594..ed6698dde 100644 --- a/.github/workflows/inference-ci.yml +++ b/.github/workflows/inference-ci.yml @@ -35,6 +35,6 @@ jobs: shell: bash - name: Run inference tests run: | - setwd("inference/R_packages/inference") + setwd("flepimop/R_packages/inference") devtools::test(stop_on_failure=TRUE) shell: Rscript {0} From 240b25773a1205400e836811df9cbe7a84a286c3 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:52:17 -0400 Subject: [PATCH 3/9] Set `gempyor` integration tests shell Set the shell to bash so the `source` function is available. --- .github/workflows/gempyor-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/gempyor-ci.yml b/.github/workflows/gempyor-ci.yml index d9c9b1133..ce26a5750 100644 --- a/.github/workflows/gempyor-ci.yml +++ b/.github/workflows/gempyor-ci.yml @@ -43,3 +43,4 @@ jobs: source /var/python/3.10/virtualenv/bin/activate cd examples pytest --exitfirst + shell: bash From e49109acdd6cfa99fdadaa260c34bfaee6bd6ea9 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:58:29 -0400 Subject: [PATCH 4/9] Limit paths for PRs in actions Add the same path related limits from the on push to on pull_requests as well. --- .github/workflows/flepicommon-ci.yml | 2 ++ .github/workflows/gempyor-ci.yml | 3 +++ .github/workflows/inference-ci.yml | 2 ++ 3 files changed, 7 insertions(+) diff --git a/.github/workflows/flepicommon-ci.yml b/.github/workflows/flepicommon-ci.yml index b7c0cc5f3..da1f07ba6 100644 --- a/.github/workflows/flepicommon-ci.yml +++ b/.github/workflows/flepicommon-ci.yml @@ -9,6 +9,8 @@ on: - main - dev pull_request: + paths: + - flepimop/R_packages/flepicommon/**/* branches: - main - dev diff --git a/.github/workflows/gempyor-ci.yml b/.github/workflows/gempyor-ci.yml index ce26a5750..4f93d7250 100644 --- a/.github/workflows/gempyor-ci.yml +++ b/.github/workflows/gempyor-ci.yml @@ -10,6 +10,9 @@ on: - main - dev pull_request: + paths: + - examples/**/* + - flepimop/gempyor_pkg/**/* branches: - main - dev diff --git a/.github/workflows/inference-ci.yml b/.github/workflows/inference-ci.yml index ed6698dde..c708b8a4b 100644 --- a/.github/workflows/inference-ci.yml +++ b/.github/workflows/inference-ci.yml @@ -9,6 +9,8 @@ on: - main - dev pull_request: + paths: + - flepimop/R_packages/inference/**/* branches: - main - dev From 49475afa8627b49adf5fe36275b2aaf34d11fc4d Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Tue, 6 Aug 2024 17:07:29 -0400 Subject: [PATCH 5/9] Remove 'breaking-improvements' branch Removed the `breaking-improvements` branch from special consideration in GitHub actions. --- .github/workflows/flepicommon-ci.yml | 1 - .github/workflows/gempyor-ci.yml | 1 - .github/workflows/inference-ci.yml | 1 - 3 files changed, 3 deletions(-) diff --git a/.github/workflows/flepicommon-ci.yml b/.github/workflows/flepicommon-ci.yml index da1f07ba6..5314c1b4f 100644 --- a/.github/workflows/flepicommon-ci.yml +++ b/.github/workflows/flepicommon-ci.yml @@ -14,7 +14,6 @@ on: branches: - main - dev - - breaking-improvements jobs: unit-tests: diff --git a/.github/workflows/gempyor-ci.yml b/.github/workflows/gempyor-ci.yml index 4f93d7250..a2cb6e313 100644 --- a/.github/workflows/gempyor-ci.yml +++ b/.github/workflows/gempyor-ci.yml @@ -16,7 +16,6 @@ on: branches: - main - dev - - breaking-improvements jobs: unit-tests: diff --git a/.github/workflows/inference-ci.yml b/.github/workflows/inference-ci.yml index c708b8a4b..2ca3d4897 100644 --- a/.github/workflows/inference-ci.yml +++ b/.github/workflows/inference-ci.yml @@ -14,7 +14,6 @@ on: branches: - main - dev - - breaking-improvements jobs: unit-tests: From 01b2ba4c964debf9afe6739585aba108595b39b4 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 7 Aug 2024 08:41:06 -0400 Subject: [PATCH 6/9] Add code linting GitHub action Added black python formatter to check all python files in `flepiMoP` repo on push/PR to main/dev that edits a `.py` file. Left the action general enough to be updated for future linting additions. --- .github/workflows/code-linting-ci.yml | 32 +++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/code-linting-ci.yml diff --git a/.github/workflows/code-linting-ci.yml b/.github/workflows/code-linting-ci.yml new file mode 100644 index 000000000..26c3f51be --- /dev/null +++ b/.github/workflows/code-linting-ci.yml @@ -0,0 +1,32 @@ +name: Code Linting + +on: + workflow_dispatch: + push: + paths: + - '**/*.py' + branches: + - main + - dev + pull_request: + paths: + - '**/*.py' + branches: + - main + - dev + +jobs: + unit-tests: + runs-on: ubuntu-latest + container: + image: hopkinsidd/flepimop:latest-dev + options: --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + - name: Black Formatter Check + uses: psf/black@stable + src: "." + options: "--check --quiet" From da76660b8a258c428987010e61e2795a7b3cdbd0 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 7 Aug 2024 08:55:56 -0400 Subject: [PATCH 7/9] Bulk format python files with black Corrected formatting in all `.py` files using black formatter with command `black .` and now passes `black --check .`. --- batch/inference_job_launcher.py | 227 ++++-- batch/scenario_job.py | 15 +- examples/test_cli.py | 69 +- .../src/gempyor/NPI/MultiPeriodModifier.py | 72 +- .../src/gempyor/NPI/SinglePeriodModifier.py | 79 +- .../src/gempyor/NPI/StackedModifier.py | 29 +- .../gempyor_pkg/src/gempyor/NPI/helpers.py | 16 +- flepimop/gempyor_pkg/src/gempyor/calibrate.py | 20 +- .../gempyor_pkg/src/gempyor/compartments.py | 371 ++++++--- .../src/gempyor/config_validator.py | 362 ++++++--- .../gempyor_pkg/src/gempyor/dev/dev_seir.py | 15 +- flepimop/gempyor_pkg/src/gempyor/dev/steps.py | 744 +++++++++++++----- flepimop/gempyor_pkg/src/gempyor/inference.py | 207 +++-- .../src/gempyor/inference_parameter.py | 22 +- .../src/gempyor/initial_conditions.py | 90 ++- flepimop/gempyor_pkg/src/gempyor/logloss.py | 64 +- .../gempyor_pkg/src/gempyor/model_info.py | 94 ++- flepimop/gempyor_pkg/src/gempyor/outcomes.py | 210 +++-- .../gempyor_pkg/src/gempyor/parameters.py | 68 +- .../src/gempyor/postprocess_inference.py | 57 +- flepimop/gempyor_pkg/src/gempyor/seeding.py | 44 +- flepimop/gempyor_pkg/src/gempyor/seir.py | 103 ++- flepimop/gempyor_pkg/src/gempyor/simulate.py | 20 +- .../gempyor_pkg/src/gempyor/statistics.py | 22 +- flepimop/gempyor_pkg/src/gempyor/steps_rk4.py | 157 ++-- .../gempyor_pkg/src/gempyor/steps_source.py | 101 ++- .../src/gempyor/subpopulation_structure.py | 39 +- flepimop/gempyor_pkg/src/gempyor/utils.py | 77 +- .../tests/npi/test_SinglePeriodModifier.py | 10 +- flepimop/gempyor_pkg/tests/npi/test_npis.py | 131 ++- .../tests/outcomes/make_seir_test_file.py | 8 +- .../tests/outcomes/test_outcomes.py | 721 ++++++++++++----- .../gempyor_pkg/tests/seir/dev_new_test0.py | 4 +- .../tests/seir/test_compartments.py | 24 +- flepimop/gempyor_pkg/tests/seir/test_ic.py | 16 +- .../gempyor_pkg/tests/seir/test_parameters.py | 32 +- flepimop/gempyor_pkg/tests/seir/test_seir.py | 231 ++++-- .../tests/seir/test_subpopulationstructure.py | 62 +- .../tests/utils/test_get_log_normal.py | 14 +- .../tests/utils/test_get_truncated_normal.py | 10 +- .../gempyor_pkg/tests/utils/test_utils.py | 64 +- .../gempyor_pkg/tests/utils/test_utils2.py | 8 +- postprocessing/postprocess_auto.py | 92 ++- utilities/clean_s3.py | 4 +- utilities/copy_for_continuation.py | 8 +- utilities/prune_by_llik.py | 24 +- utilities/prune_by_llik_and_proj.py | 60 +- 47 files changed, 3634 insertions(+), 1283 deletions(-) diff --git a/batch/inference_job_launcher.py b/batch/inference_job_launcher.py index a8b17ffb0..7b7e2908e 100755 --- a/batch/inference_job_launcher.py +++ b/batch/inference_job_launcher.py @@ -322,7 +322,9 @@ def launch_batch( # TODO: does this really save the config file? if "inference" in config: config["inference"]["iterations_per_slot"] = sims_per_job - if not os.path.exists(pathlib.Path(data_path, config["inference"]["gt_data_path"])): + if not os.path.exists( + pathlib.Path(data_path, config["inference"]["gt_data_path"]) + ): print( f"ERROR: inference.data_path path {pathlib.Path(data_path, config['inference']['gt_data_path'])} does not exist!" ) @@ -403,23 +405,37 @@ def launch_batch( if "scenarios" in config["outcome_modifiers"]: outcome_modifiers_scenarios = config["outcome_modifiers"]["scenarios"] - handler.launch(job_name, config_filepath, seir_modifiers_scenarios, outcome_modifiers_scenarios) + handler.launch( + job_name, config_filepath, seir_modifiers_scenarios, outcome_modifiers_scenarios + ) # Set job_name as environmental variable so it can be pulled for pushing to git os.environ["job_name"] = job_name # Set run_id as environmental variable so it can be pulled for pushing to git TODO - (rc, txt) = subprocess.getstatusoutput(f"git checkout -b run_{job_name}") # TODO: cd ... + (rc, txt) = subprocess.getstatusoutput( + f"git checkout -b run_{job_name}" + ) # TODO: cd ... print(txt) return rc -def autodetect_params(config, data_path, *, num_jobs=None, sims_per_job=None, num_blocks=None, batch_system=None): +def autodetect_params( + config, + data_path, + *, + num_jobs=None, + sims_per_job=None, + num_blocks=None, + batch_system=None, +): if num_jobs and sims_per_job and num_blocks: return (num_jobs, sims_per_job, num_blocks) if "inference" not in config or "iterations_per_slot" not in config["inference"]: - raise click.UsageError("inference::iterations_per_slot undefined in config, can't autodetect parameters") + raise click.UsageError( + "inference::iterations_per_slot undefined in config, can't autodetect parameters" + ) iterations_per_slot = int(config["inference"]["iterations_per_slot"]) if num_jobs is None: @@ -429,11 +445,17 @@ def autodetect_params(config, data_path, *, num_jobs=None, sims_per_job=None, nu if sims_per_job is None: if num_blocks is not None: sims_per_job = int(math.ceil(iterations_per_slot / num_blocks)) - print(f"Setting number of blocks to {num_blocks} [via num_blocks (-k) argument]") - print(f"Setting sims per job to {sims_per_job} [via {iterations_per_slot} iterations_per_slot in config]") + print( + f"Setting number of blocks to {num_blocks} [via num_blocks (-k) argument]" + ) + print( + f"Setting sims per job to {sims_per_job} [via {iterations_per_slot} iterations_per_slot in config]" + ) else: if "data_path" in config: - raise ValueError("The config has a data_path section. This is no longer supported.") + raise ValueError( + "The config has a data_path section. This is no longer supported." + ) geodata_fname = pathlib.Path(data_path) / config["subpop_setup"]["geodata"] with open(geodata_fname) as geodata_fp: num_subpops = sum(1 for line in geodata_fp) @@ -458,7 +480,9 @@ def autodetect_params(config, data_path, *, num_jobs=None, sims_per_job=None, nu if num_blocks is None: num_blocks = int(math.ceil(iterations_per_slot / sims_per_job)) - print(f"Setting number of blocks to {num_blocks} [via {iterations_per_slot} iterations_per_slot in config]") + print( + f"Setting number of blocks to {num_blocks} [via {iterations_per_slot} iterations_per_slot in config]" + ) return (num_jobs, sims_per_job, num_blocks) @@ -472,13 +496,17 @@ def get_aws_job_queues(job_queue_prefix): for q in resp["jobQueues"]: queue_name = q["jobQueueName"] if queue_name.startswith(job_queue_prefix): - job_list_resp = batch_client.list_jobs(jobQueue=queue_name, jobStatus="PENDING") + job_list_resp = batch_client.list_jobs( + jobQueue=queue_name, jobStatus="PENDING" + ) queues_with_jobs[queue_name] = len(job_list_resp["jobSummaryList"]) # Return the least-loaded queues first return sorted(queues_with_jobs, key=queues_with_jobs.get) -def aws_countfiles_autodetect_runid(s3_bucket, restart_from_location, restart_from_run_id, num_jobs, strict=False): +def aws_countfiles_autodetect_runid( + s3_bucket, restart_from_location, restart_from_run_id, num_jobs, strict=False +): import boto3 s3 = boto3.resource("s3") @@ -487,15 +515,24 @@ def aws_countfiles_autodetect_runid(s3_bucket, restart_from_location, restart_fr all_files = list(bucket.objects.filter(Prefix=prefix)) all_files = [f.key for f in all_files] if restart_from_run_id is None: - print("WARNING: no --restart_from_run_id specified, autodetecting... please wait querying S3 👀🔎...") + print( + "WARNING: no --restart_from_run_id specified, autodetecting... please wait querying S3 👀🔎..." + ) restart_from_run_id = all_files[0].split("/")[3] - if user_confirmation(question=f"Auto-detected run_id {restart_from_run_id}. Correct ?", default=True): + if user_confirmation( + question=f"Auto-detected run_id {restart_from_run_id}. Correct ?", + default=True, + ): print(f"great, continuing with run_id {restart_from_run_id}...") else: - raise ValueError(f"Abording, please specify --restart_from_run_id manually.") + raise ValueError( + f"Abording, please specify --restart_from_run_id manually." + ) final_llik = [f for f in all_files if ("llik" in f) and ("final" in f)] - if len(final_llik) == 0: # hacky: there might be a bucket with no llik files, e.g if init. + if ( + len(final_llik) == 0 + ): # hacky: there might be a bucket with no llik files, e.g if init. final_llik = [f for f in all_files if ("init" in f) and ("final" in f)] if len(final_llik) != num_jobs: @@ -583,8 +620,12 @@ def build_job_metadata(self, job_name): manifest = {} manifest["cmd"] = " ".join(sys.argv[:]) manifest["job_name"] = job_name - manifest["data_sha"] = subprocess.getoutput("cd {self.data_path}; git rev-parse HEAD") - manifest["flepimop_sha"] = subprocess.getoutput(f"cd {self.flepi_path}; git rev-parse HEAD") + manifest["data_sha"] = subprocess.getoutput( + "cd {self.data_path}; git rev-parse HEAD" + ) + manifest["flepimop_sha"] = subprocess.getoutput( + f"cd {self.flepi_path}; git rev-parse HEAD" + ) # Save the manifest file to S3 with open("manifest.json", "w") as f: @@ -594,17 +635,27 @@ def build_job_metadata(self, job_name): # need these to be uploaded so they can be executed. this_file_path = os.path.dirname(os.path.realpath(__file__)) self.save_file( - source=os.path.join(this_file_path, "AWS_inference_runner.sh"), destination=f"{job_name}-runner.sh" + source=os.path.join(this_file_path, "AWS_inference_runner.sh"), + destination=f"{job_name}-runner.sh", ) self.save_file( - source=os.path.join(this_file_path, "AWS_inference_copy.sh"), destination=f"{job_name}-copy.sh" + source=os.path.join(this_file_path, "AWS_inference_copy.sh"), + destination=f"{job_name}-copy.sh", ) tarfile_name = f"{job_name}.tar.gz" self.tar_working_dir(tarfile_name=tarfile_name) - self.save_file(source=tarfile_name, destination=f"{job_name}.tar.gz", remove_source=True) + self.save_file( + source=tarfile_name, + destination=f"{job_name}.tar.gz", + remove_source=True, + ) - self.save_file(source="manifest.json", destination=f"{job_name}/manifest.json", remove_source=True) + self.save_file( + source="manifest.json", + destination=f"{job_name}/manifest.json", + remove_source=True, + ) def tar_working_dir(self, tarfile_name): # this tar file always has the structure: @@ -616,10 +667,14 @@ def tar_working_dir(self, tarfile_name): or q == "covid-dashboard-app" or q == "renv.cache" or q == "sample_data" - or q == "renv" # joseph: I added this to fix a bug, hopefully it doesn't break anything + or q + == "renv" # joseph: I added this to fix a bug, hopefully it doesn't break anything or q.startswith(".") ): - tar.add(os.path.join(self.flepi_path, q), arcname=os.path.join("flepiMoP", q)) + tar.add( + os.path.join(self.flepi_path, q), + arcname=os.path.join("flepiMoP", q), + ) elif q == "sample_data": for r in os.listdir(os.path.join(self.flepi_path, "sample_data")): if r != "united-states-commutes": @@ -629,10 +684,17 @@ def tar_working_dir(self, tarfile_name): ) # tar.add(os.path.join("flepiMoP", "sample_data", r)) for p in os.listdir(self.data_path): - if not (p.startswith(".") or p.endswith("tar.gz") or p in self.outputs or p == "flepiMoP"): + if not ( + p.startswith(".") + or p.endswith("tar.gz") + or p in self.outputs + or p == "flepiMoP" + ): tar.add( p, - filter=lambda x: None if os.path.basename(x.name).startswith(".") else x, + filter=lambda x: ( + None if os.path.basename(x.name).startswith(".") else x + ), ) tar.close() @@ -644,7 +706,9 @@ def save_file(self, source, destination, remove_source=False, prefix=""): import boto3 s3_client = boto3.client("s3") - s3_client.upload_file(source, self.s3_bucket, os.path.join(prefix, destination)) + s3_client.upload_file( + source, self.s3_bucket, os.path.join(prefix, destination) + ) if self.batch_system == "slurm": import shutil @@ -656,7 +720,13 @@ def save_file(self, source, destination, remove_source=False, prefix=""): if remove_source: os.remove(source) - def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_modifiers_scenarios): + def launch( + self, + job_name, + config_filepath, + seir_modifiers_scenarios, + outcome_modifiers_scenarios, + ): s3_results_path = f"s3://{self.s3_bucket}/{job_name}" if self.batch_system == "slurm": @@ -676,7 +746,10 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo ## TODO: check how each of these variables are used downstream base_env_vars = [ {"name": "BATCH_SYSTEM", "value": self.batch_system}, - {"name": "S3_MODEL_PROJECT_PATH", "value": f"s3://{self.s3_bucket}/{job_name}.tar.gz"}, + { + "name": "S3_MODEL_PROJECT_PATH", + "value": f"s3://{self.s3_bucket}/{job_name}.tar.gz", + }, {"name": "DVC_OUTPUTS", "value": " ".join(self.outputs)}, {"name": "S3_RESULTS_PATH", "value": s3_results_path}, {"name": "FS_RESULTS_PATH", "value": fs_results_path}, @@ -700,14 +773,22 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo }, {"name": "FLEPI_STOCHASTIC_RUN", "value": str(self.stochastic)}, {"name": "FLEPI_RESET_CHIMERICS", "value": str(self.reset_chimerics)}, - {"name": "FLEPI_MEM_PROFILE", "value": str(os.getenv("FLEPI_MEM_PROFILE", default="FALSE"))}, - {"name": "FLEPI_MEM_PROF_ITERS", "value": str(os.getenv("FLEPI_MEM_PROF_ITERS", default="50"))}, + { + "name": "FLEPI_MEM_PROFILE", + "value": str(os.getenv("FLEPI_MEM_PROFILE", default="FALSE")), + }, + { + "name": "FLEPI_MEM_PROF_ITERS", + "value": str(os.getenv("FLEPI_MEM_PROF_ITERS", default="50")), + }, {"name": "SLACK_CHANNEL", "value": str(self.slack_channel)}, ] with open(config_filepath) as f: config = yaml.full_load(f) - for ctr, (s, d) in enumerate(itertools.product(seir_modifiers_scenarios, outcome_modifiers_scenarios)): + for ctr, (s, d) in enumerate( + itertools.product(seir_modifiers_scenarios, outcome_modifiers_scenarios) + ): cur_job_name = f"{job_name}_{s}_{d}" # Create first job cur_env_vars = base_env_vars.copy() @@ -719,7 +800,12 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo cur_env_vars.append({"name": "FLEPI_BLOCK_INDEX", "value": "1"}) cur_env_vars.append({"name": "FLEPI_RUN_INDEX", "value": f"{self.run_id}"}) if not (self.restart_from_location is None): - cur_env_vars.append({"name": "LAST_JOB_OUTPUT", "value": f"{self.restart_from_location}"}) + cur_env_vars.append( + { + "name": "LAST_JOB_OUTPUT", + "value": f"{self.restart_from_location}", + } + ) cur_env_vars.append( { "name": "OLD_FLEPI_RUN_INDEX", @@ -732,8 +818,18 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo if self.continuation: cur_env_vars.append({"name": "FLEPI_CONTINUATION", "value": f"TRUE"}) - cur_env_vars.append({"name": "FLEPI_CONTINUATION_RUN_ID", "value": f"{self.continuation_run_id}"}) - cur_env_vars.append({"name": "FLEPI_CONTINUATION_LOCATION", "value": f"{self.continuation_location}"}) + cur_env_vars.append( + { + "name": "FLEPI_CONTINUATION_RUN_ID", + "value": f"{self.continuation_run_id}", + } + ) + cur_env_vars.append( + { + "name": "FLEPI_CONTINUATION_LOCATION", + "value": f"{self.continuation_location}", + } + ) cur_env_vars.append( { "name": "FLEPI_CONTINUATION_FTYPE", @@ -743,7 +839,9 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo # First job: if self.batch_system == "aws": - cur_env_vars.append({"name": "JOB_NAME", "value": f"{cur_job_name}_block0"}) + cur_env_vars.append( + {"name": "JOB_NAME", "value": f"{cur_job_name}_block0"} + ) runner_script_path = f"s3://{self.s3_bucket}/{job_name}-runner.sh" s3_cp_run_script = f"aws s3 cp {runner_script_path} $PWD/run-flepimop-inference" # line to copy the runner script in wd as ./run-covid-pipeline command = [ @@ -814,7 +912,9 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo postprod_command, command_name="sbatch postprod", fail_on_fail=True ) postprod_job_id = stdout.decode().split(" ")[-1][:-1] - print(f">>> SUCCESS SCHEDULING POST-PROCESSING JOB. Slurm job id is {postprod_job_id}") + print( + f">>> SUCCESS SCHEDULING POST-PROCESSING JOB. Slurm job id is {postprod_job_id}" + ) elif self.batch_system == "local": cur_env_vars.append({"name": "JOB_NAME", "value": f"{cur_job_name}"}) @@ -831,12 +931,27 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo cur_env_vars = base_env_vars.copy() cur_env_vars.append({"name": "FLEPI_SEIR_SCENARIOS", "value": s}) cur_env_vars.append({"name": "FLEPI_OUTCOME_SCENARIOS", "value": d}) - cur_env_vars.append({"name": "FLEPI_PREFIX", "value": f"{config['name']}_{s}_{d}"}) - cur_env_vars.append({"name": "FLEPI_BLOCK_INDEX", "value": f"{block_idx+1}"}) - cur_env_vars.append({"name": "FLEPI_RUN_INDEX", "value": f"{self.run_id}"}) - cur_env_vars.append({"name": "OLD_FLEPI_RUN_INDEX", "value": f"{self.run_id}"}) - cur_env_vars.append({"name": "LAST_JOB_OUTPUT", "value": f"{s3_results_path}/"}) - cur_env_vars.append({"name": "JOB_NAME", "value": f"{cur_job_name}_block{block_idx}"}) + cur_env_vars.append( + {"name": "FLEPI_PREFIX", "value": f"{config['name']}_{s}_{d}"} + ) + cur_env_vars.append( + {"name": "FLEPI_BLOCK_INDEX", "value": f"{block_idx+1}"} + ) + cur_env_vars.append( + {"name": "FLEPI_RUN_INDEX", "value": f"{self.run_id}"} + ) + cur_env_vars.append( + {"name": "OLD_FLEPI_RUN_INDEX", "value": f"{self.run_id}"} + ) + cur_env_vars.append( + {"name": "LAST_JOB_OUTPUT", "value": f"{s3_results_path}/"} + ) + cur_env_vars.append( + { + "name": "JOB_NAME", + "value": f"{cur_job_name}_block{block_idx}", + } + ) cur_job = batch_client.submit_job( jobName=f"{cur_job_name}_block{block_idx}", jobQueue=cur_job_queue, @@ -862,7 +977,9 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo ] copy_script_path = f"s3://{self.s3_bucket}/{job_name}-copy.sh" - s3_cp_run_script = f"aws s3 cp {copy_script_path} $PWD/run-flepimop-copy" + s3_cp_run_script = ( + f"aws s3 cp {copy_script_path} $PWD/run-flepimop-copy" + ) cp_command = [ "sh", "-c", @@ -895,21 +1012,33 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo em = "" if self.resume_discard_seeding: em = f", discarding seeding results." - print(f" >> Resuming from run id is {self.restart_from_run_id} located in {self.restart_from_location}{em}") + print( + f" >> Resuming from run id is {self.restart_from_run_id} located in {self.restart_from_location}{em}" + ) if self.batch_system == "aws": print(f" >> Final output will be: {s3_results_path}/model_output/") elif self.batch_system == "slurm": print(f" >> Final output will be: {fs_results_path}/model_output/") if self.s3_upload: - print(f" >> Final output will be uploaded to {s3_results_path}/model_output/") + print( + f" >> Final output will be uploaded to {s3_results_path}/model_output/" + ) if self.continuation: - print(f" >> Continuing from run id is {self.continuation_run_id} located in {self.continuation_location}") + print( + f" >> Continuing from run id is {self.continuation_run_id} located in {self.continuation_location}" + ) print(f" >> Run id is {self.run_id}") print(f" >> config is {config_filepath.split('/')[-1]}") - flepimop_branch = subprocess.getoutput(f"cd {self.flepi_path}; git rev-parse --abbrev-ref HEAD") - data_branch = subprocess.getoutput(f"cd {self.data_path}; git rev-parse --abbrev-ref HEAD") + flepimop_branch = subprocess.getoutput( + f"cd {self.flepi_path}; git rev-parse --abbrev-ref HEAD" + ) + data_branch = subprocess.getoutput( + f"cd {self.data_path}; git rev-parse --abbrev-ref HEAD" + ) data_hash = subprocess.getoutput(f"cd {self.data_path}; git rev-parse HEAD") - flepimop_hash = subprocess.getoutput(f"cd {self.flepi_path}; git rev-parse HEAD") + flepimop_hash = subprocess.getoutput( + f"cd {self.flepi_path}; git rev-parse HEAD" + ) print(f""" >> FLEPIMOP branch is {flepimop_branch} with hash {flepimop_hash}""") print(f""" >> DATA branch is {data_branch} with hash {data_hash}""") print(f" ------------------------- END -------------------------") diff --git a/batch/scenario_job.py b/batch/scenario_job.py index 1961974eb..4d94460b4 100755 --- a/batch/scenario_job.py +++ b/batch/scenario_job.py @@ -196,13 +196,20 @@ def launch_job_inner( tarfile_name = f"{job_name}.tar.gz" tar = tarfile.open(tarfile_name, "w:gz") for p in os.listdir("."): - if not (p.startswith(".") or p.endswith("tar.gz") or p in dvc_outputs or p == "batch"): + if not ( + p.startswith(".") + or p.endswith("tar.gz") + or p in dvc_outputs + or p == "batch" + ): tar.add(p, filter=lambda x: None if x.name.startswith(".") else x) tar.close() # Upload the tar'd contents of this directory and the runner script to S3 runner_script_name = f"{job_name}-runner.sh" - local_runner_script = os.path.join(os.path.dirname(os.path.realpath(__file__)), "AWS_scenario_runner.sh") + local_runner_script = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "AWS_scenario_runner.sh" + ) s3_client = boto3.client("s3") s3_client.upload_file(local_runner_script, s3_input_bucket, runner_script_name) s3_client.upload_file(tarfile_name, s3_input_bucket, tarfile_name) @@ -246,7 +253,9 @@ def launch_job_inner( containerOverrides=container_overrides, ) - print(f"Batch job with id {resp['jobId']} launched; output will be written to {results_path}") + print( + f"Batch job with id {resp['jobId']} launched; output will be written to {results_path}" + ) def get_dvc_outputs(): diff --git a/examples/test_cli.py b/examples/test_cli.py index 8b1d02982..1a7e726f3 100644 --- a/examples/test_cli.py +++ b/examples/test_cli.py @@ -1,4 +1,3 @@ - from click.testing import CliRunner from gempyor.simulate import simulate import os @@ -6,44 +5,48 @@ # See here to test click application https://click.palletsprojects.com/en/8.1.x/testing/ # would be useful to also call the command directly + def test_config_sample_2pop(): - os.chdir(os.path.dirname(__file__) + "/tutorial_two_subpops") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'config_sample_2pop.yml']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output + os.chdir(os.path.dirname(__file__) + "/tutorial_two_subpops") + runner = CliRunner() + result = runner.invoke(simulate, ["-c", "config_sample_2pop.yml"]) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output def test_sample_2pop_interventions_test(): - os.chdir(os.path.dirname(__file__) + "/tutorial_two_subpops") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'config_sample_2pop_interventions_test.yml']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output + os.chdir(os.path.dirname(__file__) + "/tutorial_two_subpops") + runner = CliRunner() + result = runner.invoke( + simulate, ["-c", "config_sample_2pop_interventions_test.yml"] + ) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output def test_simple_usa_statelevel(): - os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'simple_usa_statelevel.yml', '-n', '1']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output + os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel") + runner = CliRunner() + result = runner.invoke(simulate, ["-c", "simple_usa_statelevel.yml", "-n", "1"]) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output + def test_simple_usa_statelevel(): - os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'simple_usa_statelevel.yml', '-n', '1']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output \ No newline at end of file + os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel") + runner = CliRunner() + result = runner.invoke(simulate, ["-c", "simple_usa_statelevel.yml", "-n", "1"]) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py index ecbbba962..117f0c36b 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py @@ -21,7 +21,8 @@ def __init__( name=getattr( npi_config, "key", - (npi_config["scenario"].exists() and npi_config["scenario"].get()) or "unknown", + (npi_config["scenario"].exists() and npi_config["scenario"].get()) + or "unknown", ) ) @@ -32,7 +33,9 @@ def __init__( self.subpops = subpops self.pnames_overlap_operation_sum = pnames_overlap_operation_sum - self.pnames_overlap_operation_reductionprod = pnames_overlap_operation_reductionprod + self.pnames_overlap_operation_reductionprod = ( + pnames_overlap_operation_reductionprod + ) self.param_name = npi_config["parameter"].as_str().lower() @@ -68,14 +71,22 @@ def __init__( # if parameters are exceeding global start/end dates, index of parameter df will be out of range so check first if self.sanitize: - too_early = min([min(i) for i in self.parameters["start_date"]]) < self.start_date - too_late = max([max(i) for i in self.parameters["end_date"]]) > self.end_date + too_early = ( + min([min(i) for i in self.parameters["start_date"]]) < self.start_date + ) + too_late = ( + max([max(i) for i in self.parameters["end_date"]]) > self.end_date + ) if too_early or too_late: - raise ValueError("at least one period start or end date is not between global dates") + raise ValueError( + "at least one period start or end date is not between global dates" + ) for grp_config in npi_config["groups"]: affected_subpops_grp = self.__get_affected_subpops_grp(grp_config) - for sub_index in range(len(self.parameters["start_date"][affected_subpops_grp[0]])): + for sub_index in range( + len(self.parameters["start_date"][affected_subpops_grp[0]]) + ): period_range = pd.date_range( self.parameters["start_date"][affected_subpops_grp[0]][sub_index], self.parameters["end_date"][affected_subpops_grp[0]][sub_index], @@ -101,7 +112,9 @@ def __checkErrors(self): max_start_date = max([max(i) for i in self.parameters["start_date"]]) min_end_date = min([min(i) for i in self.parameters["end_date"]]) max_end_date = max([max(i) for i in self.parameters["end_date"]]) - if not ((self.start_date <= min_start_date) & (max_start_date <= self.end_date)): + if not ( + (self.start_date <= min_start_date) & (max_start_date <= self.end_date) + ): raise ValueError( f"at least one period_start_date [{min_start_date}, {max_start_date}] is not between global dates [{self.start_date}, {self.end_date}]" ) @@ -111,7 +124,9 @@ def __checkErrors(self): ) if not (self.parameters["start_date"] <= self.parameters["end_date"]).all(): - raise ValueError(f"at least one period_start_date is greater than the corresponding period end date") + raise ValueError( + f"at least one period_start_date is greater than the corresponding period end date" + ) for n in self.affected_subpops: if n not in self.subpops: @@ -135,7 +150,9 @@ def __createFromConfig(self, npi_config): self.affected_subpops = self.__get_affected_subpops(npi_config) - self.parameters = self.parameters[self.parameters.index.isin(self.affected_subpops)] + self.parameters = self.parameters[ + self.parameters.index.isin(self.affected_subpops) + ] dist = npi_config["value"].as_random_distribution() self.parameters["modifier_name"] = self.name self.parameters["parameter"] = self.param_name @@ -153,7 +170,9 @@ def __createFromConfig(self, npi_config): else: start_dates = [self.start_date] end_dates = [self.end_date] - this_spatial_group = helpers.get_spatial_groups(grp_config, affected_subpops_grp) + this_spatial_group = helpers.get_spatial_groups( + grp_config, affected_subpops_grp + ) self.spatial_groups.append(this_spatial_group) # print(self.name, this_spatial_groups) @@ -182,7 +201,9 @@ def __createFromDf(self, loaded_df, npi_config): loaded_df = loaded_df[loaded_df["modifier_name"] == self.name] self.affected_subpops = self.__get_affected_subpops(npi_config) - self.parameters = self.parameters[self.parameters.index.isin(self.affected_subpops)] + self.parameters = self.parameters[ + self.parameters.index.isin(self.affected_subpops) + ] self.parameters["modifier_name"] = self.name self.parameters["parameter"] = self.param_name @@ -194,7 +215,9 @@ def __createFromDf(self, loaded_df, npi_config): if self.sanitize: if len(self.affected_subpops) != len(self.parameters): print(f"loading {self.name} and we got {len(self.parameters)} subpops") - print(f"getting from config that it affects {len(self.affected_subpops)}") + print( + f"getting from config that it affects {len(self.affected_subpops)}" + ) self.spatial_groups = [] for grp_config in npi_config["groups"]: @@ -209,7 +232,9 @@ def __createFromDf(self, loaded_df, npi_config): else: start_dates = [self.start_date] end_dates = [self.end_date] - this_spatial_group = helpers.get_spatial_groups(grp_config, affected_subpops_grp) + this_spatial_group = helpers.get_spatial_groups( + grp_config, affected_subpops_grp + ) self.spatial_groups.append(this_spatial_group) for subpop in this_spatial_group["ungrouped"]: @@ -227,7 +252,9 @@ def __createFromDf(self, loaded_df, npi_config): for subpop in group: self.parameters.at[subpop, "start_date"] = start_dates self.parameters.at[subpop, "end_date"] = end_dates - self.parameters.at[subpop, "value"] = loaded_df.at[",".join(group), "value"] + self.parameters.at[subpop, "value"] = loaded_df.at[ + ",".join(group), "value" + ] else: dist = npi_config["value"].as_random_distribution() drawn_value = dist(size=1) @@ -258,11 +285,16 @@ def __get_affected_subpops(self, npi_config): affected_subpops_grp += [str(n.get()) for n in grp_config["subpop"]] affected_subpops = set(affected_subpops_grp) if len(affected_subpops) != len(affected_subpops_grp): - raise ValueError(f"In NPI {self.name}, some subpops belong to several groups. This is unsupported.") + raise ValueError( + f"In NPI {self.name}, some subpops belong to several groups. This is unsupported." + ) return affected_subpops def get_default(self, param): - if param in self.pnames_overlap_operation_sum or param in self.pnames_overlap_operation_reductionprod: + if ( + param in self.pnames_overlap_operation_sum + or param in self.pnames_overlap_operation_reductionprod + ): return 0.0 else: return 1.0 @@ -278,7 +310,9 @@ def getReductionToWrite(self): # self.parameters.index is a list of subpops for this_spatial_groups in self.spatial_groups: # spatially ungrouped dataframe - df_ungroup = self.parameters[self.parameters.index.isin(this_spatial_groups["ungrouped"])].copy() + df_ungroup = self.parameters[ + self.parameters.index.isin(this_spatial_groups["ungrouped"]) + ].copy() df_ungroup.index.name = "subpop" df_ungroup["start_date"] = df_ungroup["start_date"].apply( lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l]) @@ -301,7 +335,9 @@ def getReductionToWrite(self): "start_date": df_group["start_date"].apply( lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l]) ), - "end_date": df_group["end_date"].apply(lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l])), + "end_date": df_group["end_date"].apply( + lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l]) + ), "value": df_group["value"], } ).set_index("subpop") diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py index e078ddeba..f2438321c 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py @@ -21,7 +21,8 @@ def __init__( name=getattr( npi_config, "key", - (npi_config["scenario"].exists() and npi_config["scenario"].get()) or "unknown", + (npi_config["scenario"].exists() and npi_config["scenario"].get()) + or "unknown", ) ) @@ -29,7 +30,9 @@ def __init__( self.end_date = modinf.tf self.pnames_overlap_operation_sum = pnames_overlap_operation_sum - self.pnames_overlap_operation_reductionprod = pnames_overlap_operation_reductionprod + self.pnames_overlap_operation_reductionprod = ( + pnames_overlap_operation_reductionprod + ) self.subpops = subpops @@ -60,15 +63,22 @@ def __init__( self.__createFromConfig(npi_config) # if parameters are exceeding global start/end dates, index of parameter df will be out of range so check first - if self.parameters["start_date"].min() < self.start_date or self.parameters["end_date"].max() > self.end_date: - raise ValueError(f"""{self.name} : at least one period start or end date is not between global dates""") + if ( + self.parameters["start_date"].min() < self.start_date + or self.parameters["end_date"].max() > self.end_date + ): + raise ValueError( + f"""{self.name} : at least one period start or end date is not between global dates""" + ) # for index in self.parameters.index: # period_range = pd.date_range(self.parameters["start_date"][index], self.parameters["end_date"][index]) ## This the line that does the work # self.npi_old.loc[index, period_range] = np.tile(self.parameters["value"][index], (len(period_range), 1)).T - period_range = pd.date_range(self.parameters["start_date"].iloc[0], self.parameters["end_date"].iloc[0]) + period_range = pd.date_range( + self.parameters["start_date"].iloc[0], self.parameters["end_date"].iloc[0] + ) self.npi.loc[self.parameters.index, period_range] = np.tile( self.parameters["value"][:], (len(period_range), 1) ).T @@ -80,7 +90,9 @@ def __checkErrors(self): max_start_date = self.parameters["start_date"].max() min_end_date = self.parameters["end_date"].min() max_end_date = self.parameters["end_date"].max() - if not ((self.start_date <= min_start_date) & (max_start_date <= self.end_date)): + if not ( + (self.start_date <= min_start_date) & (max_start_date <= self.end_date) + ): raise ValueError( f"at least one period_start_date [{min_start_date}, {max_start_date}] is not between global dates [{self.start_date}, {self.end_date}]" ) @@ -90,7 +102,9 @@ def __checkErrors(self): ) if not (self.parameters["start_date"] <= self.parameters["end_date"]).all(): - raise ValueError(f"at least one period_start_date is greater than the corresponding period end date") + raise ValueError( + f"at least one period_start_date is greater than the corresponding period end date" + ) for n in self.affected_subpops: if n not in self.subpops: @@ -116,19 +130,27 @@ def __createFromConfig(self, npi_config): if npi_config["subpop"].exists() and npi_config["subpop"].get() != "all": self.affected_subpops = {str(n.get()) for n in npi_config["subpop"]} - self.parameters = self.parameters[self.parameters.index.isin(self.affected_subpops)] + self.parameters = self.parameters[ + self.parameters.index.isin(self.affected_subpops) + ] # Create reduction self.dist = npi_config["value"].as_random_distribution() self.parameters["modifier_name"] = self.name self.parameters["start_date"] = ( - npi_config["period_start_date"].as_date() if npi_config["period_start_date"].exists() else self.start_date + npi_config["period_start_date"].as_date() + if npi_config["period_start_date"].exists() + else self.start_date ) self.parameters["end_date"] = ( - npi_config["period_end_date"].as_date() if npi_config["period_end_date"].exists() else self.end_date + npi_config["period_end_date"].as_date() + if npi_config["period_end_date"].exists() + else self.end_date ) self.parameters["parameter"] = self.param_name - self.spatial_groups = helpers.get_spatial_groups(npi_config, list(self.affected_subpops)) + self.spatial_groups = helpers.get_spatial_groups( + npi_config, list(self.affected_subpops) + ) if self.spatial_groups["ungrouped"]: self.parameters.loc[self.spatial_groups["ungrouped"], "value"] = self.dist( size=len(self.spatial_groups["ungrouped"]) @@ -146,17 +168,23 @@ def __createFromDf(self, loaded_df, npi_config): if npi_config["subpop"].exists() and npi_config["subpop"].get() != "all": self.affected_subpops = {str(n.get()) for n in npi_config["subpop"]} - self.parameters = self.parameters[self.parameters.index.isin(self.affected_subpops)] + self.parameters = self.parameters[ + self.parameters.index.isin(self.affected_subpops) + ] self.parameters["modifier_name"] = self.name self.parameters["parameter"] = self.param_name # self.parameters = loaded_df[["modifier_name", "start_date", "end_date", "parameter", "value"]].copy() # dates are picked from config self.parameters["start_date"] = ( - npi_config["period_start_date"].as_date() if npi_config["period_start_date"].exists() else self.start_date + npi_config["period_start_date"].as_date() + if npi_config["period_start_date"].exists() + else self.start_date ) self.parameters["end_date"] = ( - npi_config["period_end_date"].as_date() if npi_config["period_end_date"].exists() else self.end_date + npi_config["period_end_date"].as_date() + if npi_config["period_end_date"].exists() + else self.end_date ) ## This is more legible to me, but if we change it here, we should change it in __createFromConfig as well # if npi_config["period_start_date"].exists(): @@ -175,17 +203,24 @@ def __createFromDf(self, loaded_df, npi_config): # TODO: to be consistent with MTR, we want to also draw the values for the subpops # that are not in the loaded_df. - self.spatial_groups = helpers.get_spatial_groups(npi_config, list(self.affected_subpops)) + self.spatial_groups = helpers.get_spatial_groups( + npi_config, list(self.affected_subpops) + ) if self.spatial_groups["ungrouped"]: - self.parameters.loc[self.spatial_groups["ungrouped"], "value"] = loaded_df.loc[ - self.spatial_groups["ungrouped"], "value" - ] + self.parameters.loc[self.spatial_groups["ungrouped"], "value"] = ( + loaded_df.loc[self.spatial_groups["ungrouped"], "value"] + ) if self.spatial_groups["grouped"]: for group in self.spatial_groups["grouped"]: - self.parameters.loc[group, "value"] = loaded_df.loc[",".join(group), "value"] + self.parameters.loc[group, "value"] = loaded_df.loc[ + ",".join(group), "value" + ] def get_default(self, param): - if param in self.pnames_overlap_operation_sum or param in self.pnames_overlap_operation_reductionprod: + if ( + param in self.pnames_overlap_operation_sum + or param in self.pnames_overlap_operation_reductionprod + ): return 0.0 else: return 1.0 @@ -198,7 +233,9 @@ def getReduction(self, param): def getReductionToWrite(self): # spatially ungrouped dataframe - df = self.parameters[self.parameters.index.isin(self.spatial_groups["ungrouped"])].copy() + df = self.parameters[ + self.parameters.index.isin(self.spatial_groups["ungrouped"]) + ].copy() df.index.name = "subpop" df["start_date"] = df["start_date"].astype("str") df["end_date"] = df["end_date"].astype("str") diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py index 489a48fbb..21165b0c8 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py @@ -32,7 +32,9 @@ def __init__( self.end_date = modinf.tf self.pnames_overlap_operation_sum = pnames_overlap_operation_sum - self.pnames_overlap_operation_reductionprod = pnames_overlap_operation_reductionprod + self.pnames_overlap_operation_reductionprod = ( + pnames_overlap_operation_reductionprod + ) self.subpops = subpops self.param_name = [] @@ -47,7 +49,9 @@ def __init__( if isinstance(scenario, str): settings = modifiers_library.get(scenario) if settings is None: - raise RuntimeError(f"couldn't find scenario in config file [got: {scenario}]") + raise RuntimeError( + f"couldn't find scenario in config file [got: {scenario}]" + ) # via profiling: faster to recreate the confuse view than to fetch+resolve due to confuse isinstance # checks scenario_npi_config = confuse.RootView([settings]) @@ -68,12 +72,16 @@ def __init__( ) new_params = sub_npi.param_name # either a list (if stacked) or a string - new_params = [new_params] if isinstance(new_params, str) else new_params # convert to list + new_params = ( + [new_params] if isinstance(new_params, str) else new_params + ) # convert to list # Add each parameter at first encounter, with a neutral start for new_p in new_params: if new_p not in self.param_name: self.param_name.append(new_p) - if new_p in pnames_overlap_operation_sum: # re.match("^transition_rate [1234567890]+$",new_p): + if ( + new_p in pnames_overlap_operation_sum + ): # re.match("^transition_rate [1234567890]+$",new_p): self.reductions[new_p] = 0 else: # for the reductionprod and product method, the initial neutral is 1 ) self.reductions[new_p] = 1 @@ -81,7 +89,9 @@ def __init__( for param in self.param_name: # Get reduction return a neutral value for this overlap operation if no parameeter exists reduction = sub_npi.getReduction(param) - if param in pnames_overlap_operation_sum: # re.match("^transition_rate [1234567890]+$",param): + if ( + param in pnames_overlap_operation_sum + ): # re.match("^transition_rate [1234567890]+$",param): self.reductions[param] += reduction elif param in pnames_overlap_operation_reductionprod: self.reductions[param] *= 1 - reduction @@ -104,7 +114,9 @@ def __init__( self.reduction_params.clear() for param in self.param_name: - if param in pnames_overlap_operation_reductionprod: # re.match("^transition_rate \d+$",param): + if ( + param in pnames_overlap_operation_reductionprod + ): # re.match("^transition_rate \d+$",param): self.reductions[param] = 1 - self.reductions[param] # check that no NPI is called several times, and retourn them @@ -124,7 +136,10 @@ def __checkErrors(self): # ) def get_default(self, param): - if param in self.pnames_overlap_operation_sum or param in self.pnames_overlap_operation_reductionprod: + if ( + param in self.pnames_overlap_operation_sum + or param in self.pnames_overlap_operation_reductionprod + ): return 0.0 else: return 1.0 diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py b/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py index f964d4c6e..7a2b4ccc6 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py @@ -12,7 +12,9 @@ def reduce_parameter( if isinstance(modification, pd.DataFrame): modification = modification.T modification.index = pd.to_datetime(modification.index.astype(str)) - modification = modification.resample("1D").ffill().to_numpy() # Type consistency: + modification = ( + modification.resample("1D").ffill().to_numpy() + ) # Type consistency: if method == "reduction_product": return parameter * (1 - modification) elif method == "sum": @@ -43,16 +45,22 @@ def get_spatial_groups(grp_config, affected_subpops: list) -> dict: else: spatial_groups["grouped"] = grp_config["subpop_groups"].get() spatial_groups["ungrouped"] = list( - set(affected_subpops) - set(flatten_list_of_lists(spatial_groups["grouped"])) + set(affected_subpops) + - set(flatten_list_of_lists(spatial_groups["grouped"])) ) # flatten the list of lists of grouped subpops, so we can do some checks flat_grouped_list = flatten_list_of_lists(spatial_groups["grouped"]) # check that all subpops are either grouped or ungrouped if set(flat_grouped_list + spatial_groups["ungrouped"]) != set(affected_subpops): - print("set of grouped and ungrouped subpops", set(flat_grouped_list + spatial_groups["ungrouped"])) + print( + "set of grouped and ungrouped subpops", + set(flat_grouped_list + spatial_groups["ungrouped"]), + ) print("set of affected subpops ", set(affected_subpops)) - raise ValueError(f"The two above sets are differs for for intervention with config \n {grp_config}") + raise ValueError( + f"The two above sets are differs for for intervention with config \n {grp_config}" + ) if len(set(flat_grouped_list + spatial_groups["ungrouped"])) != len( flat_grouped_list + spatial_groups["ungrouped"] ): diff --git a/flepimop/gempyor_pkg/src/gempyor/calibrate.py b/flepimop/gempyor_pkg/src/gempyor/calibrate.py index e5cb287aa..41de8b673 100644 --- a/flepimop/gempyor_pkg/src/gempyor/calibrate.py +++ b/flepimop/gempyor_pkg/src/gempyor/calibrate.py @@ -158,7 +158,9 @@ def calibrate( # TODO here for resume if resume or resume_location is not None: - print("Doing a resume, this only work with the same number of slot and parameters right now") + print( + "Doing a resume, this only work with the same number of slot and parameters right now" + ) p0 = None if resume_location is not None: backend = emcee.backends.HDFBackend(resume_location) @@ -188,14 +190,19 @@ def calibrate( backend=backend, moves=moves, ) - state = sampler.run_mcmc(p0, niter, progress=True, skip_initial_state_check=True) + state = sampler.run_mcmc( + p0, niter, progress=True, skip_initial_state_check=True + ) print(f"Done, mean acceptance fraction: {np.mean(sampler.acceptance_fraction):.3f}") # plotting the chain sampler = emcee.backends.HDFBackend(filename, read_only=True) gempyor.postprocess_inference.plot_chains( - inferpar=gempyor_inference.inferpar, sampler_output=sampler, sampled_slots=None, save_to=f"{run_id}_chains.pdf" + inferpar=gempyor_inference.inferpar, + sampler_output=sampler, + sampled_slots=None, + save_to=f"{run_id}_chains.pdf", ) print("EMCEE Run done, doing sampling") @@ -203,11 +210,14 @@ def calibrate( shutil.rmtree(project_path + "model_output/", ignore_errors=True) max_indices = np.argsort(sampler.get_log_prob()[-1, :])[-nsamples:] - samples = sampler.get_chain()[-1, max_indices, :] # the last iteration, for selected slots + samples = sampler.get_chain()[ + -1, max_indices, : + ] # the last iteration, for selected slots gempyor_inference.set_save(True) with multiprocessing.Pool(ncpu) as pool: results = pool.starmap( - gempyor_inference.get_logloss_as_single_number, [(samples[i, :],) for i in range(len(max_indices))] + gempyor_inference.get_logloss_as_single_number, + [(samples[i, :],) for i in range(len(max_indices))], ) # results = [] # for fn in gempyor.utils.list_filenames(folder="model_output/", filters=[run_id, "hosp.parquet"]): diff --git a/flepimop/gempyor_pkg/src/gempyor/compartments.py b/flepimop/gempyor_pkg/src/gempyor/compartments.py index ec87cf7e5..cfbce6b50 100644 --- a/flepimop/gempyor_pkg/src/gempyor/compartments.py +++ b/flepimop/gempyor_pkg/src/gempyor/compartments.py @@ -13,7 +13,13 @@ class Compartments: # Minimal object to be easily picklable for // runs - def __init__(self, seir_config=None, compartments_config=None, compartments_file=None, transitions_file=None): + def __init__( + self, + seir_config=None, + compartments_config=None, + compartments_file=None, + transitions_file=None, + ): self.times_set = 0 ## Something like this is needed for check script: @@ -29,7 +35,7 @@ def __init__(self, seir_config=None, compartments_config=None, compartments_file return def constructFromConfig(self, seir_config, compartment_config): - """ + """ This method is called by the constructor if the compartments are not loaded from a file. It will parse the compartments and transitions from the configuration files. It will populate self.compartments and self.transitions. @@ -43,7 +49,7 @@ def __eq__(self, other): ).all().all() def parse_compartments(self, seir_config, compartment_config): - """ Parse the compartments from the configuration file: + """Parse the compartments from the configuration file: seir_config: the configuration file for the SEIR model compartment_config: the configuration file for the compartments Example: if config says: @@ -75,40 +81,57 @@ def parse_compartments(self, seir_config, compartment_config): else: compartment_df = pd.merge(compartment_df, tmp, on="key") compartment_df = compartment_df.drop(["key"], axis=1) - compartment_df["name"] = compartment_df.apply(lambda x: reduce(lambda a, b: a + "_" + b, x), axis=1) + compartment_df["name"] = compartment_df.apply( + lambda x: reduce(lambda a, b: a + "_" + b, x), axis=1 + ) return compartment_df def parse_transitions(self, seir_config, fake_config=False): rc = reduce( - lambda a, b: pd.concat([a, self.parse_single_transition(seir_config, b, fake_config)]), + lambda a, b: pd.concat( + [a, self.parse_single_transition(seir_config, b, fake_config)] + ), seir_config["transitions"], pd.DataFrame(), ) rc = rc.reset_index(drop=True) return rc - def check_transition_element(self, single_transition_config, problem_dimension=None): + def check_transition_element( + self, single_transition_config, problem_dimension=None + ): return True def check_transition_elements(self, single_transition_config, problem_dimension): return True - def access_original_config_by_multi_index(self, config_piece, index, dimension=None, encapsulate_as_list=False): + def access_original_config_by_multi_index( + self, config_piece, index, dimension=None, encapsulate_as_list=False + ): if dimension is None: dimension = [None for i in index] tmp = [y for y in zip(index, range(len(index)), dimension)] tmp = zip(index, range(len(index)), dimension) - tmp = [list_access_element_safe(config_piece[x[1]], x[0], x[2], encapsulate_as_list) for x in tmp] + tmp = [ + list_access_element_safe( + config_piece[x[1]], x[0], x[2], encapsulate_as_list + ) + for x in tmp + ] return tmp def expand_transition_elements(self, single_transition_config, problem_dimension): - proportion_size = get_list_dimension(single_transition_config["proportional_to"]) + proportion_size = get_list_dimension( + single_transition_config["proportional_to"] + ) new_transition_config = single_transition_config.copy() # replace "source" by the actual source from the config for p_idx in range(proportion_size): if new_transition_config["proportional_to"][p_idx] == "source": - new_transition_config["proportional_to"][p_idx] = new_transition_config["source"] + new_transition_config["proportional_to"][p_idx] = new_transition_config[ + "source" + ] temp_array = np.zeros(problem_dimension) @@ -116,44 +139,80 @@ def expand_transition_elements(self, single_transition_config, problem_dimension new_transition_config["destination"] = np.zeros(problem_dimension, dtype=object) new_transition_config["rate"] = np.zeros(problem_dimension, dtype=object) - new_transition_config["proportional_to"] = np.zeros(problem_dimension, dtype=object) - new_transition_config["proportion_exponent"] = np.zeros(problem_dimension, dtype=object) + new_transition_config["proportional_to"] = np.zeros( + problem_dimension, dtype=object + ) + new_transition_config["proportion_exponent"] = np.zeros( + problem_dimension, dtype=object + ) - it = np.nditer(temp_array, flags=["multi_index"]) # it is an iterator that will go through all the indexes of the array + it = np.nditer( + temp_array, flags=["multi_index"] + ) # it is an iterator that will go through all the indexes of the array for x in it: try: - new_transition_config["source"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["source"], it.multi_index) + new_transition_config["source"][it.multi_index] = ( + list_recursive_convert_to_string( + self.access_original_config_by_multi_index( + single_transition_config["source"], it.multi_index + ) + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `source:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `source:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e try: - new_transition_config["destination"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["destination"], it.multi_index) + new_transition_config["destination"][it.multi_index] = ( + list_recursive_convert_to_string( + self.access_original_config_by_multi_index( + single_transition_config["destination"], it.multi_index + ) + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `destination:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `destination:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e - + try: - new_transition_config["rate"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["rate"], it.multi_index) + new_transition_config["rate"][it.multi_index] = ( + list_recursive_convert_to_string( + self.access_original_config_by_multi_index( + single_transition_config["rate"], it.multi_index + ) + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `rate:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `rate:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e @@ -173,43 +232,68 @@ def expand_transition_elements(self, single_transition_config, problem_dimension ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `proportional_to:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `proportional_to:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e - - if "proportion_exponent" in single_transition_config: # if proportion_exponent is not defined, it is set to 1 + + if ( + "proportion_exponent" in single_transition_config + ): # if proportion_exponent is not defined, it is set to 1 try: self.access_original_config_by_multi_index( single_transition_config["proportion_exponent"][0], it.multi_index, problem_dimension, ) - new_transition_config["proportion_exponent"][it.multi_index] = list_recursive_convert_to_string( - [ - self.access_original_config_by_multi_index( - single_transition_config["proportion_exponent"][p_idx], - it.multi_index, - problem_dimension, - ) - for p_idx in range(proportion_size) - ] + new_transition_config["proportion_exponent"][it.multi_index] = ( + list_recursive_convert_to_string( + [ + self.access_original_config_by_multi_index( + single_transition_config["proportion_exponent"][ + p_idx + ], + it.multi_index, + problem_dimension, + ) + for p_idx in range(proportion_size) + ] + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `proportion_exponent:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `proportion_exponent:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e else: - new_transition_config["proportion_exponent"][it.multi_index] = ["1"] * proportion_size + new_transition_config["proportion_exponent"][it.multi_index] = [ + "1" + ] * proportion_size return new_transition_config def format_source(self, source_column): - rc = [y for y in map(lambda x: reduce(lambda a, b: str(a) + "_" + str(b), x), source_column)] + rc = [ + y + for y in map( + lambda x: reduce(lambda a, b: str(a) + "_" + str(b), x), source_column + ) + ] return rc def unformat_source(self, source_column): @@ -231,7 +315,12 @@ def unformat_destination(self, destination_column): return rc def format_rate(self, rate_column): - rc = [y for y in map(lambda x: reduce(lambda a, b: str(a) + "%*%" + str(b), x), rate_column)] + rc = [ + y + for y in map( + lambda x: reduce(lambda a, b: str(a) + "%*%" + str(b), x), rate_column + ) + ] return rc def unformat_rate(self, rate_column, compartment_dimension): @@ -251,7 +340,9 @@ def format_proportional_to(self, proportional_to_column): lambda x: reduce( lambda a, b: str(a) + "_" + str(b), map( - lambda x: reduce(lambda a, b: str(a) + "+" + str(b), as_list(x)), + lambda x: reduce( + lambda a, b: str(a) + "+" + str(b), as_list(x) + ), x, ), ), @@ -284,27 +375,41 @@ def format_proportion_exponent(self, proportion_exponent_column): ] return rc - def unformat_proportion_exponent(self, proportion_exponent_column, compartment_dimension): + def unformat_proportion_exponent( + self, proportion_exponent_column, compartment_dimension + ): rc = [x.split("%*%") for x in proportion_exponent_column] for row in range(len(rc)): - rc[row] = [x.split("*", maxsplit=compartment_dimension - 1) for x in rc[row]] + rc[row] = [ + x.split("*", maxsplit=compartment_dimension - 1) for x in rc[row] + ] for elem in rc[row]: while len(elem) < compartment_dimension: elem.append(1) return rc - def parse_single_transition(self, seir_config, single_transition_config, fake_config=False): + def parse_single_transition( + self, seir_config, single_transition_config, fake_config=False + ): ## This method relies on having run parse_compartments if not fake_config: single_transition_config = single_transition_config.get() self.check_transition_element(single_transition_config["source"]) self.check_transition_element(single_transition_config["destination"]) - source_dimension = [get_list_dimension(x) for x in single_transition_config["source"]] - destination_dimension = [get_list_dimension(x) for x in single_transition_config["destination"]] - problem_dimension = reduce(lambda x, y: max(x, y), (source_dimension, destination_dimension)) + source_dimension = [ + get_list_dimension(x) for x in single_transition_config["source"] + ] + destination_dimension = [ + get_list_dimension(x) for x in single_transition_config["destination"] + ] + problem_dimension = reduce( + lambda x, y: max(x, y), (source_dimension, destination_dimension) + ) self.check_transition_elements(single_transition_config, problem_dimension) - transitions = self.expand_transition_elements(single_transition_config, problem_dimension) + transitions = self.expand_transition_elements( + single_transition_config, problem_dimension + ) tmp_array = np.zeros(problem_dimension) it = np.nditer(tmp_array, flags=["multi_index"]) @@ -316,8 +421,12 @@ def parse_single_transition(self, seir_config, single_transition_config, fake_co "source": [transitions["source"][it.multi_index]], "destination": [transitions["destination"][it.multi_index]], "rate": [transitions["rate"][it.multi_index]], - "proportional_to": [transitions["proportional_to"][it.multi_index]], - "proportion_exponent": [transitions["proportion_exponent"][it.multi_index]], + "proportional_to": [ + transitions["proportional_to"][it.multi_index] + ], + "proportion_exponent": [ + transitions["proportion_exponent"][it.multi_index] + ], }, index=[0], ) @@ -328,7 +437,10 @@ def parse_single_transition(self, seir_config, single_transition_config, fake_co return rc def toFile( - self, compartments_file="compartments.parquet", transitions_file="transitions.parquet", write_parquet=True + self, + compartments_file="compartments.parquet", + transitions_file="transitions.parquet", + write_parquet=True, ): out_df = self.compartments.copy() if write_parquet: @@ -341,8 +453,12 @@ def toFile( out_df["source"] = self.format_source(out_df["source"]) out_df["destination"] = self.format_destination(out_df["destination"]) out_df["rate"] = self.format_rate(out_df["rate"]) - out_df["proportional_to"] = self.format_proportional_to(out_df["proportional_to"]) - out_df["proportion_exponent"] = self.format_proportion_exponent(out_df["proportion_exponent"]) + out_df["proportional_to"] = self.format_proportional_to( + out_df["proportional_to"] + ) + out_df["proportion_exponent"] = self.format_proportion_exponent( + out_df["proportion_exponent"] + ) if write_parquet: pa_df = pa.Table.from_pandas(out_df, preserve_index=False) pa.parquet.write_table(pa_df, transitions_file) @@ -355,9 +471,15 @@ def fromFile(self, compartments_file, transitions_file): self.transitions = pq.read_table(transitions_file).to_pandas() compartment_dimension = self.compartments.shape[1] - 1 self.transitions["source"] = self.unformat_source(self.transitions["source"]) - self.transitions["destination"] = self.unformat_destination(self.transitions["destination"]) - self.transitions["rate"] = self.unformat_rate(self.transitions["rate"], compartment_dimension) - self.transitions["proportional_to"] = self.unformat_proportional_to(self.transitions["proportional_to"]) + self.transitions["destination"] = self.unformat_destination( + self.transitions["destination"] + ) + self.transitions["rate"] = self.unformat_rate( + self.transitions["rate"], compartment_dimension + ) + self.transitions["proportional_to"] = self.unformat_proportional_to( + self.transitions["proportional_to"] + ) self.transitions["proportion_exponent"] = self.unformat_proportion_exponent( self.transitions["proportion_exponent"], compartment_dimension ) @@ -371,7 +493,9 @@ def get_comp_idx(self, comp_dict: dict, error_info: str = "no information") -> i :param comp_dict: :return: """ - mask = pd.concat([self.compartments[k] == v for k, v in comp_dict.items()], axis=1).all(axis=1) + mask = pd.concat( + [self.compartments[k] == v for k, v in comp_dict.items()], axis=1 + ).all(axis=1) comp_idx = self.compartments[mask].index.values if len(comp_idx) != 1: raise ValueError( @@ -382,10 +506,11 @@ def get_comp_idx(self, comp_dict: dict, error_info: str = "no information") -> i def get_ncomp(self) -> int: return len(self.compartments) - def get_transition_array(self): with Timer("SEIR.compartments"): - transition_array = np.zeros((self.transitions.shape[1], self.transitions.shape[0]), dtype="int64") + transition_array = np.zeros( + (self.transitions.shape[1], self.transitions.shape[0]), dtype="int64" + ) for cit, colname in enumerate(("source", "destination")): for it, elem in enumerate(self.transitions[colname]): elem = reduce(lambda a, b: a + "_" + b, elem) @@ -395,7 +520,9 @@ def get_transition_array(self): rc = compartment if rc == -1: print(self.compartments) - raise ValueError(f"Could not find {colname} defined by {elem} in compartments") + raise ValueError( + f"Could not find {colname} defined by {elem} in compartments" + ) transition_array[cit, it] = rc unique_strings = [] @@ -417,8 +544,12 @@ def get_transition_array(self): # parenthesis are now supported # assert reduce(lambda a, b: a and b, [(x.find("(") == -1) for x in unique_strings]) # assert reduce(lambda a, b: a and b, [(x.find(")") == -1) for x in unique_strings]) - assert reduce(lambda a, b: a and b, [(x.find("%") == -1) for x in unique_strings]) - assert reduce(lambda a, b: a and b, [(x.find(" ") == -1) for x in unique_strings]) + assert reduce( + lambda a, b: a and b, [(x.find("%") == -1) for x in unique_strings] + ) + assert reduce( + lambda a, b: a and b, [(x.find(" ") == -1) for x in unique_strings] + ) for it, elem in enumerate(self.transitions["rate"]): candidate = reduce(lambda a, b: a + "*" + b, elem) @@ -454,8 +585,12 @@ def get_transition_array(self): # rc = compartment # if rc == -1: # raise ValueError(f"Could not find match for {elem3} in compartments") - proportion_info[0][current_proportion_sum_it] = current_proportion_sum_start - proportion_info[1][current_proportion_sum_it] = current_proportion_sum_start + len(elem_tmp) + proportion_info[0][ + current_proportion_sum_it + ] = current_proportion_sum_start + proportion_info[1][current_proportion_sum_it] = ( + current_proportion_sum_start + len(elem_tmp) + ) current_proportion_sum_it += 1 current_proportion_sum_start += len(elem_tmp) proportion_compartment_index = 0 @@ -466,7 +601,9 @@ def get_transition_array(self): # candidate = candidate.replace("*1", "") if not candidate in unique_strings: raise ValueError("Something went wrong") - rc = [it for it, x in enumerate(unique_strings) if x == candidate][0] + rc = [it for it, x in enumerate(unique_strings) if x == candidate][ + 0 + ] proportion_info[2][proportion_compartment_index] = rc proportion_compartment_index += 1 @@ -490,7 +627,9 @@ def get_transition_array(self): if self.compartments["name"][compartment] == elem3: rc = compartment if rc == -1: - raise ValueError(f"Could not find proportional_to {elem3} in compartments") + raise ValueError( + f"Could not find proportional_to {elem3} in compartments" + ) proportion_array[proportion_index] = rc proportion_index += 1 @@ -528,21 +667,29 @@ def get_transition_array(self): def parse_parameters(self, parameters, parameter_names, unique_strings): # parsed_parameters_old = self.parse_parameter_strings_to_numpy_arrays(parameters, parameter_names, unique_strings) - parsed_parameters = self.parse_parameter_strings_to_numpy_arrays_v2(parameters, parameter_names, unique_strings) + parsed_parameters = self.parse_parameter_strings_to_numpy_arrays_v2( + parameters, parameter_names, unique_strings + ) # for i in range(len(unique_strings)): # print(unique_strings[i], (parsed_parameters[i]==parsed_parameters_old[i]).all()) return parsed_parameters - def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names, string_list): + def parse_parameter_strings_to_numpy_arrays_v2( + self, parameters, parameter_names, string_list + ): # is using eval a better way ??? import sympy as sp # Validate input lengths if len(parameters) != len(parameter_names): - raise ValueError("Number of parameter values does not match the number of parameter names.") + raise ValueError( + "Number of parameter values does not match the number of parameter names." + ) # Define the symbols used in the formulas - symbolic_parameters_namespace = {name: sp.symbols(name) for name in parameter_names} + symbolic_parameters_namespace = { + name: sp.symbols(name) for name in parameter_names + } symbolic_parameters = [sp.symbols(name) for name in parameter_names] @@ -554,12 +701,18 @@ def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names f = sp.sympify(formula, locals=symbolic_parameters_namespace) parsed_formulas.append(f) except Exception as e: - print(f"Cannot parse formula: '{formula}' from parameters {parameter_names}") + print( + f"Cannot parse formula: '{formula}' from parameters {parameter_names}" + ) raise (e) # Print the error message for debugging # the list order needs to be right. - parameter_values = {param: value for param, value in zip(symbolic_parameters, parameters)} - parameter_values_list = [parameter_values[param] for param in symbolic_parameters] + parameter_values = { + param: value for param, value in zip(symbolic_parameters, parameters) + } + parameter_values_list = [ + parameter_values[param] for param in symbolic_parameters + ] # Create a lambdify function for substitution substitution_function = sp.lambdify(symbolic_parameters, parsed_formulas) @@ -573,7 +726,9 @@ def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names if not isinstance(substituted_formulas[i], np.ndarray): for k in range(len(substituted_formulas)): if isinstance(substituted_formulas[k], np.ndarray): - substituted_formulas[i] = substituted_formulas[i] * np.ones_like(substituted_formulas[k]) + substituted_formulas[i] = substituted_formulas[ + i + ] * np.ones_like(substituted_formulas[k]) return np.array(substituted_formulas) @@ -621,19 +776,29 @@ def parse_parameter_strings_to_numpy_arrays( is_resolvable = [x[0] or x[1] for x in zip(is_numeric, is_parameter)] is_totally_resolvable = reduce(lambda a, b: a and b, is_resolvable) if not is_totally_resolvable: - not_resolvable_indices = [it for it, x in enumerate(is_resolvable) if not x] - - tmp_rc[not_resolvable_indices] = self.parse_parameter_strings_to_numpy_arrays( - parameters, - parameter_names, - [string[not is_resolvable]], - operator_reduce_lambdas, - operators[1:], + not_resolvable_indices = [ + it for it, x in enumerate(is_resolvable) if not x + ] + + tmp_rc[not_resolvable_indices] = ( + self.parse_parameter_strings_to_numpy_arrays( + parameters, + parameter_names, + [string[not is_resolvable]], + operator_reduce_lambdas, + operators[1:], + ) ) for numeric_index in [x for x in range(len(is_numeric)) if is_numeric[x]]: tmp_rc[numeric_index] = parameters[0] * 0 + float(string[numeric_index]) - for parameter_index in [x for x in range(len(is_parameter)) if is_parameter[x]]: - parameter_name_index = [it for it, x in enumerate(parameter_names) if x == string[parameter_index]] + for parameter_index in [ + x for x in range(len(is_parameter)) if is_parameter[x] + ]: + parameter_name_index = [ + it + for it, x in enumerate(parameter_names) + if x == string[parameter_index] + ] tmp_rc[parameter_index] = parameters[parameter_name_index] rc[sit] = reduce(operator_reduce_lambdas[operators[0]], tmp_rc) @@ -648,7 +813,9 @@ def get_compartments_explicitDF(self): df = df.rename(columns=rename_dict) return df - def plot(self, output_file="transition_graph", source_filters=[], destination_filters=[]): + def plot( + self, output_file="transition_graph", source_filters=[], destination_filters=[] + ): """ if source_filters is [["age0to17"], ["OMICRON", "WILD"]], it means filter all transitions that have as source age0to17 AND (OMICRON OR WILD). @@ -712,8 +879,12 @@ def list_access_element_safe(thing, idx, dimension=None, encapsulate_as_list=Fal except Exception as e: print(f"Error {e}:") print(f">>> in list_access_element_safe for {thing} at index {idx}") - print(">>> This is often, but not always because the object above is a list (there are brackets around it).") - print(">>> and in this case it is not broadcast, so if you want to it to be broadcasted, you need remove the brackets around it.") + print( + ">>> This is often, but not always because the object above is a list (there are brackets around it)." + ) + print( + ">>> and in this case it is not broadcast, so if you want to it to be broadcasted, you need remove the brackets around it." + ) print(f"dimension: {dimension}") raise e @@ -755,7 +926,9 @@ def compartments(): def plot(): assert config["compartments"].exists() assert config["seir"].exists() - comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + comp = Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) # TODO: this should be a command like build compartments. ( @@ -774,7 +947,9 @@ def plot(): def export(): assert config["compartments"].exists() assert config["seir"].exists() - comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + comp = Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) ( unique_strings, transition_array, diff --git a/flepimop/gempyor_pkg/src/gempyor/config_validator.py b/flepimop/gempyor_pkg/src/gempyor/config_validator.py index 95f6e3029..40f9423ff 100644 --- a/flepimop/gempyor_pkg/src/gempyor/config_validator.py +++ b/flepimop/gempyor_pkg/src/gempyor/config_validator.py @@ -1,33 +1,60 @@ import yaml -from pydantic import BaseModel, ValidationError, model_validator, Field, AfterValidator, validator +from pydantic import ( + BaseModel, + ValidationError, + model_validator, + Field, + AfterValidator, + validator, +) from datetime import date from typing import Dict, List, Union, Literal, Optional, Annotated, Any from functools import partial from gempyor import compartments + def read_yaml(file_path: str) -> dict: - with open(file_path, 'r') as stream: + with open(file_path, "r") as stream: config = yaml.safe_load(stream) - + return CheckConfig(**config).model_dump() - + + def allowed_values(v, values): assert v in values return v + # def parse_value(cls, values): # value = values.get('value') # parsed_val = compartments.Compartments.parse_parameter_strings_to_numpy_arrays_v2(value) # return parsed_val - + + class SubpopSetupConfig(BaseModel): geodata: str mobility: Optional[str] selected: List[str] = Field(default_factory=list) # state_level: Optional[bool] = False # pretty sure this doesn't exist anymore + class InitialConditionsConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['Default', 'SetInitialConditions', 'SetInitialConditionsFolderDraw', 'InitialConditionsFolderDraw', 'FromFile', 'plugin']))] = 'Default' + method: Annotated[ + str, + AfterValidator( + partial( + allowed_values, + values=[ + "Default", + "SetInitialConditions", + "SetInitialConditionsFolderDraw", + "InitialConditionsFolderDraw", + "FromFile", + "plugin", + ], + ) + ), + ] = "Default" initial_file_type: Optional[str] = None initial_conditions_file: Optional[str] = None proportional: Optional[bool] = None @@ -36,105 +63,163 @@ class InitialConditionsConfig(BaseModel): ignore_population_checks: Optional[bool] = None plugin_file_path: Optional[str] = None - @model_validator(mode='before') + @model_validator(mode="before") def validate_initial_file_check(cls, values): - method = values.get('method') - initial_conditions_file = values.get('initial_conditions_file') - initial_file_type = values.get('initial_file_type') - if method in {'FromFile', 'SetInitialConditions'} and not initial_conditions_file: - raise ValueError(f'Error in InitialConditions: An initial_conditions_file is required when method is {method}') - if method in {'InitialConditionsFolderDraw','SetInitialConditionsFolderDraw'} and not initial_file_type: - raise ValueError(f'Error in InitialConditions: initial_file_type is required when method is {method}') + method = values.get("method") + initial_conditions_file = values.get("initial_conditions_file") + initial_file_type = values.get("initial_file_type") + if ( + method in {"FromFile", "SetInitialConditions"} + and not initial_conditions_file + ): + raise ValueError( + f"Error in InitialConditions: An initial_conditions_file is required when method is {method}" + ) + if ( + method in {"InitialConditionsFolderDraw", "SetInitialConditionsFolderDraw"} + and not initial_file_type + ): + raise ValueError( + f"Error in InitialConditions: initial_file_type is required when method is {method}" + ) return values - - @model_validator(mode='before') + + @model_validator(mode="before") def plugin_filecheck(cls, values): - method = values.get('method') - plugin_file_path = values.get('plugin_file_path') - if method == 'plugin' and not plugin_file_path: - raise ValueError('Error in InitialConditions: a plugin file path is required when method is plugin') + method = values.get("method") + plugin_file_path = values.get("plugin_file_path") + if method == "plugin" and not plugin_file_path: + raise ValueError( + "Error in InitialConditions: a plugin file path is required when method is plugin" + ) return values class SeedingConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['NoSeeding', 'PoissonDistributed', 'FolderDraw', 'FromFile', 'plugin']))] = 'NoSeeding' # note: removed NegativeBinomialDistributed because no longer supported + method: Annotated[ + str, + AfterValidator( + partial( + allowed_values, + values=[ + "NoSeeding", + "PoissonDistributed", + "FolderDraw", + "FromFile", + "plugin", + ], + ) + ), + ] = "NoSeeding" # note: removed NegativeBinomialDistributed because no longer supported lambda_file: Optional[str] = None seeding_file_type: Optional[str] = None seeding_file: Optional[str] = None plugin_file_path: Optional[str] = None - @model_validator(mode='before') + @model_validator(mode="before") def validate_seedingfile(cls, values): - method = values.get('method') - lambda_file = values.get('lambda_file') - seeding_file_type = values.get('seeding_file_type') - seeding_file = values.get('seeding_file') - if method == 'PoissonDistributed' and not lambda_file: - raise ValueError(f'Error in Seeding: A lambda_file is required when method is {method}') - if method == 'FolderDraw' and not seeding_file_type: - raise ValueError('Error in Seeding: A seeding_file_type is required when method is FolderDraw') - if method == 'FromFile' and not seeding_file: - raise ValueError('Error in Seeding: A seeding_file is required when method is FromFile') + method = values.get("method") + lambda_file = values.get("lambda_file") + seeding_file_type = values.get("seeding_file_type") + seeding_file = values.get("seeding_file") + if method == "PoissonDistributed" and not lambda_file: + raise ValueError( + f"Error in Seeding: A lambda_file is required when method is {method}" + ) + if method == "FolderDraw" and not seeding_file_type: + raise ValueError( + "Error in Seeding: A seeding_file_type is required when method is FolderDraw" + ) + if method == "FromFile" and not seeding_file: + raise ValueError( + "Error in Seeding: A seeding_file is required when method is FromFile" + ) return values - - @model_validator(mode='before') + + @model_validator(mode="before") def plugin_filecheck(cls, values): - method = values.get('method') - plugin_file_path = values.get('plugin_file_path') - if method == 'plugin' and not plugin_file_path: - raise ValueError('Error in Seeding: a plugin file path is required when method is plugin') + method = values.get("method") + plugin_file_path = values.get("plugin_file_path") + if method == "plugin" and not plugin_file_path: + raise ValueError( + "Error in Seeding: a plugin file path is required when method is plugin" + ) return values - + + class IntegrationConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['rk4', 'rk4.jit', 'best.current', 'legacy']))] = 'rk4' + method: Annotated[ + str, + AfterValidator( + partial(allowed_values, values=["rk4", "rk4.jit", "best.current", "legacy"]) + ), + ] = "rk4" dt: float = 2.0 + class ValueConfig(BaseModel): - distribution: str = 'fixed' - value: Optional[float] = None # NEED TO ADD ABILITY TO PARSE PARAMETERS + distribution: str = "fixed" + value: Optional[float] = None # NEED TO ADD ABILITY TO PARSE PARAMETERS mean: Optional[float] = None sd: Optional[float] = None a: Optional[float] = None b: Optional[float] = None - @model_validator(mode='before') + @model_validator(mode="before") def check_distr(cls, values): - distr = values.get('distribution') - value = values.get('value') - mean = values.get('mean') - sd = values.get('sd') - a = values.get('a') - b = values.get('b') - if distr != 'fixed': + distr = values.get("distribution") + value = values.get("value") + mean = values.get("mean") + sd = values.get("sd") + a = values.get("a") + b = values.get("b") + if distr != "fixed": if not mean and not sd: - raise ValueError('Error in value: mean and sd must be provided for non-fixed distributions') - if distr == 'truncnorm' and not a and not b: - raise ValueError('Error in value: a and b must be provided for truncated normal distributions') - if distr == 'fixed' and not value: - raise ValueError('Error in value: value must be provided for fixed distributions') + raise ValueError( + "Error in value: mean and sd must be provided for non-fixed distributions" + ) + if distr == "truncnorm" and not a and not b: + raise ValueError( + "Error in value: a and b must be provided for truncated normal distributions" + ) + if distr == "fixed" and not value: + raise ValueError( + "Error in value: value must be provided for fixed distributions" + ) return values + class BaseParameterConfig(BaseModel): value: Optional[ValueConfig] = None modifier_parameter: Optional[str] = None - name: Optional[str] = None # this is only for outcomes, to build outcome_prevalence_name (how to restrict this?) + name: Optional[str] = ( + None # this is only for outcomes, to build outcome_prevalence_name (how to restrict this?) + ) + class SeirParameterConfig(BaseParameterConfig): value: Optional[ValueConfig] = None - stacked_modifier_method: Annotated[str, AfterValidator(partial(allowed_values, values=['sum', 'product', 'reduction_product']))] = None + stacked_modifier_method: Annotated[ + str, + AfterValidator( + partial(allowed_values, values=["sum", "product", "reduction_product"]) + ), + ] = None rolling_mean_windows: Optional[float] = None timeseries: Optional[str] = None - @model_validator(mode='before') + @model_validator(mode="before") def which_value(cls, values): - value = values.get('value') is not None - timeseries = values.get('timeseries') is not None + value = values.get("value") is not None + timeseries = values.get("timeseries") is not None if value and timeseries: - raise ValueError('Error in seir::parameters: your parameter is both a timeseries and a value, please choose one') + raise ValueError( + "Error in seir::parameters: your parameter is both a timeseries and a value, please choose one" + ) return values - - -class TransitionConfig(BaseModel): + + +class TransitionConfig(BaseModel): # !! sometimes these are lists of lists and sometimes they are lists... how to deal with this? source: List[List[str]] destination: List[List[str]] @@ -142,11 +227,15 @@ class TransitionConfig(BaseModel): proportion_exponent: List[List[str]] proportional_to: List[str] + class SeirConfig(BaseModel): - integration: IntegrationConfig # is this Optional? - parameters: Dict[str, SeirParameterConfig] # there was a previous issue that gempyor doesn't work if there are no parameters (eg if just numbers are used in the transitions) - do we want to get around this? + integration: IntegrationConfig # is this Optional? + parameters: Dict[ + str, SeirParameterConfig + ] # there was a previous issue that gempyor doesn't work if there are no parameters (eg if just numbers are used in the transitions) - do we want to get around this? transitions: List[TransitionConfig] + class SinglePeriodModifierConfig(BaseModel): method: Literal["SinglePeriodModifier"] parameter: str @@ -157,15 +246,18 @@ class SinglePeriodModifierConfig(BaseModel): value: ValueConfig perturbation: Optional[ValueConfig] = None + class MultiPeriodDatesConfig(BaseModel): start_date: date end_date: date - + + class MultiPeriodGroupsConfig(BaseModel): subpop: List[str] subpop_groups: Optional[str] = None periods: List[MultiPeriodDatesConfig] + class MultiPeriodModifierConfig(BaseModel): method: Literal["MultiPeriodModifier"] parameter: str @@ -173,37 +265,47 @@ class MultiPeriodModifierConfig(BaseModel): value: ValueConfig perturbation: Optional[ValueConfig] = None + class StackedModifierConfig(BaseModel): method: Literal["StackedModifier"] modifiers: List[str] + class ModifiersConfig(BaseModel): scenarios: List[str] modifiers: Dict[str, Any] - + @field_validator("modifiers") def validate_data_dict(cls, value: Dict[str, Any]) -> Dict[str, Any]: errors = [] for key, entry in value.items(): method = entry.get("method") - if method not in {"SinglePeriodModifier", "MultiPeriodModifier", "StackedModifier"}: + if method not in { + "SinglePeriodModifier", + "MultiPeriodModifier", + "StackedModifier", + }: errors.append(f"Invalid modifier method: {method}") if errors: raise ValueError("Errors in modifiers:\n" + "\n".join(errors)) return value -class SourceConfig(BaseModel): # set up only for incidence or prevalence. Can this be any name? i don't think so atm +class SourceConfig( + BaseModel +): # set up only for incidence or prevalence. Can this be any name? i don't think so atm incidence: Dict[str, str] = None - prevalence: Dict[str, str] = None + prevalence: Dict[str, str] = None # note: these dictionaries have to have compartment names... more complicated to set this up - @model_validator(mode='before') + @model_validator(mode="before") def which_source(cls, values): - incidence = values.get('incidence') - prevalence = values.get('prevalence') + incidence = values.get("incidence") + prevalence = values.get("prevalence") if incidence and prevalence: - raise ValueError('Error in outcomes::source. Can only be incidence or prevalence, not both.') + raise ValueError( + "Error in outcomes::source. Can only be incidence or prevalence, not both." + ) return values # @model_validator(mode='before') # DOES NOT WORK @@ -214,12 +316,13 @@ def which_source(cls, values): # source_names.append(key) # return source_names # Access keys using a loop + class DelayFrameConfig(BaseModel): source: Optional[SourceConfig] = None probability: Optional[BaseParameterConfig] = None delay: Optional[BaseParameterConfig] = None duration: Optional[BaseParameterConfig] = None - sum: Optional[List[str]] = None # only for sums of other outcomes + sum: Optional[List[str]] = None # only for sums of other outcomes # @validator("sum") # def validate_sum_elements(cls, value: Optional[List[str]]) -> Optional[List[str]]: @@ -233,65 +336,96 @@ class DelayFrameConfig(BaseModel): # return value # note: ^^ this doesn't work yet because it needs to somehow be a level above? to access all OTHER source names - @model_validator(mode='before') + @model_validator(mode="before") def check_outcome_type(cls, values): - sum_present = values.get('sum') is not None - source_present = values.get('source') is not None + sum_present = values.get("sum") is not None + source_present = values.get("source") is not None if sum_present and source_present: - raise ValueError(f"Error in outcome: Both 'sum' and 'source' are present. Choose one.") + raise ValueError( + f"Error in outcome: Both 'sum' and 'source' are present. Choose one." + ) elif not sum_present and not source_present: - raise ValueError(f"Error in outcome: Neither 'sum' nor 'source' is present. Choose one.") + raise ValueError( + f"Error in outcome: Neither 'sum' nor 'source' is present. Choose one." + ) return values + class OutcomesConfig(BaseModel): - method: Literal["delayframe"] # Is this required? I don't see it anywhere in the gempyor code + method: Literal[ + "delayframe" + ] # Is this required? I don't see it anywhere in the gempyor code param_from_file: Optional[bool] = None param_subpop_file: Optional[str] = None outcomes: Dict[str, DelayFrameConfig] - - @model_validator(mode='before') + + @model_validator(mode="before") def check_paramfromfile_type(cls, values): - param_from_file = values.get('param_from_file') is not None - param_subpop_file = values.get('param_subpop_file') is not None + param_from_file = values.get("param_from_file") is not None + param_subpop_file = values.get("param_subpop_file") is not None if param_from_file and not param_subpop_file: - raise ValueError(f"Error in outcome: 'param_subpop_file' is required when 'param_from_file' is True") + raise ValueError( + f"Error in outcome: 'param_subpop_file' is required when 'param_from_file' is True" + ) return values + class ResampleConfig(BaseModel): aggregator: Optional[str] = None freq: Optional[str] = None skipna: Optional[bool] = False + class LikelihoodParams(BaseModel): scale: float # are there other options here? + class LikelihoodReg(BaseModel): - name: str + name: str + class LikelihoodConfig(BaseModel): - dist: Annotated[str, AfterValidator(partial(allowed_values, values=['pois', 'norm', 'norm_cov', 'nbinom', 'rmse', 'absolute_error']))] = None + dist: Annotated[ + str, + AfterValidator( + partial( + allowed_values, + values=["pois", "norm", "norm_cov", "nbinom", "rmse", "absolute_error"], + ) + ), + ] = None params: Optional[LikelihoodParams] = None + class StatisticsConfig(BaseModel): name: str sim_var: str data_var: str regularize: Optional[LikelihoodReg] = None resample: Optional[ResampleConfig] = None - scale: Optional[float] = None # is scale here or at likelihood level? - zero_to_one: Optional[bool] = False # is this the same as add_one? remove_na? + scale: Optional[float] = None # is scale here or at likelihood level? + zero_to_one: Optional[bool] = False # is this the same as add_one? remove_na? likelihood: LikelihoodConfig + class InferenceConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['emcee', 'default', 'classical']))] = 'default' # for now - i can only see emcee as an option here, otherwise ignored in classical - need to add these options - iterations_per_slot: Optional[int] # i think this is optional because it is also set in command line?? - do_inference: bool + method: Annotated[ + str, + AfterValidator( + partial(allowed_values, values=["emcee", "default", "classical"]) + ), + ] = "default" # for now - i can only see emcee as an option here, otherwise ignored in classical - need to add these options + iterations_per_slot: Optional[ + int + ] # i think this is optional because it is also set in command line?? + do_inference: bool gt_data_path: str statistics: Dict[str, StatisticsConfig] + class CheckConfig(BaseModel): name: str setup_name: Optional[str] = None @@ -312,32 +446,36 @@ class CheckConfig(BaseModel): outcome_modifiers: Optional[ModifiersConfig] = None inference: Optional[InferenceConfig] = None -# add validator for if modifiers exist but seir/outcomes do not - -# there is an error in the one below - @model_validator(mode='before') + # add validator for if modifiers exist but seir/outcomes do not + + # there is an error in the one below + @model_validator(mode="before") def verify_inference(cls, values): - inference_present = values.get('inference') is not None - start_date_groundtruth = values.get('start_date_groundtruth') is not None + inference_present = values.get("inference") is not None + start_date_groundtruth = values.get("start_date_groundtruth") is not None if inference_present and not start_date_groundtruth: - raise ValueError('Inference mode is enabled but no groundtruth dates are provided') + raise ValueError( + "Inference mode is enabled but no groundtruth dates are provided" + ) elif start_date_groundtruth and not inference_present: - raise ValueError('Groundtruth dates are provided but inference mode is not enabled') + raise ValueError( + "Groundtruth dates are provided but inference mode is not enabled" + ) return values - - @model_validator(mode='before') + + @model_validator(mode="before") def check_dates(cls, values): - start_date = values.get('start_date') - end_date = values.get('end_date') + start_date = values.get("start_date") + end_date = values.get("end_date") if start_date and end_date: if end_date <= start_date: - raise ValueError('end_date must be greater than start_date') + raise ValueError("end_date must be greater than start_date") return values - - @model_validator(mode='before') + + @model_validator(mode="before") def init_or_seed(cls, values): - init = values.get('initial_conditions') - seed = values.get('seeding') + init = values.get("initial_conditions") + seed = values.get("seeding") if not init or seed: - raise ValueError('either initial_conditions or seeding must be provided') + raise ValueError("either initial_conditions or seeding must be provided") return values diff --git a/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py b/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py index a847a5f90..f9ecba633 100644 --- a/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py +++ b/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py @@ -36,7 +36,9 @@ ) seeding_data = modinf.seeding.get_from_config(sim_id=100, modinf=modinf) -initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) +initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf +) mobility_subpop_indices = modinf.mobility.indices mobility_data_indices = modinf.mobility.indptr @@ -48,7 +50,9 @@ modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -81,7 +85,12 @@ True, ) df = seir.states2Df(modinf, states) -assert df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] > 1 +assert ( + df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[ + str(modinf.tf), "20002" + ] + > 1 +) print(df) ts = df cp = "R" diff --git a/flepimop/gempyor_pkg/src/gempyor/dev/steps.py b/flepimop/gempyor_pkg/src/gempyor/dev/steps.py index 43066e5ee..0213fb912 100644 --- a/flepimop/gempyor_pkg/src/gempyor/dev/steps.py +++ b/flepimop/gempyor_pkg/src/gempyor/dev/steps.py @@ -53,7 +53,11 @@ def ode_integration( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -65,7 +69,9 @@ def ode_integration( def rhs(t, x, today): print("rhs.t", t) states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -80,52 +86,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -142,9 +168,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -168,18 +200,24 @@ def rhs(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -257,7 +295,11 @@ def rk4_integration1( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -267,7 +309,9 @@ def rk4_integration1( def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -282,52 +326,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -344,9 +408,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -380,18 +450,24 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -447,7 +523,11 @@ def rk4_integration2( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -458,7 +538,9 @@ def rk4_integration2( @jit(nopython=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -473,52 +555,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -535,9 +637,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -572,18 +680,24 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -644,7 +758,11 @@ def rk4_integration3( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -655,7 +773,9 @@ def rk4_integration3( @jit(nopython=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -670,52 +790,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -732,9 +872,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -753,16 +899,18 @@ def rk4_integrate(t, x, today): @jit(nopython=True) def day_wrapper_rk4(today, states_next): x_ = np.zeros((2, ncompartments, nspatial_nodes)) - for seeding_instance_idx in range(day_start_idx_dict[today], day_start_idx_dict[today + 1]): + for seeding_instance_idx in range( + day_start_idx_dict[today], day_start_idx_dict[today + 1] + ): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_subpops_dict[seeding_instance_idx] seeding_sources = seeding_sources_dict[seeding_instance_idx] seeding_destinations = seeding_destinations_dict[seeding_instance_idx] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts # ADD TO cumulative, this is debatable, @@ -838,7 +986,11 @@ def rk4_integration4( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -849,7 +1001,9 @@ def rk4_integration4( @jit(nopython=True) # , fastmath=True, parallel=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -864,52 +1018,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -926,9 +1100,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -963,18 +1143,24 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -1031,7 +1217,11 @@ def rk4_integration5( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -1058,18 +1248,24 @@ def rk4_integration5( this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -1092,7 +1288,9 @@ def rk4_integration5( x = x_ + kx[i - 1] * rk_coefs[i] states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -1104,48 +1302,72 @@ def rk4_integration5( relevant_number_in_comp = np.zeros((nspatial_nodes)) relevant_exponent = np.ones((nspatial_nodes)) for proportion_sum_index in range( - proportion_info[proportion_sum_starts_col][proportion_index], + proportion_info[proportion_sum_starts_col][ + proportion_index + ], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] - # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][ - today + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] ] + # exponents should not be a proportion, since we don't sum them over sum compartments + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][ + proportion_index + ] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 - ) == transitions[transition_proportion_stop_col][transition_index] + transitions[transition_proportion_start_col][ + transition_index + ] + + 1 + ) == transitions[transition_proportion_stop_col][ + transition_index + ] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 + - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][ - spatial_node - ] + * parameters[ + transitions[transition_rate_col][ + transition_index + ] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] rate_change_compartment = proportion_change_compartment @@ -1153,11 +1375,16 @@ def rk4_integration5( relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] ) - rate_change_compartment /= population[visiting_compartment] + rate_change_compartment /= population[ + visiting_compartment + ] rate_change_compartment *= parameters[ transitions[transition_rate_col][transition_index] ][today][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + total_rate[spatial_node] *= ( + rate_keep_compartment + + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -1169,13 +1396,19 @@ def rk4_integration5( # ) # else: if True: - number_move = source_number * total_rate # * compound_adjusted_rate + number_move = ( + source_number * total_rate + ) # * compound_adjusted_rate # for spatial_node in range(nspatial_nodes): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move states_diff[ 1, transitions[transition_destination_col][transition_index], : ] += number_move # Cumumlative @@ -1234,7 +1467,11 @@ def rk4_integration2_smart( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -1248,7 +1485,9 @@ def rhs(t, x): if (today) > ndays: today = ndays - 1 states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -1263,52 +1502,72 @@ def rhs(t, x): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -1325,9 +1584,15 @@ def rhs(t, x): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -1372,20 +1637,34 @@ def rk4_integrate(today, x): seeding_data["day_start_idx"][today + 1], ): this_seeding_amounts = seeding_amounts[seeding_instance_idx] - seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] - seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_subpops = seeding_data["seeding_subpops"][ + seeding_instance_idx + ] + seeding_sources = seeding_data["seeding_sources"][ + seeding_instance_idx + ] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] - states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( + states_next[seeding_sources][ + seeding_subpops + ] -= this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * ( states_next[seeding_sources][seeding_subpops] > 0 ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -1450,9 +1729,12 @@ def rk4_integrate(today, x): ## Return "UniTuple(float64[:, :, :], 2) (" ## return states and cumlative states, both [ ndays x ncompartments x nspatial_nodes ] ## Dimensions - "int32," "int32," "int32," ## ncompartments ## nspatial_nodes ## Number of days + "int32," + "int32," + "int32," ## ncompartments ## nspatial_nodes ## Number of days ## Parameters - "float64[:, :, :]," "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt + "float64[:, :, :]," + "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt ## Transitions "int64[:, :]," ## transitions [ [source, destination, proportion_start, proportion_stop, rate] x ntransitions ] "int64[:, :]," ## proportions_info [ [sum_starts, sum_stops, exponent] x ntransition_proportions ] @@ -1504,7 +1786,11 @@ def rk4_integration_aot( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -1515,7 +1801,9 @@ def rk4_integration_aot( def rhs(t, x, today): # states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] states_current = x[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -1530,52 +1818,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -1592,9 +1900,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return states_diff # return a 1D vector @@ -1628,18 +1942,24 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape diff --git a/flepimop/gempyor_pkg/src/gempyor/inference.py b/flepimop/gempyor_pkg/src/gempyor/inference.py index 81b1c22d0..5cc6bb79c 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference.py @@ -82,11 +82,21 @@ def simulation_atomic( np.random.seed(int.from_bytes(os.urandom(4), byteorder="little")) random_id = np.random.randint(0, 1e8) - npi_seir = seir.build_npi_SEIR(modinf=modinf, load_ID=False, sim_id2load=None, config=config, bypass_DF=snpi_df_in) + npi_seir = seir.build_npi_SEIR( + modinf=modinf, + load_ID=False, + sim_id2load=None, + config=config, + bypass_DF=snpi_df_in, + ) if modinf.npi_config_outcomes: npi_outcomes = outcomes.build_outcome_modifiers( - modinf=modinf, load_ID=False, sim_id2load=None, config=config, bypass_DF=hnpi_df_in + modinf=modinf, + load_ID=False, + sim_id2load=None, + config=config, + bypass_DF=hnpi_df_in, ) else: npi_outcomes = None @@ -94,10 +104,14 @@ def simulation_atomic( # reduce them parameters = modinf.parameters.parameters_reduce(p_draw, npi_seir) # Parse them - parsed_parameters = modinf.compartments.parse_parameters(parameters, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + parameters, modinf.parameters.pnames, unique_strings + ) # Convert the seeding data dictionnary to a numba dictionnary - seeding_data_nbdict = nb.typed.Dict.empty(key_type=nb.types.unicode_type, value_type=nb.types.int64[:]) + seeding_data_nbdict = nb.typed.Dict.empty( + key_type=nb.types.unicode_type, value_type=nb.types.int64[:] + ) for k, v in seeding_data.items(): seeding_data_nbdict[k] = np.array(v, dtype=np.int64) @@ -151,7 +165,9 @@ def get_static_arguments(modinf): ) = modinf.compartments.get_transition_array() outcomes_parameters = outcomes.read_parameters_from_config(modinf) - npi_seir = seir.build_npi_SEIR(modinf=modinf, load_ID=False, sim_id2load=None, config=config) + npi_seir = seir.build_npi_SEIR( + modinf=modinf, load_ID=False, sim_id2load=None, config=config + ) if modinf.npi_config_outcomes: npi_outcomes = outcomes.build_outcome_modifiers( modinf=modinf, @@ -162,15 +178,23 @@ def get_static_arguments(modinf): else: npi_outcomes = None - p_draw = modinf.parameters.parameters_quick_draw(n_days=modinf.n_days, nsubpops=modinf.nsubpops) + p_draw = modinf.parameters.parameters_quick_draw( + n_days=modinf.n_days, nsubpops=modinf.nsubpops + ) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=0, modinf=modinf) - seeding_data, seeding_amounts = modinf.seeding.get_from_config(sim_id=0, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=0, modinf=modinf + ) + seeding_data, seeding_amounts = modinf.seeding.get_from_config( + sim_id=0, modinf=modinf + ) # reduce them parameters = modinf.parameters.parameters_reduce(p_draw, npi_seir) # Parse them - parsed_parameters = modinf.compartments.parse_parameters(parameters, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + parameters, modinf.parameters.pnames, unique_strings + ) if real_simulation: states = seir.steps_SEIR( @@ -190,7 +214,10 @@ def get_static_arguments(modinf): compartment_df = modinf.compartments.get_compartments_explicitDF() # Iterate over columns of the DataFrame and populate the dictionary for column in compartment_df.columns: - compartment_coords[column] = ("compartment", compartment_df[column].tolist()) + compartment_coords[column] = ( + "compartment", + compartment_df[column].tolist(), + ) coords = dict( date=pd.date_range(modinf.ti, modinf.tf, freq="D"), @@ -198,7 +225,9 @@ def get_static_arguments(modinf): subpop=modinf.subpop_struct.subpop_names, ) - zeros = np.zeros((len(coords["date"]), len(coords["mc_name"][1]), len(coords["subpop"]))) + zeros = np.zeros( + (len(coords["date"]), len(coords["mc_name"][1]), len(coords["subpop"])) + ) states = xr.Dataset( data_vars=dict( prevalence=(["date", "compartment", "subpop"], zeros), @@ -256,12 +285,16 @@ def autodetect_scenarios(config): seir_modifiers_scenarios = None if config["seir_modifiers"].exists(): if config["seir_modifiers"]["scenarios"].exists(): - seir_modifiers_scenarios = config["seir_modifiers"]["scenarios"].as_str_seq() + seir_modifiers_scenarios = config["seir_modifiers"][ + "scenarios" + ].as_str_seq() outcome_modifiers_scenarios = None if config["outcomes"].exists() and config["outcome_modifiers"].exists(): if config["outcome_modifiers"]["scenarios"].exists(): - outcome_modifiers_scenarios = config["outcome_modifiers"]["scenarios"].as_str_seq() + outcome_modifiers_scenarios = config["outcome_modifiers"][ + "scenarios" + ].as_str_seq() outcome_modifiers_scenarios = as_list(outcome_modifiers_scenarios) seir_modifiers_scenarios = as_list(seir_modifiers_scenarios) @@ -275,41 +308,41 @@ def autodetect_scenarios(config): return seir_modifiers_scenarios[0], outcome_modifiers_scenarios[0] + # rewrite the get log loss functions as single functions, not in a class. This is not faster # def get_logloss(proposal, inferpar, logloss, static_sim_arguments, modinf, silent=True, save=False): # if not inferpar.check_in_bound(proposal=proposal): # if not silent: # print("OUT OF BOUND!!") # return -np.inf, -np.inf, -np.inf -# +# # snpi_df_mod, hnpi_df_mod = inferpar.inject_proposal( # proposal=proposal, # snpi_df=static_sim_arguments["snpi_df_ref"], # hnpi_df=static_sim_arguments["hnpi_df_ref"], # ) -# +# # ss = copy.deepcopy(static_sim_arguments) # ss["snpi_df_in"] = snpi_df_mod # ss["hnpi_df_in"] = hnpi_df_mod # del ss["snpi_df_ref"] # del ss["hnpi_df_ref"] -# +# # outcomes_df = simulation_atomic(**ss, modinf=modinf, save=save) -# +# # ll_total, logloss, regularizations = logloss.compute_logloss( # model_df=outcomes_df, subpop_names=modinf.subpop_struct.subpop_names # ) # if not silent: # print(f"llik is {ll_total}") -# +# # return ll_total, logloss, regularizations -# +# # def get_logloss_as_single_number(proposal, inferpar, logloss, static_sim_arguments, modinf, silent=True, save=False): # ll_total, logloss, regularizations = get_logloss(proposal, inferpar, logloss, static_sim_arguments, modinf, silent, save) # return ll_total - class GempyorInference: def __init__( self, @@ -333,12 +366,20 @@ def __init__( config.set_file(os.path.join(path_prefix, config_filepath)) - self.seir_modifiers_scenario, self.outcome_modifiers_scenario = autodetect_scenarios(config) + self.seir_modifiers_scenario, self.outcome_modifiers_scenario = ( + autodetect_scenarios(config) + ) if run_id is None: run_id = file_paths.run_id() if prefix is None: - prefix = config["name"].get() + f"_{self.seir_modifiers_scenario}_{self.outcome_modifiers_scenario}" + "/" + run_id + "/" + prefix = ( + config["name"].get() + + f"_{self.seir_modifiers_scenario}_{self.outcome_modifiers_scenario}" + + "/" + + run_id + + "/" + ) in_run_id = run_id if out_run_id is None: out_run_id = in_run_id @@ -387,7 +428,8 @@ def __init__( self.do_inference = True self.inference_method = "emcee" self.inferpar = inference_parameter.InferenceParameters( - global_config=config, subpop_names=self.modinf.subpop_struct.subpop_names + global_config=config, + subpop_names=self.modinf.subpop_struct.subpop_names, ) self.logloss = logloss.LogLoss( inference_config=config["inference"], @@ -412,7 +454,14 @@ def set_save(self, save): def get_all_sim_arguments(self): # inferpar, logloss, static_sim_arguments, modinf, proposal, silent, save - return [self.inferpar, self.logloss, self.static_sim_arguments, self.modinf, self.silent, self.save] + return [ + self.inferpar, + self.logloss, + self.static_sim_arguments, + self.modinf, + self.silent, + self.save, + ] def get_logloss(self, proposal): if not self.inferpar.check_in_bound(proposal=proposal): @@ -479,11 +528,15 @@ def update_run_id(self, new_run_id, new_out_run_id=None): else: self.modinf.out_run_id = new_out_run_id - def one_simulation_legacy(self, sim_id2write: int, load_ID: bool = False, sim_id2load: int = None): + def one_simulation_legacy( + self, sim_id2write: int, load_ID: bool = False, sim_id2load: int = None + ): sim_id2write = int(sim_id2write) if load_ID: sim_id2load = int(sim_id2load) - with Timer(f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}"): + with Timer( + f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}" + ): with Timer("onerun_SEIR"): seir.onerun_SEIR( sim_id2write=sim_id2write, @@ -533,16 +586,31 @@ def one_simulation( sim_id2load = int(sim_id2load) self.lastsim_sim_id2load = sim_id2load - with Timer(f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}"): + with Timer( + f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}" + ): if not self.already_built and self.modinf.outcomes_config is not None: - self.outcomes_parameters = outcomes.read_parameters_from_config(self.modinf) + self.outcomes_parameters = outcomes.read_parameters_from_config( + self.modinf + ) npi_outcomes = None if parallel: with Timer("//things"): - with ProcessPoolExecutor(max_workers=max(mp.cpu_count(), 3)) as executor: - ret_seir = executor.submit(seir.build_npi_SEIR, self.modinf, load_ID, sim_id2load, config) - if self.modinf.outcomes_config is not None and self.modinf.npi_config_outcomes: + with ProcessPoolExecutor( + max_workers=max(mp.cpu_count(), 3) + ) as executor: + ret_seir = executor.submit( + seir.build_npi_SEIR, + self.modinf, + load_ID, + sim_id2load, + config, + ) + if ( + self.modinf.outcomes_config is not None + and self.modinf.npi_config_outcomes + ): ret_outcomes = executor.submit( outcomes.build_outcome_modifiers, self.modinf, @@ -551,7 +619,9 @@ def one_simulation( config, ) if not self.already_built: - ret_comparments = executor.submit(self.modinf.compartments.get_transition_array) + ret_comparments = executor.submit( + self.modinf.compartments.get_transition_array + ) # print("expections:", ret_seir.exception(), ret_outcomes.exception(), ret_comparments.exception()) @@ -564,15 +634,24 @@ def one_simulation( ) = ret_comparments.result() self.already_built = True npi_seir = ret_seir.result() - if self.modinf.outcomes_config is not None and self.modinf.npi_config_outcomes: + if ( + self.modinf.outcomes_config is not None + and self.modinf.npi_config_outcomes + ): npi_outcomes = ret_outcomes.result() else: if not self.already_built: self.build_structure() npi_seir = seir.build_npi_SEIR( - modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + modinf=self.modinf, + load_ID=load_ID, + sim_id2load=sim_id2load, + config=config, ) - if self.modinf.outcomes_config is not None and self.modinf.npi_config_outcomes: + if ( + self.modinf.outcomes_config is not None + and self.modinf.npi_config_outcomes + ): npi_outcomes = outcomes.build_outcome_modifiers( modinf=self.modinf, load_ID=load_ID, @@ -585,7 +664,9 @@ def one_simulation( with Timer("SEIR.parameters"): # Draw or load parameters - p_draw = self.get_seir_parameters(load_ID=load_ID, sim_id2load=sim_id2load) + p_draw = self.get_seir_parameters( + load_ID=load_ID, sim_id2load=sim_id2load + ) # reduce them parameters = self.modinf.parameters.parameters_reduce(p_draw, npi_seir) # Parse them @@ -598,8 +679,12 @@ def one_simulation( with Timer("onerun_SEIR.seeding"): if load_ID: - initial_conditions = self.modinf.initial_conditions.get_from_file(sim_id2load, modinf=self.modinf) - seeding_data, seeding_amounts = self.modinf.seeding.get_from_file(sim_id2load, modinf=self.modinf) + initial_conditions = self.modinf.initial_conditions.get_from_file( + sim_id2load, modinf=self.modinf + ) + seeding_data, seeding_amounts = self.modinf.seeding.get_from_file( + sim_id2load, modinf=self.modinf + ) else: initial_conditions = self.modinf.initial_conditions.get_from_config( sim_id2write, modinf=self.modinf @@ -645,7 +730,7 @@ def one_simulation( parameters=self.outcomes_parameters, loaded_values=loaded_values, npi=npi_outcomes, - bypass_seir_xr=states + bypass_seir_xr=states, ) self.lastsim_outcomes_df = outcomes_df self.lastsim_hpar_df = hpar_df @@ -660,14 +745,18 @@ def one_simulation( ) return 0 - def plot_transition_graph(self, output_file="transition_graph", source_filters=[], destination_filters=[]): + def plot_transition_graph( + self, output_file="transition_graph", source_filters=[], destination_filters=[] + ): self.modinf.compartments.plot( output_file=output_file, source_filters=source_filters, destination_filters=destination_filters, ) - def get_outcome_npi(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_outcome_npi( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): npi_outcomes = None if self.modinf.npi_config_outcomes: npi_outcomes = outcomes.build_outcome_modifiers( @@ -680,7 +769,9 @@ def get_outcome_npi(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypas ) return npi_outcomes - def get_seir_npi(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_seir_npi( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): npi_seir = seir.build_npi_SEIR( modinf=self.modinf, load_ID=load_ID, @@ -691,7 +782,9 @@ def get_seir_npi(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_F ) return npi_seir - def get_seir_parameters(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_seir_parameters( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): param_df = None if bypass_DF is not None: param_df = bypass_DF @@ -712,7 +805,9 @@ def get_seir_parameters(self, load_ID=False, sim_id2load=None, bypass_DF=None, b ) return p_draw - def get_seir_parametersDF(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_seir_parametersDF( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): p_draw = self.get_seir_parameters( load_ID=load_ID, sim_id2load=sim_id2load, @@ -767,7 +862,9 @@ def get_parsed_parameters_seir( if not self.already_built: self.build_structure() - npi_seir = seir.build_npi_SEIR(modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi_seir = seir.build_npi_SEIR( + modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) p_draw = self.get_seir_parameters(load_ID=load_ID, sim_id2load=sim_id2load) parameters = self.modinf.parameters.parameters_reduce(p_draw, npi_seir) @@ -784,7 +881,9 @@ def get_reduced_parameters_seir( # bypass_DF=None, # bypass_FN=None, ): - npi_seir = seir.build_npi_SEIR(modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi_seir = seir.build_npi_SEIR( + modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) p_draw = self.get_seir_parameters(load_ID=load_ID, sim_id2load=sim_id2load) parameters = self.modinf.parameters.parameters_reduce(p_draw, npi_seir) @@ -805,7 +904,9 @@ def paramred_parallel(run_spec, snpi_fn): seir_modifiers_scenario="inference", # NPIs scenario to use outcome_modifiers_scenario="med", # Outcome scenario to use stoch_traj_flag=False, - path_prefix=run_spec["geodata"], # prefix where to find the folder indicated in subpop_setup$ + path_prefix=run_spec[ + "geodata" + ], # prefix where to find the folder indicated in subpop_setup$ ) snpi = pq.read_table(snpi_fn).to_pandas() @@ -816,7 +917,9 @@ def paramred_parallel(run_spec, snpi_fn): params_draw_arr = gempyor_inference.get_seir_parameters( bypass_FN=snpi_fn.replace("snpi", "spar") ) # could also accept (load_ID=True, sim_id2load=XXX) or (bypass_DF=) or (bypass_FN=) - param_reduc_from = gempyor_inference.get_seir_parameter_reduced(npi_seir=npi_seir, p_draw=params_draw_arr) + param_reduc_from = gempyor_inference.get_seir_parameter_reduced( + npi_seir=npi_seir, p_draw=params_draw_arr + ) return param_reduc_from @@ -831,7 +934,9 @@ def paramred_parallel_config(run_spec, dummy): seir_modifiers_scenario="inference", # NPIs scenario to use outcome_modifiers_scenario="med", # Outcome scenario to use stoch_traj_flag=False, - path_prefix=run_spec["geodata"], # prefix where to find the folder indicated in subpop_setup$ + path_prefix=run_spec[ + "geodata" + ], # prefix where to find the folder indicated in subpop_setup$ ) npi_seir = gempyor_inference.get_seir_npi() @@ -839,6 +944,8 @@ def paramred_parallel_config(run_spec, dummy): params_draw_arr = ( gempyor_inference.get_seir_parameters() ) # could also accept (load_ID=True, sim_id2load=XXX) or (bypass_DF=) or (bypass_FN=) - param_reduc_from = gempyor_inference.get_seir_parameter_reduced(npi_seir=npi_seir, p_draw=params_draw_arr) + param_reduc_from = gempyor_inference.get_seir_parameter_reduced( + npi_seir=npi_seir, p_draw=params_draw_arr + ) return param_reduc_from diff --git a/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py b/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py index 8f1b0fc53..30cab0bba 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py @@ -36,9 +36,14 @@ def add_modifier(self, pname, ptype, parameter_config, subpops): """ # identify spatial group affected_subpops = set(subpops) - if parameter_config["subpop"].exists() and parameter_config["subpop"].get() != "all": + if ( + parameter_config["subpop"].exists() + and parameter_config["subpop"].get() != "all" + ): affected_subpops = {str(n.get()) for n in parameter_config["subpop"]} - spatial_groups = NPI.helpers.get_spatial_groups(parameter_config, list(affected_subpops)) + spatial_groups = NPI.helpers.get_spatial_groups( + parameter_config, list(affected_subpops) + ) # ungrouped subpop (all affected subpop by default) have one parameter per subpop if spatial_groups["ungrouped"]: @@ -87,11 +92,15 @@ def build_from_config(self, global_config, subpop_names): for config_part in ["seir_modifiers", "outcome_modifiers"]: if global_config[config_part].exists(): for npi in global_config[config_part]["modifiers"].get(): - if global_config[config_part]["modifiers"][npi]["perturbation"].exists(): + if global_config[config_part]["modifiers"][npi][ + "perturbation" + ].exists(): self.add_modifier( pname=npi, ptype=config_part, - parameter_config=global_config[config_part]["modifiers"][npi], + parameter_config=global_config[config_part]["modifiers"][ + npi + ], subpops=subpop_names, ) @@ -156,7 +165,10 @@ def check_in_bound(self, proposal) -> bool: Returns: bool: True if the proposal is within bounds, False otherwise. """ - if self.hit_lbs(proposal=proposal).any() or self.hit_ubs(proposal=proposal).any(): + if ( + self.hit_lbs(proposal=proposal).any() + or self.hit_ubs(proposal=proposal).any() + ): return False return True diff --git a/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py b/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py index 554f3910c..eef32531f 100644 --- a/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py +++ b/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py @@ -45,19 +45,30 @@ def __init__( if self.initial_conditions_config is not None: if "ignore_population_checks" in self.initial_conditions_config.keys(): - self.ignore_population_checks = self.initial_conditions_config["ignore_population_checks"].get(bool) + self.ignore_population_checks = self.initial_conditions_config[ + "ignore_population_checks" + ].get(bool) if "allow_missing_subpops" in self.initial_conditions_config.keys(): - self.allow_missing_subpops = self.initial_conditions_config["allow_missing_subpops"].get(bool) + self.allow_missing_subpops = self.initial_conditions_config[ + "allow_missing_subpops" + ].get(bool) if "allow_missing_compartments" in self.initial_conditions_config.keys(): - self.allow_missing_compartments = self.initial_conditions_config["allow_missing_compartments"].get(bool) + self.allow_missing_compartments = self.initial_conditions_config[ + "allow_missing_compartments" + ].get(bool) # TODO: add check, this option onlywork with tidy dataframe if "proportional" in self.initial_conditions_config.keys(): - self.proportional_ic = self.initial_conditions_config["proportional"].get(bool) + self.proportional_ic = self.initial_conditions_config[ + "proportional" + ].get(bool) def get_from_config(self, sim_id: int, modinf) -> np.ndarray: method = "Default" - if self.initial_conditions_config is not None and "method" in self.initial_conditions_config.keys(): + if ( + self.initial_conditions_config is not None + and "method" in self.initial_conditions_config.keys() + ): method = self.initial_conditions_config["method"].as_str() if method == "Default": @@ -66,13 +77,20 @@ def get_from_config(self, sim_id: int, modinf) -> np.ndarray: y0[0, :] = modinf.subpop_pop return y0 # we finish here: no rest and not proportionallity applies - if method == "SetInitialConditions" or method == "SetInitialConditionsFolderDraw": + if ( + method == "SetInitialConditions" + or method == "SetInitialConditionsFolderDraw" + ): # TODO Think about - Does not support the new way of doing compartment indexing if method == "SetInitialConditionsFolderDraw": - ic_df = modinf.read_simID(ftype=self.initial_conditions_config["initial_file_type"], sim_id=sim_id) + ic_df = modinf.read_simID( + ftype=self.initial_conditions_config["initial_file_type"], + sim_id=sim_id, + ) else: ic_df = read_df( - self.path_prefix / self.initial_conditions_config["initial_conditions_file"].get(), + self.path_prefix + / self.initial_conditions_config["initial_conditions_file"].get(), ) y0 = read_initial_condition_from_tidydataframe( ic_df=ic_df, @@ -85,11 +103,13 @@ def get_from_config(self, sim_id: int, modinf) -> np.ndarray: elif method == "InitialConditionsFolderDraw" or method == "FromFile": if method == "InitialConditionsFolderDraw": ic_df = modinf.read_simID( - ftype=self.initial_conditions_config["initial_file_type"].get(), sim_id=sim_id + ftype=self.initial_conditions_config["initial_file_type"].get(), + sim_id=sim_id, ) elif method == "FromFile": ic_df = read_df( - self.path_prefix / self.initial_conditions_config["initial_conditions_file"].get(), + self.path_prefix + / self.initial_conditions_config["initial_conditions_file"].get(), ) y0 = read_initial_condition_from_seir_output( @@ -99,10 +119,14 @@ def get_from_config(self, sim_id: int, modinf) -> np.ndarray: allow_missing_subpops=self.allow_missing_subpops, ) else: - raise NotImplementedError(f"unknown initial conditions method [got: {method}]") + raise NotImplementedError( + f"unknown initial conditions method [got: {method}]" + ) # check that the inputed values sums to the subpop population: - check_population(y0=y0, modinf=modinf, ignore_population_checks=self.ignore_population_checks) + check_population( + y0=y0, modinf=modinf, ignore_population_checks=self.ignore_population_checks + ) return y0 @@ -136,7 +160,11 @@ def check_population(y0, modinf, ignore_population_checks=False): def read_initial_condition_from_tidydataframe( - ic_df, modinf, allow_missing_subpops, allow_missing_compartments, proportional_ic=False + ic_df, + modinf, + allow_missing_subpops, + allow_missing_compartments, + proportional_ic=False, ): rests = [] # Places to allocate the rest of the population y0 = np.zeros((modinf.compartments.compartments.shape[0], modinf.nsubpops)) @@ -145,9 +173,13 @@ def read_initial_condition_from_tidydataframe( states_pl = ic_df[ic_df["subpop"] == pl] for comp_idx, comp_name in modinf.compartments.compartments["name"].items(): if "mc_name" in states_pl.columns: - ic_df_compartment_val = states_pl[states_pl["mc_name"] == comp_name]["amount"] + ic_df_compartment_val = states_pl[ + states_pl["mc_name"] == comp_name + ]["amount"] else: - filters = modinf.compartments.compartments.iloc[comp_idx].drop("name") + filters = modinf.compartments.compartments.iloc[comp_idx].drop( + "name" + ) ic_df_compartment_val = states_pl.copy() for mc_name, mc_value in filters.items(): ic_df_compartment_val = ic_df_compartment_val[ @@ -177,7 +209,9 @@ def read_initial_condition_from_tidydataframe( logger.critical( f"No initial conditions for for subpop {pl}, assuming everyone (n={modinf.subpop_pop[pl_idx]}) in the first metacompartment ({modinf.compartments.compartments['name'].iloc[0]})" ) - raise ValueError("THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy") + raise ValueError( + "THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy" + ) # TODO: this is probably ok but highlighting for consistency if "proportional" in self.initial_conditions_config.keys(): if self.initial_conditions_config["proportional"].get(): @@ -202,7 +236,9 @@ def read_initial_condition_from_tidydataframe( return y0 -def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops, allow_missing_compartments): +def read_initial_condition_from_seir_output( + ic_df, modinf, allow_missing_subpops, allow_missing_compartments +): """ Read the initial conditions from the SEIR output. @@ -227,9 +263,13 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops ic_df["date"] = ic_df["date"].dt.date ic_df["date"] = ic_df["date"].astype(str) - ic_df = ic_df[(ic_df["date"] == str(modinf.ti)) & (ic_df["mc_value_type"] == "prevalence")] + ic_df = ic_df[ + (ic_df["date"] == str(modinf.ti)) & (ic_df["mc_value_type"] == "prevalence") + ] if ic_df.empty: - raise ValueError(f"There is no entry for initial time ti in the provided initial_conditions::states_file.") + raise ValueError( + f"There is no entry for initial time ti in the provided initial_conditions::states_file." + ) y0 = np.zeros((modinf.compartments.compartments.shape[0], modinf.nsubpops)) for comp_idx, comp_name in modinf.compartments.compartments["name"].items(): @@ -239,7 +279,9 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops filters = modinf.compartments.compartments.iloc[comp_idx].drop("name") ic_df_compartment = ic_df.copy() for mc_name, mc_value in filters.items(): - ic_df_compartment = ic_df_compartment[ic_df_compartment["mc_" + mc_name] == mc_value] + ic_df_compartment = ic_df_compartment[ + ic_df_compartment["mc_" + mc_name] == mc_value + ] if len(ic_df_compartment) > 1: # ic_df_compartment = ic_df_compartment.iloc[0] @@ -248,7 +290,9 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops ) elif ic_df_compartment.empty: if allow_missing_compartments: - ic_df_compartment = pd.DataFrame(0, columns=ic_df_compartment.columns, index=[0]) + ic_df_compartment = pd.DataFrame( + 0, columns=ic_df_compartment.columns, index=[0] + ) else: raise ValueError( f"Initial Conditions: Could not set compartment {comp_name} (id: {comp_idx}) in subpop {pl} (id: {pl_idx}). The data from the init file is {ic_df_compartment[pl]}." @@ -262,7 +306,9 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops if pl in ic_df.columns: y0[comp_idx, pl_idx] = float(ic_df_compartment[pl].iloc[0]) elif allow_missing_subpops: - raise ValueError("THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy") + raise ValueError( + "THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy" + ) # TODO this should set the full subpop, not just the 0th commpartment logger.critical( f"No initial conditions for for subpop {pl}, assuming everyone (n={modinf.subpop_pop[pl_idx]}) in the first metacompartments ({modinf.compartments.compartments['name'].iloc[0]})" diff --git a/flepimop/gempyor_pkg/src/gempyor/logloss.py b/flepimop/gempyor_pkg/src/gempyor/logloss.py index 95892dac0..0ed62fad2 100644 --- a/flepimop/gempyor_pkg/src/gempyor/logloss.py +++ b/flepimop/gempyor_pkg/src/gempyor/logloss.py @@ -17,7 +17,13 @@ class LogLoss: - def __init__(self, inference_config: confuse.ConfigView, subpop_struct, time_setup, path_prefix: str = "."): + def __init__( + self, + inference_config: confuse.ConfigView, + subpop_struct, + time_setup, + path_prefix: str = ".", + ): # TODO: bad format for gt because each date must have a value for each column, but if it doesn't and you add NA # then this NA has a meaning that depends on skip NA, which is annoying. # A lot of things can go wrong here, in the previous approach where GT was cast to xarray as @@ -35,20 +41,36 @@ def __init__(self, inference_config: confuse.ConfigView, subpop_struct, time_set # made the controversial choice of storing the gt as an xarray dataset instead of a dictionary # of dataframes - self.gt_xr = xr.Dataset.from_dataframe(self.gt.reset_index().set_index(["date", "subpop"])) + self.gt_xr = xr.Dataset.from_dataframe( + self.gt.reset_index().set_index(["date", "subpop"]) + ) # Very important: subsample the subpop in the population, in the right order, and sort by the date index. - self.gt_xr = self.gt_xr.sortby("date").reindex({"subpop": subpop_struct.subpop_names}) + self.gt_xr = self.gt_xr.sortby("date").reindex( + {"subpop": subpop_struct.subpop_names} + ) # This will force at 0, if skipna is False, data of some variable that don't exist if iother exist # and damn python datetime types are ugly... - self.first_date = max(pd.to_datetime(self.gt_xr.date[0].values).date(), time_setup.ti) - self.last_date = min(pd.to_datetime(self.gt_xr.date[-1].values).date(), time_setup.tf) + self.first_date = max( + pd.to_datetime(self.gt_xr.date[0].values).date(), time_setup.ti + ) + self.last_date = min( + pd.to_datetime(self.gt_xr.date[-1].values).date(), time_setup.tf + ) self.statistics = {} for key, value in inference_config["statistics"].items(): self.statistics[key] = statistics.Statistic(key, value) - def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename=None, **kwargs): + def plot_gt( + self, + ax=None, + subpop=None, + statistic=None, + subplot=False, + filename=None, + **kwargs, + ): """Plots ground truth data. Args: @@ -68,7 +90,10 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= fig, axes = plt.subplots( len(self.gt["subpop"].unique()), len(self.gt.columns.drop("subpop")), - figsize=(4 * len(self.gt.columns.drop("subpop")), 3 * len(self.gt["subpop"].unique())), + figsize=( + 4 * len(self.gt.columns.drop("subpop")), + 3 * len(self.gt["subpop"].unique()), + ), dpi=250, sharex=True, ) @@ -81,7 +106,9 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= subpops = [subpop] if statistic is None: - statistics = self.gt.columns.drop("subpop") # Assuming other columns are statistics + statistics = self.gt.columns.drop( + "subpop" + ) # Assuming other columns are statistics else: statistics = [statistic] @@ -89,14 +116,18 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= # One subplot for each subpop/statistic combination for i, subpop in enumerate(subpops): for j, stat in enumerate(statistics): - data_to_plot = self.gt[(self.gt["subpop"] == subpop)][stat].sort_index() + data_to_plot = self.gt[(self.gt["subpop"] == subpop)][ + stat + ].sort_index() axes[i, j].plot(data_to_plot, **kwargs) axes[i, j].set_title(f"{subpop} - {stat}") else: # All lines in a single plot for subpop in subpops: for stat in statistics: - data_to_plot = self.gt[(self.gt["subpop"] == subpop)][stat].sort_index() + data_to_plot = self.gt[(self.gt["subpop"] == subpop)][ + stat + ].sort_index() data_to_plot.plot(ax=ax, **kwargs, label=f"{subpop} - {stat}") if len(statistics) > 1: ax.legend() @@ -107,7 +138,10 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= plt.savefig(filename, **kwargs) # Save the figure if subplot: - return fig, axes # Return figure and subplots for potential further customization + return ( + fig, + axes, + ) # Return figure and subplots for potential further customization else: return ax # Optionally return the axis @@ -121,13 +155,17 @@ def compute_logloss(self, model_df, subpop_names): coords = {"statistic": list(self.statistics.keys()), "subpop": subpop_names} logloss = xr.DataArray( - np.zeros((len(coords["statistic"]), len(coords["subpop"]))), dims=["statistic", "subpop"], coords=coords + np.zeros((len(coords["statistic"]), len(coords["subpop"]))), + dims=["statistic", "subpop"], + coords=coords, ) regularizations = 0 model_xr = ( - xr.Dataset.from_dataframe(model_df.reset_index().set_index(["date", "subpop"])) + xr.Dataset.from_dataframe( + model_df.reset_index().set_index(["date", "subpop"]) + ) .sortby("date") .reindex({"subpop": subpop_names}) ) diff --git a/flepimop/gempyor_pkg/src/gempyor/model_info.py b/flepimop/gempyor_pkg/src/gempyor/model_info.py index 54f981d8f..dc3477f8a 100644 --- a/flepimop/gempyor_pkg/src/gempyor/model_info.py +++ b/flepimop/gempyor_pkg/src/gempyor/model_info.py @@ -1,6 +1,13 @@ import pandas as pd import datetime, os, logging, pathlib, confuse -from . import seeding, subpopulation_structure, parameters, compartments, file_paths, initial_conditions +from . import ( + seeding, + subpopulation_structure, + parameters, + compartments, + file_paths, + initial_conditions, +) from .utils import read_df, write_df logger = logging.getLogger(__name__) @@ -11,7 +18,9 @@ def __init__(self, config: confuse.ConfigView): self.ti = config["start_date"].as_date() self.tf = config["end_date"].as_date() if self.tf <= self.ti: - raise ValueError("tf (time to finish) is less than or equal to ti (time to start)") + raise ValueError( + "tf (time to finish) is less than or equal to ti (time to start)" + ) self.n_days = (self.tf - self.ti).days + 1 self.dates = pd.date_range(start=self.ti, end=self.tf, freq="D") @@ -29,7 +38,7 @@ class ModelInfo: seeding # One of seeding or initial_conditions is required when running seir outcomes # Required if running outcomes seir_modifiers # Not required. If exists, every modifier will be applied to seir parameters - outcomes_modifiers # Not required. If exists, every modifier will be applied to outcomes + outcomes_modifiers # Not required. If exists, every modifier will be applied to outcomes inference # Required if running inference ``` """ @@ -94,7 +103,9 @@ def __init__( # 3. What about subpopulations subpop_config = config["subpop_setup"] if "data_path" in config: - raise ValueError("The config has a data_path section. This is no longer supported.") + raise ValueError( + "The config has a data_path section. This is no longer supported." + ) self.path_prefix = pathlib.Path(path_prefix) self.subpop_struct = subpopulation_structure.SubpopulationStructure( @@ -112,9 +123,13 @@ def __init__( self.seir_config = config["seir"] self.parameters_config = config["seir"]["parameters"] self.initial_conditions_config = ( - config["initial_conditions"] if config["initial_conditions"].exists() else None + config["initial_conditions"] + if config["initial_conditions"].exists() + else None + ) + self.seeding_config = ( + config["seeding"] if config["seeding"].exists() else None ) - self.seeding_config = config["seeding"] if config["seeding"].exists() else None if self.seeding_config is None and self.initial_conditions_config is None: logging.critical( @@ -130,25 +145,36 @@ def __init__( subpop_names=self.subpop_struct.subpop_names, path_prefix=self.path_prefix, ) - self.seeding = seeding.SeedingFactory(config=self.seeding_config, path_prefix=self.path_prefix) + self.seeding = seeding.SeedingFactory( + config=self.seeding_config, path_prefix=self.path_prefix + ) self.initial_conditions = initial_conditions.InitialConditionsFactory( config=self.initial_conditions_config, path_prefix=self.path_prefix ) # really ugly references to the config globally here. if config["compartments"].exists() and self.seir_config is not None: self.compartments = compartments.Compartments( - seir_config=self.seir_config, compartments_config=config["compartments"] + seir_config=self.seir_config, + compartments_config=config["compartments"], ) # SEIR modifiers self.npi_config_seir = None if config["seir_modifiers"].exists(): if config["seir_modifiers"]["scenarios"].exists(): - self.npi_config_seir = config["seir_modifiers"]["modifiers"][seir_modifiers_scenario] - self.seir_modifiers_library = config["seir_modifiers"]["modifiers"].get() + self.npi_config_seir = config["seir_modifiers"]["modifiers"][ + seir_modifiers_scenario + ] + self.seir_modifiers_library = config["seir_modifiers"][ + "modifiers" + ].get() else: - self.seir_modifiers_library = config["seir_modifiers"]["modifiers"].get() - raise ValueError("Not implemented yet") # TODO create a Stacked from all + self.seir_modifiers_library = config["seir_modifiers"][ + "modifiers" + ].get() + raise ValueError( + "Not implemented yet" + ) # TODO create a Stacked from all elif self.seir_modifiers_scenario is not None: raise ValueError( "An seir modifiers scenario was provided to ModelInfo but no 'seir_modifiers' sections in config" @@ -157,21 +183,33 @@ def __init__( logging.info("Running ModelInfo with seir but without SEIR Modifiers") elif self.seir_modifiers_scenario is not None: - raise ValueError("A seir modifiers scenario was provided to ModelInfo but no 'seir:' sections in config") + raise ValueError( + "A seir modifiers scenario was provided to ModelInfo but no 'seir:' sections in config" + ) else: logging.critical("Running ModelInfo without SEIR") # 5. Outcomes - self.outcomes_config = config["outcomes"] if config["outcomes"].exists() else None + self.outcomes_config = ( + config["outcomes"] if config["outcomes"].exists() else None + ) if self.outcomes_config is not None: self.npi_config_outcomes = None if config["outcome_modifiers"].exists(): if config["outcome_modifiers"]["scenarios"].exists(): - self.npi_config_outcomes = config["outcome_modifiers"]["modifiers"][self.outcome_modifiers_scenario] - self.outcome_modifiers_library = config["outcome_modifiers"]["modifiers"].get() + self.npi_config_outcomes = config["outcome_modifiers"]["modifiers"][ + self.outcome_modifiers_scenario + ] + self.outcome_modifiers_library = config["outcome_modifiers"][ + "modifiers" + ].get() else: - self.outcome_modifiers_library = config["outcome_modifiers"]["modifiers"].get() - raise ValueError("Not implemented yet") # TODO create a Stacked from all + self.outcome_modifiers_library = config["outcome_modifiers"][ + "modifiers" + ].get() + raise ValueError( + "Not implemented yet" + ) # TODO create a Stacked from all ## NEED TO IMPLEMENT THIS -- CURRENTLY CANNOT USE outcome modifiers elif self.outcome_modifiers_scenario is not None: @@ -182,7 +220,9 @@ def __init__( else: self.outcome_modifiers_scenario = None else: - logging.info("Running ModelInfo with outcomes but without Outcomes Modifiers") + logging.info( + "Running ModelInfo with outcomes but without Outcomes Modifiers" + ) elif self.outcome_modifiers_scenario is not None: raise ValueError( "An outcome modifiers scenario was provided to ModelInfo but no 'outcomes:' sections in config" @@ -228,7 +268,9 @@ def __init__( os.makedirs(datadir, exist_ok=True) if self.write_parquet and self.write_csv: - print("Confused between reading .csv or parquet. Assuming input file is .parquet") + print( + "Confused between reading .csv or parquet. Assuming input file is .parquet" + ) if self.write_parquet: self.extension = "parquet" elif self.write_csv: @@ -244,7 +286,9 @@ def get_input_filename(self, ftype: str, sim_id: int, extension_override: str = extension_override=extension_override, ) - def get_output_filename(self, ftype: str, sim_id: int, extension_override: str = ""): + def get_output_filename( + self, ftype: str, sim_id: int, extension_override: str = "" + ): return self.path_prefix / self.get_filename( ftype=ftype, sim_id=sim_id, @@ -252,7 +296,9 @@ def get_output_filename(self, ftype: str, sim_id: int, extension_override: str = extension_override=extension_override, ) - def get_filename(self, ftype: str, sim_id: int, input: bool, extension_override: str = ""): + def get_filename( + self, ftype: str, sim_id: int, input: bool, extension_override: str = "" + ): """return a CSP formated filename.""" if extension_override: # empty strings are Falsy @@ -281,7 +327,9 @@ def get_filename(self, ftype: str, sim_id: int, input: bool, extension_override: def get_setup_name(self): return self.setup_name - def read_simID(self, ftype: str, sim_id: int, input: bool = True, extension_override: str = ""): + def read_simID( + self, ftype: str, sim_id: int, input: bool = True, extension_override: str = "" + ): fname = self.get_filename( ftype=ftype, sim_id=sim_id, diff --git a/flepimop/gempyor_pkg/src/gempyor/outcomes.py b/flepimop/gempyor_pkg/src/gempyor/outcomes.py index 5563f4d85..3c693518c 100644 --- a/flepimop/gempyor_pkg/src/gempyor/outcomes.py +++ b/flepimop/gempyor_pkg/src/gempyor/outcomes.py @@ -22,7 +22,9 @@ def run_parallel_outcomes(modinf, *, sim_id2write, nslots=1, n_jobs=1): sim_id2writes = np.arange(sim_id2write, sim_id2write + modinf.nslots) loaded_values = None - if (n_jobs == 1) or (modinf.nslots == 1): # run single process for debugging/profiling purposes + if (n_jobs == 1) or ( + modinf.nslots == 1 + ): # run single process for debugging/profiling purposes for sim_offset in np.arange(nslots): onerun_delayframe_outcomes( sim_id2write=sim_id2writes[sim_offset], @@ -100,7 +102,9 @@ def onerun_delayframe_outcomes( npi_outcomes = None if modinf.npi_config_outcomes: - npi_outcomes = build_outcome_modifiers(modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi_outcomes = build_outcome_modifiers( + modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) loaded_values = None if load_ID: @@ -117,7 +121,13 @@ def onerun_delayframe_outcomes( ) with Timer("onerun_delayframe_outcomes.postprocess"): - postprocess_and_write(sim_id=sim_id2write, modinf=modinf, outcomes_df=outcomes_df, hpar=hpar, npi=npi_outcomes) + postprocess_and_write( + sim_id=sim_id2write, + modinf=modinf, + outcomes_df=outcomes_df, + hpar=hpar, + npi=npi_outcomes, + ) def read_parameters_from_config(modinf: model_info.ModelInfo): @@ -129,7 +139,10 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): if modinf.outcomes_config["param_from_file"].exists(): if modinf.outcomes_config["param_from_file"].get(): # Load the actual csv file - branching_file = modinf.path_prefix / modinf.outcomes_config["param_subpop_file"].as_str() + branching_file = ( + modinf.path_prefix + / modinf.outcomes_config["param_subpop_file"].as_str() + ) branching_data = pa.parquet.read_table(branching_file).to_pandas() if "relative_probability" not in list(branching_data["quantity"]): raise ValueError( @@ -142,14 +155,18 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): "", end="", ) - branching_data = branching_data[branching_data["subpop"].isin(modinf.subpop_struct.subpop_names)] + branching_data = branching_data[ + branching_data["subpop"].isin(modinf.subpop_struct.subpop_names) + ] print( "Intersect with seir simulation: ", len(branching_data.subpop.unique()), "kept", ) - if len(branching_data.subpop.unique()) != len(modinf.subpop_struct.subpop_names): + if len(branching_data.subpop.unique()) != len( + modinf.subpop_struct.subpop_names + ): raise ValueError( f"Places in seir input files does not correspond to subpops in outcome probability file {branching_file}" ) @@ -162,7 +179,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): src_name = outcomes_config[new_comp]["source"].get() if isinstance(src_name, str): parameters[new_comp]["source"] = src_name - elif ("incidence" in src_name.keys()) or ("prevalence" in src_name.keys()): + elif ("incidence" in src_name.keys()) or ( + "prevalence" in src_name.keys() + ): parameters[new_comp]["source"] = dict(src_name) else: @@ -170,10 +189,16 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"unsure how to read outcome {new_comp}: not a str, nor an incidence or prevalence: {src_name}" ) - parameters[new_comp]["probability"] = outcomes_config[new_comp]["probability"]["value"] - if outcomes_config[new_comp]["probability"]["modifier_parameter"].exists(): + parameters[new_comp]["probability"] = outcomes_config[new_comp][ + "probability" + ]["value"] + if outcomes_config[new_comp]["probability"][ + "modifier_parameter" + ].exists(): parameters[new_comp]["probability::npi_param_name"] = ( - outcomes_config[new_comp]["probability"]["modifier_parameter"].as_str().lower() + outcomes_config[new_comp]["probability"]["modifier_parameter"] + .as_str() + .lower() ) logging.debug( f"probability of outcome {new_comp} is affected by intervention " @@ -181,13 +206,21 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"instead of {new_comp}::probability" ) else: - parameters[new_comp]["probability::npi_param_name"] = f"{new_comp}::probability".lower() + parameters[new_comp][ + "probability::npi_param_name" + ] = f"{new_comp}::probability".lower() if outcomes_config[new_comp]["delay"].exists(): - parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"]["value"] - if outcomes_config[new_comp]["delay"]["modifier_parameter"].exists(): + parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"][ + "value" + ] + if outcomes_config[new_comp]["delay"][ + "modifier_parameter" + ].exists(): parameters[new_comp]["delay::npi_param_name"] = ( - outcomes_config[new_comp]["delay"]["modifier_parameter"].as_str().lower() + outcomes_config[new_comp]["delay"]["modifier_parameter"] + .as_str() + .lower() ) logging.debug( f"delay of outcome {new_comp} is affected by intervention " @@ -195,18 +228,32 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"instead of {new_comp}::delay" ) else: - parameters[new_comp]["delay::npi_param_name"] = f"{new_comp}::delay".lower() + parameters[new_comp][ + "delay::npi_param_name" + ] = f"{new_comp}::delay".lower() else: - logging.critical(f"No delay for outcome {new_comp}, using a 0 delay") + logging.critical( + f"No delay for outcome {new_comp}, using a 0 delay" + ) outcomes_config[new_comp]["delay"] = {"value": 0} - parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"]["value"] - parameters[new_comp]["delay::npi_param_name"] = f"{new_comp}::delay".lower() + parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"][ + "value" + ] + parameters[new_comp][ + "delay::npi_param_name" + ] = f"{new_comp}::delay".lower() if outcomes_config[new_comp]["duration"].exists(): - parameters[new_comp]["duration"] = outcomes_config[new_comp]["duration"]["value"] - if outcomes_config[new_comp]["duration"]["modifier_parameter"].exists(): + parameters[new_comp]["duration"] = outcomes_config[new_comp][ + "duration" + ]["value"] + if outcomes_config[new_comp]["duration"][ + "modifier_parameter" + ].exists(): parameters[new_comp]["duration::npi_param_name"] = ( - outcomes_config[new_comp]["duration"]["modifier_parameter"].as_str().lower() + outcomes_config[new_comp]["duration"]["modifier_parameter"] + .as_str() + .lower() ) logging.debug( f"duration of outcome {new_comp} is affected by intervention " @@ -214,7 +261,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"instead of {new_comp}::duration" ) else: - parameters[new_comp]["duration::npi_param_name"] = f"{new_comp}::duration".lower() + parameters[new_comp][ + "duration::npi_param_name" + ] = f"{new_comp}::duration".lower() if outcomes_config[new_comp]["duration"]["name"].exists(): parameters[new_comp]["outcome_prevalence_name"] = ( @@ -223,7 +272,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): ) else: # parameters[class_name]["outcome_prevalence_name"] = new_comp + "_curr" + subclass - parameters[new_comp]["outcome_prevalence_name"] = new_comp + "_curr" + parameters[new_comp]["outcome_prevalence_name"] = ( + new_comp + "_curr" + ) if modinf.outcomes_config["param_from_file"].exists(): if modinf.outcomes_config["param_from_file"].get(): rel_probability = branching_data[ @@ -231,14 +282,22 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): & (branching_data["quantity"] == "relative_probability") ].copy(deep=True) if len(rel_probability) > 0: - logging.debug(f"Using 'param_from_file' for relative probability in outcome {new_comp}") + logging.debug( + f"Using 'param_from_file' for relative probability in outcome {new_comp}" + ) # Sort it in case the relative probablity file is mispecified - rel_probability.subpop = rel_probability.subpop.astype("category") - rel_probability.subpop = rel_probability.subpop.cat.set_categories( - modinf.subpop_struct.subpop_names + rel_probability.subpop = rel_probability.subpop.astype( + "category" + ) + rel_probability.subpop = ( + rel_probability.subpop.cat.set_categories( + modinf.subpop_struct.subpop_names + ) ) rel_probability = rel_probability.sort_values(["subpop"]) - parameters[new_comp]["rel_probability"] = rel_probability["value"].to_numpy() + parameters[new_comp]["rel_probability"] = rel_probability[ + "value" + ].to_numpy() else: logging.debug( f"*NOT* Using 'param_from_file' for relative probability in outcome {new_comp}" @@ -348,7 +407,9 @@ def compute_all_multioutcomes( outcome_name=new_comp, ) else: - raise ValueError(f"Unknown type for seir simulation provided, got f{type(seir_sim)}") + raise ValueError( + f"Unknown type for seir simulation provided, got f{type(seir_sim)}" + ) # we don't keep source in this cases else: # already defined outcomes if source_name in all_data: @@ -358,28 +419,40 @@ def compute_all_multioutcomes( f"ERROR with outcome {new_comp}: the specified source {source_name} is not a dictionnary (for seir outcome) nor an existing pre-identified outcomes." ) - if (loaded_values is not None) and (new_comp in loaded_values["outcome"].values): + if (loaded_values is not None) and ( + new_comp in loaded_values["outcome"].values + ): ## This may be unnecessary probabilities = loaded_values[ - (loaded_values["quantity"] == "probability") & (loaded_values["outcome"] == new_comp) + (loaded_values["quantity"] == "probability") + & (loaded_values["outcome"] == new_comp) + ]["value"].to_numpy() + delays = loaded_values[ + (loaded_values["quantity"] == "delay") + & (loaded_values["outcome"] == new_comp) ]["value"].to_numpy() - delays = loaded_values[(loaded_values["quantity"] == "delay") & (loaded_values["outcome"] == new_comp)][ - "value" - ].to_numpy() else: - probabilities = parameters[new_comp]["probability"].as_random_distribution()( + probabilities = parameters[new_comp][ + "probability" + ].as_random_distribution()( size=len(modinf.subpop_struct.subpop_names) ) # one draw per subpop if "rel_probability" in parameters[new_comp]: - probabilities = probabilities * parameters[new_comp]["rel_probability"] + probabilities = ( + probabilities * parameters[new_comp]["rel_probability"] + ) delays = parameters[new_comp]["delay"].as_random_distribution()( size=len(modinf.subpop_struct.subpop_names) ) # one draw per subpop probabilities[probabilities > 1] = 1 probabilities[probabilities < 0] = 0 - probabilities = np.repeat(probabilities[:, np.newaxis], len(dates), axis=1).T # duplicate in time - delays = np.repeat(delays[:, np.newaxis], len(dates), axis=1).T # duplicate in time + probabilities = np.repeat( + probabilities[:, np.newaxis], len(dates), axis=1 + ).T # duplicate in time + delays = np.repeat( + delays[:, np.newaxis], len(dates), axis=1 + ).T # duplicate in time delays = np.round(delays).astype(int) # Write hpar before NPI subpop_names_len = len(modinf.subpop_struct.subpop_names) @@ -387,7 +460,7 @@ def compute_all_multioutcomes( { "subpop": 2 * modinf.subpop_struct.subpop_names, "quantity": (subpop_names_len * ["probability"]) - + (subpop_names_len * ["delay"]), + + (subpop_names_len * ["delay"]), "outcome": 2 * subpop_names_len * [new_comp], "value": np.concatenate( ( @@ -402,42 +475,61 @@ def compute_all_multioutcomes( if npi is not None: delays = NPI.reduce_parameter( parameter=delays, - modification=npi.getReduction(parameters[new_comp]["delay::npi_param_name"].lower()), + modification=npi.getReduction( + parameters[new_comp]["delay::npi_param_name"].lower() + ), ) delays = np.round(delays).astype(int) probabilities = NPI.reduce_parameter( parameter=probabilities, - modification=npi.getReduction(parameters[new_comp]["probability::npi_param_name"].lower()), + modification=npi.getReduction( + parameters[new_comp]["probability::npi_param_name"].lower() + ), ) # Create new compartment incidence: all_data[new_comp] = np.empty_like(source_array) # Draw with from source compartment if modinf.stoch_traj_flag: - all_data[new_comp] = np.random.binomial(source_array.astype(np.int32), probabilities) + all_data[new_comp] = np.random.binomial( + source_array.astype(np.int32), probabilities + ) else: - all_data[new_comp] = source_array * (probabilities * np.ones_like(source_array)) + all_data[new_comp] = source_array * ( + probabilities * np.ones_like(source_array) + ) # Shift to account for the delay ## stoch_delay_flag is whether to use stochastic delays or not stoch_delay_flag = False - all_data[new_comp] = multishift(all_data[new_comp], delays, stoch_delay_flag=stoch_delay_flag) + all_data[new_comp] = multishift( + all_data[new_comp], delays, stoch_delay_flag=stoch_delay_flag + ) # Produce a dataframe an merge it - df_p = dataframe_from_array(all_data[new_comp], modinf.subpop_struct.subpop_names, dates, new_comp) + df_p = dataframe_from_array( + all_data[new_comp], modinf.subpop_struct.subpop_names, dates, new_comp + ) outcomes = pd.merge(outcomes, df_p) # Make duration if "duration" in parameters[new_comp]: - if (loaded_values is not None) and (new_comp in loaded_values["outcome"].values): + if (loaded_values is not None) and ( + new_comp in loaded_values["outcome"].values + ): durations = loaded_values[ - (loaded_values["quantity"] == "duration") & (loaded_values["outcome"] == new_comp) + (loaded_values["quantity"] == "duration") + & (loaded_values["outcome"] == new_comp) ]["value"].to_numpy() else: - durations = parameters[new_comp]["duration"].as_random_distribution()( + durations = parameters[new_comp][ + "duration" + ].as_random_distribution()( size=len(modinf.subpop_struct.subpop_names) ) # one draw per subpop - durations = np.repeat(durations[:, np.newaxis], len(dates), axis=1).T # duplicate in time + durations = np.repeat( + durations[:, np.newaxis], len(dates), axis=1 + ).T # duplicate in time durations = np.round(durations).astype(int) hpar = pd.DataFrame( data={ @@ -458,7 +550,9 @@ def compute_all_multioutcomes( # print(f"{new_comp}-duration".lower(), npi.getReduction(f"{new_comp}-duration".lower())) durations = NPI.reduce_parameter( parameter=durations, - modification=npi.getReduction(parameters[new_comp]["duration::npi_param_name"].lower()), + modification=npi.getReduction( + parameters[new_comp]["duration::npi_param_name"].lower() + ), ) # npi.getReduction(f"{new_comp}::duration".lower())) durations = np.round(durations).astype(int) # plt.imshow(durations) @@ -492,7 +586,9 @@ def compute_all_multioutcomes( for cmp in parameters[new_comp]["sum"]: sum_outcome += all_data[cmp] all_data[new_comp] = sum_outcome - df_p = dataframe_from_array(sum_outcome, modinf.subpop_struct.subpop_names, dates, new_comp) + df_p = dataframe_from_array( + sum_outcome, modinf.subpop_struct.subpop_names, dates, new_comp + ) outcomes = pd.merge(outcomes, df_p) # Concat our hpar dataframes hpar = ( @@ -525,7 +621,9 @@ def filter_seir_df(diffI, dates, subpops, filters, outcome_name) -> np.ndarray: df = df[df[f"mc_{mc_type}"].isin(mc_value)] for mcn in df["mc_name"].unique(): new_df = df[df["mc_name"] == mcn] - new_df = new_df.drop(["date"] + [c for c in new_df.columns if "mc_" in c], axis=1) + new_df = new_df.drop( + ["date"] + [c for c in new_df.columns if "mc_" in c], axis=1 + ) # new_df = new_df.drop("date", axis=1) incidI_arr = incidI_arr + new_df.to_numpy() return incidI_arr @@ -554,7 +652,9 @@ def filter_seir_xr(diffI, dates, subpops, filters, outcome_name) -> np.ndarray: if isinstance(mc_value, str): mc_value = [mc_value] # Filter data along the specified mc_type dimension - diffI_filtered = diffI_filtered.where(diffI_filtered[f"mc_{mc_type}"].isin(mc_value), drop=True) + diffI_filtered = diffI_filtered.where( + diffI_filtered[f"mc_{mc_type}"].isin(mc_value), drop=True + ) # Sum along the compartment dimension incidI_arr += diffI_filtered[vtype].sum(dim="compartment") @@ -626,7 +726,9 @@ def multishift(arr, shifts, stoch_delay_flag=True): # for k,case in enumerate(cases): # results[i+k][j] = cases[k] else: - for i in range(arr.shape[0]): # numba nopython does not allow iterating over 2D array + for i in range( + arr.shape[0] + ): # numba nopython does not allow iterating over 2D array for j in range(arr.shape[1]): if i + shifts[i, j] < arr.shape[0]: result[i + shifts[i, j], j] += arr[i, j] diff --git a/flepimop/gempyor_pkg/src/gempyor/parameters.py b/flepimop/gempyor_pkg/src/gempyor/parameters.py index b01fb52ab..b093322c9 100644 --- a/flepimop/gempyor_pkg/src/gempyor/parameters.py +++ b/flepimop/gempyor_pkg/src/gempyor/parameters.py @@ -41,12 +41,18 @@ def __init__( self.pdata = {} self.pnames2pindex = {} - self.stacked_modifier_method = {"sum": [], "product": [], "reduction_product": []} + self.stacked_modifier_method = { + "sum": [], + "product": [], + "reduction_product": [], + } self.pnames = self.pconfig.keys() self.npar = len(self.pnames) if self.npar != len(set([name.lower() for name in self.pnames])): - raise ValueError("Parameters of the SEIR model have the same name (remember that case is not sufficient!)") + raise ValueError( + "Parameters of the SEIR model have the same name (remember that case is not sufficient!)" + ) # Attributes of dictionary for idx, pn in enumerate(self.pnames): @@ -56,19 +62,29 @@ def __init__( # Parameter characterized by it's distribution if self.pconfig[pn]["value"].exists(): - self.pdata[pn]["dist"] = self.pconfig[pn]["value"].as_random_distribution() + self.pdata[pn]["dist"] = self.pconfig[pn][ + "value" + ].as_random_distribution() # Parameter given as a file elif self.pconfig[pn]["timeseries"].exists(): - fn_name = os.path.join(path_prefix, self.pconfig[pn]["timeseries"].get()) + fn_name = os.path.join( + path_prefix, self.pconfig[pn]["timeseries"].get() + ) df = utils.read_df(fn_name).set_index("date") df.index = pd.to_datetime(df.index) - if len(df.columns) == 1: # if only one ts, assume it applies to all subpops + if ( + len(df.columns) == 1 + ): # if only one ts, assume it applies to all subpops df = pd.DataFrame( - pd.concat([df] * len(subpop_names), axis=1).values, index=df.index, columns=subpop_names + pd.concat([df] * len(subpop_names), axis=1).values, + index=df.index, + columns=subpop_names, ) elif len(df.columns) >= len(subpop_names): # one ts per subpop - df = df[subpop_names] # make sure the order of subpops is the same as the reference + df = df[ + subpop_names + ] # make sure the order of subpops is the same as the reference # (subpop_names from spatial setup) and select the columns else: print("loaded col :", sorted(list(df.columns))) @@ -102,15 +118,23 @@ def __init__( self.pdata[pn]["ts"] = df if self.pconfig[pn]["stacked_modifier_method"].exists(): - self.pdata[pn]["stacked_modifier_method"] = self.pconfig[pn]["stacked_modifier_method"].as_str() + self.pdata[pn]["stacked_modifier_method"] = self.pconfig[pn][ + "stacked_modifier_method" + ].as_str() else: self.pdata[pn]["stacked_modifier_method"] = "product" - logging.debug(f"No 'stacked_modifier_method' for parameter {pn}, assuming multiplicative NPIs") + logging.debug( + f"No 'stacked_modifier_method' for parameter {pn}, assuming multiplicative NPIs" + ) if self.pconfig[pn]["rolling_mean_windows"].exists(): - self.pdata[pn]["rolling_mean_windows"] = self.pconfig[pn]["rolling_mean_windows"].get() + self.pdata[pn]["rolling_mean_windows"] = self.pconfig[pn][ + "rolling_mean_windows" + ].get() - self.stacked_modifier_method[self.pdata[pn]["stacked_modifier_method"]].append(pn.lower()) + self.stacked_modifier_method[ + self.pdata[pn]["stacked_modifier_method"] + ].append(pn.lower()) logging.debug(f"We have {self.npar} parameter: {self.pnames}") logging.debug(f"Data to sample is: {self.pdata}") @@ -146,7 +170,9 @@ def parameters_quick_draw(self, n_days: int, nsubpops: int) -> ndarray: return param_arr # we don't store it as a member because this object needs to be small to be pickable - def parameters_load(self, param_df: pd.DataFrame, n_days: int, nsubpops: int) -> ndarray: + def parameters_load( + self, param_df: pd.DataFrame, n_days: int, nsubpops: int + ) -> ndarray: """ drop-in equivalent to param_quick_draw() that take a file as written parameter_write() :param fname: @@ -165,7 +191,9 @@ def parameters_load(self, param_df: pd.DataFrame, n_days: int, nsubpops: int) -> elif "ts" in self.pdata[pn]: param_arr[idx] = self.pdata[pn]["ts"].values else: - print(f"PARAM: parameter {pn} NOT found in loadID file. Drawing from config distribution") + print( + f"PARAM: parameter {pn} NOT found in loadID file. Drawing from config distribution" + ) pval = self.pdata[pn]["dist"]() param_arr[idx] = np.full((n_days, nsubpops), pval) @@ -179,9 +207,15 @@ def getParameterDF(self, p_draw: ndarray) -> pd.DataFrame: """ # we don't write to disk time series parameters. out_df = pd.DataFrame( - [p_draw[idx, 0, 0] for idx, pn in enumerate(self.pnames) if "dist" in self.pdata[pn]], + [ + p_draw[idx, 0, 0] + for idx, pn in enumerate(self.pnames) + if "dist" in self.pdata[pn] + ], columns=["value"], - index=[pn for idx, pn in enumerate(self.pnames) if "dist" in self.pdata[pn]], + index=[ + pn for idx, pn in enumerate(self.pnames) if "dist" in self.pdata[pn] + ], ) out_df["parameter"] = out_df.index @@ -204,6 +238,8 @@ def parameters_reduce(self, p_draw: ndarray, npi: object) -> ndarray: ) p_reduced[idx] = npi_val if "rolling_mean_windows" in self.pdata[pn]: - p_reduced[idx] = utils.rolling_mean_pad(data=npi_val, window=self.pdata[pn]["rolling_mean_windows"]) + p_reduced[idx] = utils.rolling_mean_pad( + data=npi_val, window=self.pdata[pn]["rolling_mean_windows"] + ) return p_reduced diff --git a/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py b/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py index 2b42e2944..fa5081c0d 100644 --- a/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py +++ b/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py @@ -21,7 +21,15 @@ import pandas as pd import pyarrow.parquet as pq import xarray as xr -from gempyor import config, model_info, outcomes, seir, inference_parameter, logloss, inference +from gempyor import ( + config, + model_info, + outcomes, + seir, + inference_parameter, + logloss, + inference, +) from gempyor.inference import GempyorInference import tqdm import os @@ -37,7 +45,9 @@ def find_walkers_to_sample(inferpar, sampler_output, nsamples, nwalker, nthin): last_llik = sampler_output.get_log_prob()[-1, :] sampled_slots = last_llik > (last_llik.mean() - 1 * last_llik.std()) - print(f"there are {sampled_slots.sum()}/{len(sampled_slots)} good walkers... keeping these") + print( + f"there are {sampled_slots.sum()}/{len(sampled_slots)} good walkers... keeping these" + ) # TODO this function give back good_samples = sampler.get_chain()[:, sampled_slots, :] @@ -46,13 +56,15 @@ def find_walkers_to_sample(inferpar, sampler_output, nsamples, nwalker, nthin): exported_samples = np.empty((nsamples, inferpar.get_dim())) for i in range(nsamples): exported_samples[i, :] = good_samples[ - step_number - thin * (i // (sampled_slots.sum())), i % (sampled_slots.sum()), : + step_number - thin * (i // (sampled_slots.sum())), + i % (sampled_slots.sum()), + :, ] # parentesis around i//(sampled_slots.sum() are very important - - -def plot_chains(inferpar, chains, llik, save_to, sampled_slots=None, param_gt=None, llik_gt=None): +def plot_chains( + inferpar, chains, llik, save_to, sampled_slots=None, param_gt=None, llik_gt=None +): """ Plot the chains of the inference :param inferpar: the inference parameter object @@ -113,24 +125,47 @@ def plot_single_chain(frompt, ax, chain, label, gt=None): for sp in tqdm.tqdm(set(inferpar.subpops)): # find unique supopulation these_pars = inferpar.get_parameters_for_subpop(sp) - fig, axes = plt.subplots(max(len(these_pars), 2), 2, figsize=(6, (len(these_pars) + 1) * 2)) + fig, axes = plt.subplots( + max(len(these_pars), 2), 2, figsize=(6, (len(these_pars) + 1) * 2) + ) for idx, par_id in enumerate(these_pars): - plot_single_chain(first_thresh, axes[idx, 0], chains[:, :, par_id], labels[par_id], gt=param_gt[par_id] if param_gt is not None else None) - plot_single_chain(second_thresh, axes[idx, 1], chains[:, :, par_id], labels[par_id], gt=param_gt[par_id] if param_gt is not None else None) + plot_single_chain( + first_thresh, + axes[idx, 0], + chains[:, :, par_id], + labels[par_id], + gt=param_gt[par_id] if param_gt is not None else None, + ) + plot_single_chain( + second_thresh, + axes[idx, 1], + chains[:, :, par_id], + labels[par_id], + gt=param_gt[par_id] if param_gt is not None else None, + ) fig.tight_layout() pdf.savefig(fig) plt.close(fig) + def plot_fit(modinf, loss): subpop_names = modinf.subpop_struct.subpop_names fig, axes = plt.subplots( - len(subpop_names), len(loss.statistics), figsize=(3 * len(loss.statistics), 3 * len(subpop_names)), sharex=True + len(subpop_names), + len(loss.statistics), + figsize=(3 * len(loss.statistics), 3 * len(subpop_names)), + sharex=True, ) for j, subpop in enumerate(modinf.subpop_struct.subpop_names): gt_s = loss.gt[loss.gt["subpop"] == subpop].sort_index() first_date = max(gt_s.index.min(), results[0].index.min()) last_date = min(gt_s.index.max(), results[0].index.max()) - gt_s = gt_s.loc[first_date:last_date].drop(["subpop"], axis=1).resample("W-SAT").sum() + gt_s = ( + gt_s.loc[first_date:last_date] + .drop(["subpop"], axis=1) + .resample("W-SAT") + .sum() + ) for i, (stat_name, stat) in enumerate(loss.statistics.items()): ax = axes[j, i] diff --git a/flepimop/gempyor_pkg/src/gempyor/seeding.py b/flepimop/gempyor_pkg/src/gempyor/seeding.py index fe58657c0..f49114e5e 100644 --- a/flepimop/gempyor_pkg/src/gempyor/seeding.py +++ b/flepimop/gempyor_pkg/src/gempyor/seeding.py @@ -17,9 +17,13 @@ def _DataFrame2NumbaDict(df, amounts, modinf) -> nb.typed.Dict: if not df["date"].is_monotonic_increasing: - raise ValueError("_DataFrame2NumbaDict got an unsorted dataframe, exposing itself to non-sense") + raise ValueError( + "_DataFrame2NumbaDict got an unsorted dataframe, exposing itself to non-sense" + ) - cmp_grp_names = [col for col in modinf.compartments.compartments.columns if col != "name"] + cmp_grp_names = [ + col for col in modinf.compartments.compartments.columns if col != "name" + ] seeding_dict: nb.typed.Dict = nb.typed.Dict.empty( key_type=nb.types.unicode_type, value_type=nb.types.int64[:], @@ -45,16 +49,26 @@ def _DataFrame2NumbaDict(df, amounts, modinf) -> nb.typed.Dict: nb_seed_perday[(row["date"].date() - modinf.ti).days] = ( nb_seed_perday[(row["date"].date() - modinf.ti).days] + 1 ) - source_dict = {grp_name: row[f"source_{grp_name}"] for grp_name in cmp_grp_names} - destination_dict = {grp_name: row[f"destination_{grp_name}"] for grp_name in cmp_grp_names} + source_dict = { + grp_name: row[f"source_{grp_name}"] for grp_name in cmp_grp_names + } + destination_dict = { + grp_name: row[f"destination_{grp_name}"] + for grp_name in cmp_grp_names + } seeding_dict["seeding_sources"][idx] = modinf.compartments.get_comp_idx( - source_dict, error_info=f"(seeding source at idx={idx}, row_index={row_index}, row=>>{row}<<)" + source_dict, + error_info=f"(seeding source at idx={idx}, row_index={row_index}, row=>>{row}<<)", + ) + seeding_dict["seeding_destinations"][idx] = ( + modinf.compartments.get_comp_idx( + destination_dict, + error_info=f"(seeding destination at idx={idx}, row_index={row_index}, row=>>{row}<<)", + ) ) - seeding_dict["seeding_destinations"][idx] = modinf.compartments.get_comp_idx( - destination_dict, - error_info=f"(seeding destination at idx={idx}, row_index={row_index}, row=>>{row}<<)", + seeding_dict["seeding_subpops"][idx] = ( + modinf.subpop_struct.subpop_names.index(row["subpop"]) ) - seeding_dict["seeding_subpops"][idx] = modinf.subpop_struct.subpop_names.index(row["subpop"]) seeding_amounts[idx] = amounts[idx] # id_seed+=1 else: @@ -97,7 +111,9 @@ def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict: ) dupes = seeding[seeding.duplicated(["subpop", "date"])].index + 1 if not dupes.empty: - raise ValueError(f"Repeated subpop-date in rows {dupes.tolist()} of seeding::lambda_file.") + raise ValueError( + f"Repeated subpop-date in rows {dupes.tolist()} of seeding::lambda_file." + ) elif method == "FolderDraw": seeding = pd.read_csv( self.path_prefix @@ -127,7 +143,9 @@ def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict: # print(seeding.shape) seeding = seeding.sort_values(by="date", axis="index").reset_index() # print(seeding) - mask = (seeding["date"].dt.date > modinf.ti) & (seeding["date"].dt.date <= modinf.tf) + mask = (seeding["date"].dt.date > modinf.ti) & ( + seeding["date"].dt.date <= modinf.tf + ) seeding = seeding.loc[mask].reset_index() # print(seeding.shape) # print(seeding) @@ -138,7 +156,9 @@ def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict: if method == "PoissonDistributed": amounts = np.random.poisson(seeding["amount"]) elif method == "NegativeBinomialDistributed": - raise ValueError("Seeding method 'NegativeBinomialDistributed' is not supported by flepiMoP anymore.") + raise ValueError( + "Seeding method 'NegativeBinomialDistributed' is not supported by flepiMoP anymore." + ) elif method == "FolderDraw" or method == "FromFile": amounts = seeding["amount"] else: diff --git a/flepimop/gempyor_pkg/src/gempyor/seir.py b/flepimop/gempyor_pkg/src/gempyor/seir.py index 4e59761f2..d374bed25 100644 --- a/flepimop/gempyor_pkg/src/gempyor/seir.py +++ b/flepimop/gempyor_pkg/src/gempyor/seir.py @@ -41,7 +41,9 @@ def build_step_source_arg( else: integration_method = "rk4.jit" dt = 2.0 - logging.info(f"Integration method not provided, assuming type {integration_method} with dt=2") + logging.info( + f"Integration method not provided, assuming type {integration_method} with dt=2" + ) ## The type is very important for the call to the compiled function, and e.g mixing an int64 for an int32 can ## result in serious error. Note that "In Microsoft C, even on a 64 bit system, the size of the long int data type @@ -58,7 +60,10 @@ def build_step_source_arg( assert type(transition_array[0][0]) == np.int64 assert type(proportion_array[0]) == np.int64 assert type(proportion_info[0][0]) == np.int64 - assert initial_conditions.shape == (modinf.compartments.compartments.shape[0], modinf.nsubpops) + assert initial_conditions.shape == ( + modinf.compartments.compartments.shape[0], + modinf.nsubpops, + ) assert type(initial_conditions[0][0]) == np.float64 # Test of empty seeding: assert len(seeding_data.keys()) == 4 @@ -150,7 +155,9 @@ def steps_SEIR( else: from .dev import steps as steps_experimental - logging.critical("Experimental !!! These methods are not ready for production ! ") + logging.critical( + "Experimental !!! These methods are not ready for production ! " + ) if integration_method in [ "scipy.solve_ivp", "scipy.odeint", @@ -162,7 +169,9 @@ def steps_SEIR( f"with method {integration_method}, only deterministic " f"integration is possible (got stoch_straj_flag={modinf.stoch_traj_flag}" ) - seir_sim = steps_experimental.ode_integration(**fnct_args, integration_method=integration_method) + seir_sim = steps_experimental.ode_integration( + **fnct_args, integration_method=integration_method + ) elif integration_method == "rk4.jit1": seir_sim = steps_experimental.rk4_integration1(**fnct_args) elif integration_method == "rk4.jit2": @@ -200,13 +209,17 @@ def steps_SEIR( **compartment_coords, subpop=modinf.subpop_struct.subpop_names, ), - attrs=dict(description="Dynamical simulation results", run_id=modinf.in_run_id), # TODO add more information + attrs=dict( + description="Dynamical simulation results", run_id=modinf.in_run_id + ), # TODO add more information ) return states -def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_FN=None): +def build_npi_SEIR( + modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_FN=None +): with Timer("SEIR.NPI"): loaded_df = None if bypass_DF is not None: @@ -223,8 +236,12 @@ def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_ modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, loaded_df=loaded_df, - pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method[ + "sum" + ], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) else: npi = NPI.NPIBase.execute( @@ -232,8 +249,12 @@ def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_ modinf=modinf, modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, - pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method[ + "sum" + ], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) return npi @@ -248,7 +269,9 @@ def onerun_SEIR( np.random.seed() npi = None if modinf.npi_config_seir: - npi = build_npi_SEIR(modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi = build_npi_SEIR( + modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) with Timer("onerun_SEIR.compartments"): ( @@ -260,11 +283,19 @@ def onerun_SEIR( with Timer("onerun_SEIR.seeding"): if load_ID: - initial_conditions = modinf.initial_conditions.get_from_file(sim_id2load, modinf=modinf) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id2load, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_file( + sim_id2load, modinf=modinf + ) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id2load, modinf=modinf + ) else: - initial_conditions = modinf.initial_conditions.get_from_config(sim_id2write, modinf=modinf) - seeding_data, seeding_amounts = modinf.seeding.get_from_config(sim_id2write, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id2write, modinf=modinf + ) + seeding_data, seeding_amounts = modinf.seeding.get_from_config( + sim_id2write, modinf=modinf + ) with Timer("onerun_SEIR.parameters"): # Draw or load parameters @@ -275,14 +306,18 @@ def onerun_SEIR( nsubpops=modinf.nsubpops, ) else: - p_draw = modinf.parameters.parameters_quick_draw(n_days=modinf.n_days, nsubpops=modinf.nsubpops) + p_draw = modinf.parameters.parameters_quick_draw( + n_days=modinf.n_days, nsubpops=modinf.nsubpops + ) # reduce them parameters = modinf.parameters.parameters_reduce(p_draw, npi) log_debug_parameters(p_draw, "Parameters without seir_modifiers") log_debug_parameters(parameters, "Parameters with seir_modifiers") # Parse them - parsed_parameters = modinf.compartments.parse_parameters(parameters, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + parameters, modinf.parameters.pnames, unique_strings + ) log_debug_parameters(parsed_parameters, "Unique Parameters used by transitions") with Timer("onerun_SEIR.compute"): @@ -310,7 +345,13 @@ def run_parallel_SEIR(modinf, config, *, n_jobs=1): if n_jobs == 1: # run single process for debugging/profiling purposes for sim_id in tqdm.tqdm(sim_ids): - onerun_SEIR(sim_id2write=sim_id, modinf=modinf, load_ID=False, sim_id2load=None, config=config) + onerun_SEIR( + sim_id2write=sim_id, + modinf=modinf, + load_ID=False, + sim_id2load=None, + config=config, + ) else: tqdm.contrib.concurrent.process_map( onerun_SEIR, @@ -322,7 +363,9 @@ def run_parallel_SEIR(modinf, config, *, n_jobs=1): max_workers=n_jobs, ) - logging.info(f""">> {modinf.nslots} seir simulations completed in {time.monotonic() - start:.1f} seconds""") + logging.info( + f""">> {modinf.nslots} seir simulations completed in {time.monotonic() - start:.1f} seconds""" + ) def states2Df(modinf, states): @@ -337,12 +380,17 @@ def states2Df(modinf, states): # states_diff = np.diff(states_diff, axis=0) ts_index = pd.MultiIndex.from_product( - [pd.date_range(modinf.ti, modinf.tf, freq="D"), modinf.compartments.compartments["name"]], + [ + pd.date_range(modinf.ti, modinf.tf, freq="D"), + modinf.compartments.compartments["name"], + ], names=["date", "mc_name"], ) # prevalence data, we use multi.index dataframe, sparring us the array manipulation we use to do prev_df = pd.DataFrame( - data=states["prevalence"].to_numpy().reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), + data=states["prevalence"] + .to_numpy() + .reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), index=ts_index, columns=modinf.subpop_struct.subpop_names, ).reset_index() @@ -355,12 +403,17 @@ def states2Df(modinf, states): prev_df.insert(loc=0, column="mc_value_type", value="prevalence") ts_index = pd.MultiIndex.from_product( - [pd.date_range(modinf.ti, modinf.tf, freq="D"), modinf.compartments.compartments["name"]], + [ + pd.date_range(modinf.ti, modinf.tf, freq="D"), + modinf.compartments.compartments["name"], + ], names=["date", "mc_name"], ) incid_df = pd.DataFrame( - data=states["incidence"].to_numpy().reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), + data=states["incidence"] + .to_numpy() + .reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), index=ts_index, columns=modinf.subpop_struct.subpop_names, ).reset_index() @@ -384,7 +437,9 @@ def write_spar_snpi(sim_id, modinf, p_draw, npi): if npi is not None: modinf.write_simID(ftype="snpi", sim_id=sim_id, df=npi.getReductionDF()) # Parameters - modinf.write_simID(ftype="spar", sim_id=sim_id, df=modinf.parameters.getParameterDF(p_draw=p_draw)) + modinf.write_simID( + ftype="spar", sim_id=sim_id, df=modinf.parameters.getParameterDF(p_draw=p_draw) + ) def write_seir(sim_id, modinf, states): diff --git a/flepimop/gempyor_pkg/src/gempyor/simulate.py b/flepimop/gempyor_pkg/src/gempyor/simulate.py index 34fdd9d4b..d97b94a93 100644 --- a/flepimop/gempyor_pkg/src/gempyor/simulate.py +++ b/flepimop/gempyor_pkg/src/gempyor/simulate.py @@ -299,23 +299,31 @@ def simulate( seir_modifiers_scenarios = None if config["seir_modifiers"].exists(): if config["seir_modifiers"]["scenarios"].exists(): - seir_modifiers_scenarios = config["seir_modifiers"]["scenarios"].as_str_seq() + seir_modifiers_scenarios = config["seir_modifiers"][ + "scenarios" + ].as_str_seq() # Model Info handles the case of the default scneario if not outcome_modifiers_scenarios: outcome_modifiers_scenarios = None if config["outcomes"].exists() and config["outcome_modifiers"].exists(): if config["outcome_modifiers"]["scenarios"].exists(): - outcome_modifiers_scenarios = config["outcome_modifiers"]["scenarios"].as_str_seq() + outcome_modifiers_scenarios = config["outcome_modifiers"][ + "scenarios" + ].as_str_seq() outcome_modifiers_scenarios = as_list(outcome_modifiers_scenarios) seir_modifiers_scenarios = as_list(seir_modifiers_scenarios) print(outcome_modifiers_scenarios, seir_modifiers_scenarios) - scenarios_combinations = [[s, d] for s in seir_modifiers_scenarios for d in outcome_modifiers_scenarios] + scenarios_combinations = [ + [s, d] for s in seir_modifiers_scenarios for d in outcome_modifiers_scenarios + ] print("Combination of modifiers scenarios to be run: ") print(scenarios_combinations) for seir_modifiers_scenario, outcome_modifiers_scenario in scenarios_combinations: - print(f"seir_modifier: {seir_modifiers_scenario}, outcomes_modifier:{outcome_modifiers_scenario}") + print( + f"seir_modifier: {seir_modifiers_scenario}, outcomes_modifier:{outcome_modifiers_scenario}" + ) if not nslots: nslots = config["nslots"].as_number() @@ -354,7 +362,9 @@ def simulate( if config["seir"].exists(): seir.run_parallel_SEIR(modinf, config=config, n_jobs=jobs) if config["outcomes"].exists(): - outcomes.run_parallel_outcomes(sim_id2write=first_sim_index, modinf=modinf, nslots=nslots, n_jobs=jobs) + outcomes.run_parallel_outcomes( + sim_id2write=first_sim_index, modinf=modinf, nslots=nslots, n_jobs=jobs + ) print( f">>> {seir_modifiers_scenario}_{outcome_modifiers_scenario} completed in {time.monotonic() - start:.1f} seconds" ) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index ea2cc72a3..a9b8ce8a6 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -48,7 +48,10 @@ def __init__(self, name, statistic_config: confuse.ConfigView): if resample_config["aggregator"].exists(): self.resample_aggregator_name = resample_config["aggregator"].get() self.resample_skipna = False # TODO - if resample_config["aggregator"].exists() and resample_config["skipna"].exists(): + if ( + resample_config["aggregator"].exists() + and resample_config["skipna"].exists() + ): self.resample_skipna = resample_config["skipna"].get() self.scale = False @@ -71,7 +74,10 @@ def _forecast_regularize(self, model_data, gt_data, **kwargs): last_n = kwargs.get("last_n", 4) mult = kwargs.get("mult", 2) - last_n_llik = self.llik(model_data.isel(date=slice(-last_n, None)), gt_data.isel(date=slice(-last_n, None))) + last_n_llik = self.llik( + model_data.isel(date=slice(-last_n, None)), + gt_data.isel(date=slice(-last_n, None)), + ) return mult * last_n_llik.sum().sum().values @@ -89,7 +95,9 @@ def __repr__(self) -> str: def apply_resample(self, data): if self.resample: - aggregator_method = getattr(data.resample(date=self.resample_freq), self.resample_aggregator_name) + aggregator_method = getattr( + data.resample(date=self.resample_freq), self.resample_aggregator_name + ) return aggregator_method(skipna=self.resample_skipna) else: return data @@ -113,7 +121,9 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): "norm_cov": lambda x, loc, scale: scipy.stats.norm.logpdf( x, loc=loc, scale=scale * loc.where(loc > 5, 5) ), # TODO: check, that it's really the loc - "nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf(x, n=self.params.get("n"), p=model_data), + "nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf( + x, n=self.params.get("n"), p=model_data + ), "rmse": lambda x, y: -np.log(np.nansum(np.sqrt((x - y) ** 2))), "absolute_error": lambda x, y: -np.log(np.nansum(np.abs(x - y))), } @@ -147,6 +157,8 @@ def compute_logloss(self, model_data, gt_data): regularization = 0 for reg_func, reg_config in self.regularizations: - regularization += reg_func(model_data=model_data, gt_data=gt_data, **reg_config) # Pass config parameters + regularization += reg_func( + model_data=model_data, gt_data=gt_data, **reg_config + ) # Pass config parameters return self.llik(model_data, gt_data).sum("date"), regularization diff --git a/flepimop/gempyor_pkg/src/gempyor/steps_rk4.py b/flepimop/gempyor_pkg/src/gempyor/steps_rk4.py index 2cb09c99d..74813be09 100644 --- a/flepimop/gempyor_pkg/src/gempyor/steps_rk4.py +++ b/flepimop/gempyor_pkg/src/gempyor/steps_rk4.py @@ -55,7 +55,11 @@ def rk4_integration( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): proportion_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -63,11 +67,19 @@ def rk4_integration( @jit(nopython=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - st_next = states_current.copy() # this is used to make sure stochastic integration never goes below zero - transition_amounts = np.zeros((ntransitions, nspatial_nodes)) # keep track of the transitions + st_next = ( + states_current.copy() + ) # this is used to make sure stochastic integration never goes below zero + transition_amounts = np.zeros( + (ntransitions, nspatial_nodes) + ) # keep track of the transitions if (x < 0).any(): - print("Integration error: rhs got a negative x (pos, time)", np.where(x < 0), t) + print( + "Integration error: rhs got a negative x (pos, time)", + np.where(x < 0), + t, + ) for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -85,56 +97,76 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] # chadi: i believe what this mean that the first proportion is always the # source compartment. That's why there is nothing with n_spatial node here. # but (TODO) we should enforce that ? if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp # does this mean we need the first to be "source" ??? yes ! if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * proportion_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * proportion_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_subpop = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment * ( - relevant_number_in_comp[visiting_subpop] ** relevant_exponent[visiting_subpop] + relevant_number_in_comp[visiting_subpop] + ** relevant_exponent[visiting_subpop] ) rate_change_compartment /= population[visiting_subpop] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_subpop] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_subpop] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compute the number of individual transitioning from source to destination from the total rate # number_move has shape (nspatial_nodes) @@ -143,7 +175,9 @@ def rhs(t, x, today): elif method == "legacy": compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) if stochastic_p: - number_move = source_number * compound_adjusted_rate ## to initialize typ + number_move = ( + source_number * compound_adjusted_rate + ) ## to initialize typ for spatial_node in range(nspatial_nodes): number_move[spatial_node] = np.random.binomial( # number_move[spatial_node] = random.binomial( @@ -162,7 +196,9 @@ def rhs(t, x, today): @jit(nopython=True) def update_states(states, delta_t, transition_amounts): - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum st_next = states.copy() st_next = np.reshape(st_next, (2, ncompartments, nspatial_nodes)) if method == "rk4": @@ -181,23 +217,34 @@ def update_states(states, delta_t, transition_amounts): ) if ( transition_amounts[transition_index][spatial_node] - >= st_next[0][transitions[transition_source_col][transition_index]][spatial_node] - float_tolerance + >= st_next[0][transitions[transition_source_col][transition_index]][ + spatial_node + ] + - float_tolerance ): transition_amounts[transition_index][spatial_node] = max( - st_next[0][transitions[transition_source_col][transition_index]][spatial_node] + st_next[0][ + transitions[transition_source_col][transition_index] + ][spatial_node] - float_tolerance, 0, ) - st_next[0][transitions[transition_source_col][transition_index]] -= transition_amounts[transition_index] - st_next[0][transitions[transition_destination_col][transition_index]] += transition_amounts[ - transition_index - ] - - states_diff[0, transitions[transition_source_col][transition_index]] -= transition_amounts[transition_index] - states_diff[0, transitions[transition_destination_col][transition_index]] += transition_amounts[ - transition_index - ] - states_diff[1, transitions[transition_destination_col][transition_index], :] += transition_amounts[ + st_next[0][ + transitions[transition_source_col][transition_index] + ] -= transition_amounts[transition_index] + st_next[0][ + transitions[transition_destination_col][transition_index] + ] += transition_amounts[transition_index] + + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= transition_amounts[transition_index] + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += transition_amounts[transition_index] + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += transition_amounts[ transition_index ] # Cumumlative @@ -214,7 +261,9 @@ def rk4_integrate(t, x, today): yesterday = -1 times = np.arange(0, (ndays - 1) + 1e-7, dt) - for time_index, time in tqdm.tqdm(enumerate(times), disable=silent): # , total=len(times) + for time_index, time in tqdm.tqdm( + enumerate(times), disable=silent + ): # , total=len(times) today = int(np.floor(time)) is_a_new_day = today != yesterday yesterday = today @@ -224,21 +273,31 @@ def rk4_integrate(t, x, today): states[today, :, :] = states_next for seeding_instance_idx in range( seeding_data["day_start_idx"][today], - seeding_data["day_start_idx"][min(today + int(np.ceil(dt)), len(seeding_data["day_start_idx"]) - 1)], + seeding_data["day_start_idx"][ + min( + today + int(np.ceil(dt)), len(seeding_data["day_start_idx"]) - 1 + ) + ], ): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts x_ = np.zeros((2, ncompartments, nspatial_nodes)) x_[0] = states_next @@ -271,13 +330,19 @@ def rk4_integrate(t, x, today): error = False ## Perform some checks: if np.isnan(states_daily_incid).any() or np.isnan(states).any(): - logging.critical("Integration error: NaN detected in epidemic integration result. Failing...") + logging.critical( + "Integration error: NaN detected in epidemic integration result. Failing..." + ) error = True if not (np.isfinite(states_daily_incid).all() and np.isfinite(states).all()): - logging.critical("Integration error: Inf detected in epidemic integration result. Failing...") + logging.critical( + "Integration error: Inf detected in epidemic integration result. Failing..." + ) error = True if (states_daily_incid < 0).any() or (states < 0).any(): - logging.critical("Integration error: negative values detected in epidemic integration result. Failing...") + logging.critical( + "Integration error: negative values detected in epidemic integration result. Failing..." + ) # todo: this, but smart so it doesn't fail if empty array # print( # f"STATES: NNZ:{states[states < 0].size}/{states.size}, max:{np.max(states[states < 0])}, min:{np.min(states[states < 0])}, mean:{np.mean(states[states < 0])} median:{np.median(states[states < 0])}" @@ -318,6 +383,8 @@ def rk4_integrate(t, x, today): print( "load the name space with: \nwith open('integration_dump.pkl','rb') as fn_dump:\n states, states_daily_incid, ncompartments, nspatial_nodes, ndays, parameters, dt, transitions, proportion_info, transition_sum_compartments, initial_conditions, seeding_data, seeding_amounts, mobility_data, mobility_row_indices, mobility_data_indices, population, stochastic_p, method = pickle.load(fn_dump)" ) - print("/!\\ Invalid integration, will cause problems for downstream users /!\\ ") + print( + "/!\\ Invalid integration, will cause problems for downstream users /!\\ " + ) # raise ValueError("Invalid Integration...") return states, states_daily_incid diff --git a/flepimop/gempyor_pkg/src/gempyor/steps_source.py b/flepimop/gempyor_pkg/src/gempyor/steps_source.py index b8af1d493..ba46badc7 100644 --- a/flepimop/gempyor_pkg/src/gempyor/steps_source.py +++ b/flepimop/gempyor_pkg/src/gempyor/steps_source.py @@ -30,9 +30,12 @@ ## Return "UniTuple(float64[:, :, :], 2) (" ## return states and cumlative states, both [ ndays x ncompartments x nspatial_nodes ] ## Dimensions - "int32," "int32," "int32," ## ncompartments ## nspatial_nodes ## Number of days + "int32," + "int32," + "int32," ## ncompartments ## nspatial_nodes ## Number of days ## Parameters - "float64[:, :, :]," "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt + "float64[:, :, :]," + "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt ## Transitions "int64[:, :]," ## transitions [ [source, destination, proportion_start, proportion_stop, rate] x ntransitions ] "int64[:, :]," ## proportions_info [ [sum_starts, sum_stops, exponent] x ntransition_proportions ] @@ -84,7 +87,11 @@ def steps_SEIR_nb( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -111,18 +118,24 @@ def steps_SEIR_nb( this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_infected = 0 for transition_index in range(ntransitions): @@ -138,52 +151,72 @@ def steps_SEIR_nb( proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -221,14 +254,22 @@ def steps_SEIR_nb( for spatial_node in range(nspatial_nodes): if ( number_move[spatial_node] - > states_next[transitions[transition_source_col][transition_index]][spatial_node] - ): - number_move[spatial_node] = states_next[transitions[transition_source_col][transition_index]][ + > states_next[transitions[transition_source_col][transition_index]][ spatial_node ] - states_next[transitions[transition_source_col][transition_index]] -= number_move - states_next[transitions[transition_destination_col][transition_index]] += number_move - states_daily_incid[today, transitions[transition_destination_col][transition_index], :] += number_move + ): + number_move[spatial_node] = states_next[ + transitions[transition_source_col][transition_index] + ][spatial_node] + states_next[ + transitions[transition_source_col][transition_index] + ] -= number_move + states_next[ + transitions[transition_destination_col][transition_index] + ] += number_move + states_daily_incid[ + today, transitions[transition_destination_col][transition_index], : + ] += number_move states_current = states_next.copy() diff --git a/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py b/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py index 1e2b9b8de..58192bb9e 100644 --- a/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py +++ b/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py @@ -27,7 +27,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): self.setup_name = setup_name self.data = pd.read_csv( - geodata_file, converters={subpop_names_key: lambda x: str(x).strip()}, skipinitialspace=True + geodata_file, + converters={subpop_names_key: lambda x: str(x).strip()}, + skipinitialspace=True, ) # subpops and populations, strip whitespaces self.nsubpops = len(self.data) # K = # of locations @@ -44,7 +46,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): # subpop_names_key is the name of the column in geodata_file with subpops if subpop_names_key not in self.data: - raise ValueError(f"subpop_names_key: {subpop_names_key} does not correspond to a column in geodata.") + raise ValueError( + f"subpop_names_key: {subpop_names_key} does not correspond to a column in geodata." + ) self.subpop_names = self.data[subpop_names_key].tolist() if len(self.subpop_names) != len(set(self.subpop_names)): raise ValueError(f"There are duplicate subpop_names in geodata.") @@ -53,7 +57,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): mobility_file = path_prefix / subpop_config["mobility"].get() mobility_file = pathlib.Path(mobility_file) if mobility_file.suffix == ".txt": - print("Mobility files as matrices are not recommended. Please switch soon to long form csv files.") + print( + "Mobility files as matrices are not recommended. Please switch soon to long form csv files." + ) self.mobility = scipy.sparse.csr_matrix( np.loadtxt(mobility_file), dtype=int ) # K x K matrix of people moving @@ -64,17 +70,28 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): ) elif mobility_file.suffix == ".csv": - mobility_data = pd.read_csv(mobility_file, converters={"ori": str, "dest": str}, skipinitialspace=True) + mobility_data = pd.read_csv( + mobility_file, + converters={"ori": str, "dest": str}, + skipinitialspace=True, + ) nn_dict = {v: k for k, v in enumerate(self.subpop_names)} - mobility_data["ori_idx"] = mobility_data["ori"].apply(nn_dict.__getitem__) - mobility_data["dest_idx"] = mobility_data["dest"].apply(nn_dict.__getitem__) + mobility_data["ori_idx"] = mobility_data["ori"].apply( + nn_dict.__getitem__ + ) + mobility_data["dest_idx"] = mobility_data["dest"].apply( + nn_dict.__getitem__ + ) if any(mobility_data["ori_idx"] == mobility_data["dest_idx"]): raise ValueError( f"Mobility fluxes with same origin and destination in long form matrix. This is not supported" ) self.mobility = scipy.sparse.coo_matrix( - (mobility_data.amount, (mobility_data.ori_idx, mobility_data.dest_idx)), + ( + mobility_data.amount, + (mobility_data.ori_idx, mobility_data.dest_idx), + ), shape=(self.nsubpops, self.nsubpops), dtype=int, ).tocsr() @@ -115,7 +132,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): ) else: logging.critical("No mobility matrix specified -- assuming no one moves") - self.mobility = scipy.sparse.csr_matrix(np.zeros((self.nsubpops, self.nsubpops)), dtype=int) + self.mobility = scipy.sparse.csr_matrix( + np.zeros((self.nsubpops, self.nsubpops)), dtype=int + ) if subpop_config["selected"].exists(): selected = subpop_config["selected"].get() @@ -129,4 +148,6 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): self.subpop_names = selected self.nsubpops = len(self.data) # TODO: this needs to be tested - self.mobility = self.mobility[selected_subpop_indices][:, selected_subpop_indices] + self.mobility = self.mobility[selected_subpop_indices][ + :, selected_subpop_indices + ] diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 6131b5c52..873f105d0 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -103,7 +103,9 @@ def command_safe_run(command, command_name="mycommand", fail_on_fail=True): import subprocess import shlex # using shlex to split the command because it's not obvious https://docs.python.org/3/library/subprocess.html#subprocess.Popen - sr = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + sr = subprocess.Popen( + shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) (stdout, stderr) = sr.communicate() if sr.returncode != 0: print(f"{command_name} failed failed with returncode {sr.returncode}") @@ -137,7 +139,9 @@ def wrapper(*args, **kwargs): return decorator -def search_and_import_plugins_class(plugin_file_path: str, path_prefix: str, class_name: str, **kwargs): +def search_and_import_plugins_class( + plugin_file_path: str, path_prefix: str, class_name: str, **kwargs +): # Look for all possible plugins and import them # https://stackoverflow.com/questions/67631/how-can-i-import-a-module-dynamically-given-the-full-path # unfortunatelly very complicated, this is cpython only ?? @@ -159,7 +163,9 @@ def search_and_import_plugins_class(plugin_file_path: str, path_prefix: str, cla from functools import wraps -def profile(output_file=None, sort_by="cumulative", lines_to_print=None, strip_dirs=False): +def profile( + output_file=None, sort_by="cumulative", lines_to_print=None, strip_dirs=False +): """A time profiler decorator. Inspired by and modified the profile decorator of Giampaolo Rodola: http://code.activestate.com/recipes/577817-profile-decorator/ @@ -307,14 +313,20 @@ def as_random_distribution(self): dist = self["distribution"].get() if dist == "fixed": return functools.partial( - np.random.uniform, self["value"].as_evaled_expression(), self["value"].as_evaled_expression(), + np.random.uniform, + self["value"].as_evaled_expression(), + self["value"].as_evaled_expression(), ) elif dist == "uniform": return functools.partial( - np.random.uniform, self["low"].as_evaled_expression(), self["high"].as_evaled_expression(), + np.random.uniform, + self["low"].as_evaled_expression(), + self["high"].as_evaled_expression(), ) elif dist == "poisson": - return functools.partial(np.random.poisson, self["lam"].as_evaled_expression()) + return functools.partial( + np.random.poisson, self["lam"].as_evaled_expression() + ) elif dist == "binomial": p = self["p"].as_evaled_expression() if (p < 0) or (p > 1): @@ -336,13 +348,18 @@ def as_random_distribution(self): ).rvs elif dist == "lognorm": return get_log_normal( - meanlog=self["meanlog"].as_evaled_expression(), sdlog=self["sdlog"].as_evaled_expression(), + meanlog=self["meanlog"].as_evaled_expression(), + sdlog=self["sdlog"].as_evaled_expression(), ).rvs else: raise NotImplementedError(f"unknown distribution [got: {dist}]") else: # we allow a fixed value specified directly: - return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),) + return functools.partial( + np.random.uniform, + self.as_evaled_expression(), + self.as_evaled_expression(), + ) def list_filenames( @@ -431,14 +448,14 @@ def rolling_mean_pad( [20.2, 21.2, 22.2, 23.2], [22.6, 23.6, 24.6, 25.6]]) """ - weights = (1. / window) * np.ones(window) + weights = (1.0 / window) * np.ones(window) output = scipy.ndimage.convolve1d(data, weights, axis=0, mode="nearest") if window % 2 == 0: rows, cols = data.shape i = rows - 1 - output[i, :] = 0. + output[i, :] = 0.0 window -= 1 - weight = 1. / window + weight = 1.0 / window for l in range(-((window - 1) // 2), 1 + (window // 2)): i_star = min(max(i + l, 0), i) for j in range(cols): @@ -472,7 +489,12 @@ def bash(command): def create_resume_out_filename( - flepi_run_index: str, flepi_prefix: str, flepi_slot_index: str, flepi_block_index: str, filetype: str, liketype: str + flepi_run_index: str, + flepi_prefix: str, + flepi_slot_index: str, + flepi_block_index: str, + filetype: str, + liketype: str, ) -> str: prefix = f"{flepi_prefix}/{flepi_run_index}" inference_filepath_suffix = f"{liketype}/intermediate" @@ -493,7 +515,11 @@ def create_resume_out_filename( def create_resume_input_filename( - resume_run_index: str, flepi_prefix: str, flepi_slot_index: str, filetype: str, liketype: str + resume_run_index: str, + flepi_prefix: str, + flepi_slot_index: str, + filetype: str, + liketype: str, ) -> str: prefix = f"{flepi_prefix}/{resume_run_index}" inference_filepath_suffix = f"{liketype}/final" @@ -511,7 +537,9 @@ def create_resume_input_filename( ) -def get_filetype_for_resume(resume_discard_seeding: str, flepi_block_index: str) -> List[str]: +def get_filetype_for_resume( + resume_discard_seeding: str, flepi_block_index: str +) -> List[str]: """ Retrieves a list of parquet file types that are relevant for resuming a process based on specific environment variable settings. This function dynamically determines the list @@ -564,7 +592,8 @@ def create_resume_file_names_map( behavior. """ file_types = get_filetype_for_resume( - resume_discard_seeding=resume_discard_seeding, flepi_block_index=flepi_block_index + resume_discard_seeding=resume_discard_seeding, + flepi_block_index=flepi_block_index, ) resume_file_name_mapping = dict() liketypes = ["global", "chimeric"] @@ -638,10 +667,12 @@ def download_file_from_s3(name_map: Dict[str, str]) -> None: import boto3 from botocore.exceptions import ClientError except ModuleNotFoundError: - raise ModuleNotFoundError(( - "No module named 'boto3', which is required for " - "gempyor.utils.download_file_from_s3. Please install the aws target." - )) + raise ModuleNotFoundError( + ( + "No module named 'boto3', which is required for " + "gempyor.utils.download_file_from_s3. Please install the aws target." + ) + ) s3 = boto3.client("s3") first_output_filename = next(iter(name_map.values())) output_dir = os.path.dirname(first_output_filename) @@ -664,13 +695,13 @@ def move_file_at_local(name_map: Dict[str, str]) -> None: """ Moves files locally according to a given mapping. - This function takes a dictionary where the keys are source file paths and - the values are destination file paths. It ensures that the destination - directories exist and then copies the files from the source paths to the + This function takes a dictionary where the keys are source file paths and + the values are destination file paths. It ensures that the destination + directories exist and then copies the files from the source paths to the destination paths. Parameters: - name_map (Dict[str, str]): A dictionary mapping source file paths to + name_map (Dict[str, str]): A dictionary mapping source file paths to destination file paths. Returns: diff --git a/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py b/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py index 7e3c2bc59..ecf2a84e9 100644 --- a/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py +++ b/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py @@ -49,7 +49,10 @@ def test_SinglePeriodModifier_start_date_fail(self): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config_test.yml") - with pytest.raises(ValueError, match=r".*at least one period start or end date is not between.*"): + with pytest.raises( + ValueError, + match=r".*at least one period start or end date is not between.*", + ): s = model_info.ModelInfo( setup_name="test_seir", config=config, @@ -72,7 +75,10 @@ def test_SinglePeriodModifier_end_date_fail(self): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config_test.yml") - with pytest.raises(ValueError, match=r".*at least one period start or end date is not between.*"): + with pytest.raises( + ValueError, + match=r".*at least one period start or end date is not between.*", + ): s = model_info.ModelInfo( setup_name="test_seir", config=config, diff --git a/flepimop/gempyor_pkg/tests/npi/test_npis.py b/flepimop/gempyor_pkg/tests/npi/test_npis.py index bad306b1e..8a3e5bb2f 100644 --- a/flepimop/gempyor_pkg/tests/npi/test_npis.py +++ b/flepimop/gempyor_pkg/tests/npi/test_npis.py @@ -47,12 +47,18 @@ def test_full_npis_read_write(): # inference_simulator.s, load_ID=False, sim_id2load=None, config=config # ) - inference_simulator.modinf.write_simID(ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF()) + inference_simulator.modinf.write_simID( + ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF() + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() hnpi_read["value"] = np.random.random(len(hnpi_read)) * 2 - 1 out_hnpi = pa.Table.from_pandas(hnpi_read, preserve_index=False) - pa.parquet.write_table(out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet")) + pa.parquet.write_table( + out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet") + ) import random random.seed(10) @@ -74,10 +80,16 @@ def test_full_npis_read_write(): npi_outcomes = outcomes.build_outcome_modifiers( inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config ) - inference_simulator.modinf.write_simID(ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF()) + inference_simulator.modinf.write_simID( + ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF() + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() # runs with the new, random NPI @@ -98,10 +110,16 @@ def test_full_npis_read_write(): npi_outcomes = outcomes.build_outcome_modifiers( inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config ) - inference_simulator.modinf.write_simID(ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF()) + inference_simulator.modinf.write_simID( + ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF() + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() @@ -116,10 +134,15 @@ def test_spatial_groups(): ) # Test build from config, value of the reduction array - npi = seir.build_npi_SEIR(inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config) + npi = seir.build_npi_SEIR( + inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config + ) # all independent: r1 - assert len(npi.getReduction("r1")["2021-01-01"].unique()) == inference_simulator.modinf.nsubpops + assert ( + len(npi.getReduction("r1")["2021-01-01"].unique()) + == inference_simulator.modinf.nsubpops + ) assert npi.getReduction("r1").isna().sum().sum() == 0 # all the same: r2 @@ -127,26 +150,41 @@ def test_spatial_groups(): assert npi.getReduction("r2").isna().sum().sum() == 0 # two groups: r3 - assert len(npi.getReduction("r3")["2020-04-15"].unique()) == inference_simulator.modinf.nsubpops - 2 + assert ( + len(npi.getReduction("r3")["2020-04-15"].unique()) + == inference_simulator.modinf.nsubpops - 2 + ) assert npi.getReduction("r3").isna().sum().sum() == 0 - assert len(npi.getReduction("r3").loc[["01000", "02000"], "2020-04-15"].unique()) == 1 - assert len(npi.getReduction("r3").loc[["04000", "06000"], "2020-04-15"].unique()) == 1 + assert ( + len(npi.getReduction("r3").loc[["01000", "02000"], "2020-04-15"].unique()) == 1 + ) + assert ( + len(npi.getReduction("r3").loc[["04000", "06000"], "2020-04-15"].unique()) == 1 + ) # one group: r4 assert ( len(npi.getReduction("r4")["2020-04-15"].unique()) == 4 ) # 0 for these not included, 1 unique for the group, and two for the rest assert npi.getReduction("r4").isna().sum().sum() == 0 - assert len(npi.getReduction("r4").loc[["01000", "02000"], "2020-04-15"].unique()) == 1 - assert len(npi.getReduction("r4").loc[["04000", "06000"], "2020-04-15"].unique()) == 2 + assert ( + len(npi.getReduction("r4").loc[["01000", "02000"], "2020-04-15"].unique()) == 1 + ) + assert ( + len(npi.getReduction("r4").loc[["04000", "06000"], "2020-04-15"].unique()) == 2 + ) assert (npi.getReduction("r4").loc[["05000", "08000"], "2020-04-15"] == 0).all() # mtr group: r5 assert npi.getReduction("r5").isna().sum().sum() == 0 assert len(npi.getReduction("r5")["2020-12-15"].unique()) == 2 assert len(npi.getReduction("r5")["2020-10-15"].unique()) == 4 - assert len(npi.getReduction("r5").loc[["01000", "04000"], "2020-10-15"].unique()) == 1 - assert len(npi.getReduction("r5").loc[["02000", "06000"], "2020-10-15"].unique()) == 2 + assert ( + len(npi.getReduction("r5").loc[["01000", "04000"], "2020-10-15"].unique()) == 1 + ) + assert ( + len(npi.getReduction("r5").loc[["02000", "06000"], "2020-10-15"].unique()) == 2 + ) # test the dataframes that are wrote. npi_df = npi.getReductionDF() @@ -160,7 +198,9 @@ def test_spatial_groups(): # all the same: r2 df = npi_df[npi_df["modifier_name"] == "all_together"] assert len(df) == 1 - assert set(df["subpop"].iloc[0].split(",")) == set(inference_simulator.modinf.subpop_struct.subpop_names) + assert set(df["subpop"].iloc[0].split(",")) == set( + inference_simulator.modinf.subpop_struct.subpop_names + ) assert len(df["subpop"].iloc[0].split(",")) == inference_simulator.modinf.nsubpops # two groups: r3 @@ -175,7 +215,10 @@ def test_spatial_groups(): df = npi_df[npi_df["modifier_name"] == "mt_reduce"] assert len(df) == 4 assert df.subpop.to_list() == ["09000,10000", "02000", "06000", "01000,04000"] - assert df[df["subpop"] == "09000,10000"]["start_date"].iloc[0] == "2020-12-01,2021-12-01" + assert ( + df[df["subpop"] == "09000,10000"]["start_date"].iloc[0] + == "2020-12-01,2021-12-01" + ) assert ( df[df["subpop"] == "01000,04000"]["start_date"].iloc[0] == df[df["subpop"] == "06000"]["start_date"].iloc[0] @@ -194,15 +237,21 @@ def test_spatial_groups(): ) # Test build from config, value of the reduction array - npi = seir.build_npi_SEIR(inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config) + npi = seir.build_npi_SEIR( + inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config + ) npi_df = npi.getReductionDF() inference_simulator.modinf.write_simID(ftype="snpi", sim_id=1, df=npi_df) - snpi_read = pq.read_table(f"{config_filepath_prefix}model_output/snpi/000000001.105.snpi.parquet").to_pandas() + snpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/snpi/000000001.105.snpi.parquet" + ).to_pandas() snpi_read["value"] = np.random.random(len(snpi_read)) * 2 - 1 out_snpi = pa.Table.from_pandas(snpi_read, preserve_index=False) - pa.parquet.write_table(out_snpi, file_paths.create_file_name(106, "", 1, "snpi", "parquet")) + pa.parquet.write_table( + out_snpi, file_paths.create_file_name(106, "", 1, "snpi", "parquet") + ) inference_simulator = gempyor.GempyorInference( config_filepath=f"{config_filepath_prefix}config_test_spatial_group_npi.yml", @@ -213,22 +262,42 @@ def test_spatial_groups(): out_run_id=107, ) - npi_seir = seir.build_npi_SEIR(inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config) - inference_simulator.modinf.write_simID(ftype="snpi", sim_id=1, df=npi_seir.getReductionDF()) + npi_seir = seir.build_npi_SEIR( + inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config + ) + inference_simulator.modinf.write_simID( + ftype="snpi", sim_id=1, df=npi_seir.getReductionDF() + ) - snpi_read = pq.read_table(f"{config_filepath_prefix}model_output/snpi/000000001.106.snpi.parquet").to_pandas() - snpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/snpi/000000001.107.snpi.parquet").to_pandas() + snpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/snpi/000000001.106.snpi.parquet" + ).to_pandas() + snpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/snpi/000000001.107.snpi.parquet" + ).to_pandas() # now the order can change, so we need to sort by subpop and start_date - snpi_wrote = snpi_wrote.sort_values(by=["subpop", "start_date"]).reset_index(drop=True) - snpi_read = snpi_read.sort_values(by=["subpop", "start_date"]).reset_index(drop=True) + snpi_wrote = snpi_wrote.sort_values(by=["subpop", "start_date"]).reset_index( + drop=True + ) + snpi_read = snpi_read.sort_values(by=["subpop", "start_date"]).reset_index( + drop=True + ) assert (snpi_read == snpi_wrote).all().all() npi_read = seir.build_npi_SEIR( - inference_simulator.modinf, load_ID=False, sim_id2load=1, config=config, bypass_DF=snpi_read + inference_simulator.modinf, + load_ID=False, + sim_id2load=1, + config=config, + bypass_DF=snpi_read, ) npi_wrote = seir.build_npi_SEIR( - inference_simulator.modinf, load_ID=False, sim_id2load=1, config=config, bypass_DF=snpi_wrote + inference_simulator.modinf, + load_ID=False, + sim_id2load=1, + config=config, + bypass_DF=snpi_wrote, ) assert (npi_read.getReductionDF() == npi_wrote.getReductionDF()).all().all() diff --git a/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py b/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py index 56df652cf..aa186010c 100644 --- a/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py +++ b/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py @@ -36,7 +36,9 @@ prefix = "" sim_id = 1 -a = pd.read_parquet(file_paths.create_file_name(run_id, prefix, sim_id, "seir", "parquet")) +a = pd.read_parquet( + file_paths.create_file_name(run_id, prefix, sim_id, "seir", "parquet") +) print(a) # created by running SEIR test_seir.py (comment line 530 to remove file tree) first b = pd.read_parquet("../../SEIR/test/model_output/seir/000000101.test.seir.parquet") @@ -54,7 +56,9 @@ diffI = np.arange(5) * 2 date_data = datetime.date(2020, 4, 15) for i in range(5): - b.loc[(b["mc_value_type"] == "incidence") & (b["date"] == str(date_data)), subpop[i]] = diffI[i] + b.loc[ + (b["mc_value_type"] == "incidence") & (b["date"] == str(date_data)), subpop[i] + ] = diffI[i] pa_df = pa.Table.from_pandas(b, preserve_index=False) pa.parquet.write_table(pa_df, "new_test_no_vacc.parquet") diff --git a/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py b/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py index 35f490920..ed0942663 100644 --- a/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py +++ b/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py @@ -40,87 +40,150 @@ def test_outcome(): stoch_traj_flag=False, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False + ) - hosp = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.1.hosp.parquet").to_pandas() + hosp = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.1.hosp.parquet" + ).to_pandas() hosp.set_index("date", drop=True, inplace=True) for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: assert hosp[hosp["subpop"] == place]["incidI"][dt] == diffI[i] - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == diffI[i] * 0.01 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 * 0.4 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == diffI[i] * 0.01 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == diffI[i] * 0.1 * 0.4 + ) for j in range(7): - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + j)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + j) + ] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidI"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7) + ] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place]["incidH"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidICU"][dt] == 0 - hpar = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet").to_pandas() + hpar = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet" + ).to_pandas() for i, place in enumerate(subpop): assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.1 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 7 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "duration")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "duration") + ]["value"].iloc[0] ) == 7 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.01 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.4 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 0 ) @@ -136,15 +199,23 @@ def test_outcome_modifiers_scenario_with_load(): stoch_traj_flag=False, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False + ) - hpar_config = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet").to_pandas() - hpar_rel = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet").to_pandas() + hpar_config = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet" + ).to_pandas() + hpar_rel = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet" + ).to_pandas() for out in ["incidH", "incidD", "incidICU"]: for i, place in enumerate(subpop): a = hpar_rel[(hpar_rel["outcome"] == out) & (hpar_rel["subpop"] == place)] - b = hpar_config[(hpar_rel["outcome"] == out) & (hpar_config["subpop"] == place)] + b = hpar_config[ + (hpar_rel["outcome"] == out) & (hpar_config["subpop"] == place) + ] assert len(a) == len(b) for j in range(len(a)): if b.iloc[j]["quantity"] in ["delay", "duration"]: @@ -171,16 +242,30 @@ def test_outcomes_read_write_hpar(): stoch_traj_flag=False, out_run_id=3, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.3.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.3.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.2.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.3.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.2.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.3.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.2.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.3.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.2.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.3.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -201,7 +286,9 @@ def test_multishift_notstochdelays(): [36, 29], ] ) - shifts = np.array([[1, 0], [2, 1], [1, 0], [2, 2], [1, 2], [0, 1], [1, 1], [1, 2], [1, 2], [1, 0]]) + shifts = np.array( + [[1, 0], [2, 1], [1, 0], [2, 2], [1, 2], [0, 1], [1, 1], [1, 2], [1, 2], [1, 0]] + ) expected = np.array( [ [0, 39], @@ -216,7 +303,9 @@ def test_multishift_notstochdelays(): [12, 32], ] ) - assert (outcomes.multishift(array, shifts, stoch_delay_flag=False) == expected).all() + assert ( + outcomes.multishift(array, shifts, stoch_delay_flag=False) == expected + ).all() def test_outcomes_npi(): @@ -230,89 +319,152 @@ def test_outcomes_npi(): stoch_traj_flag=False, out_run_id=105, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf + ) - hosp = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() + hosp = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() hosp.set_index("date", drop=True, inplace=True) # same as config.yaml (doubled, then NPI halve it) for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: assert hosp[hosp["subpop"] == place]["incidI"][dt] == diffI[i] - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == diffI[i] * 0.01 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 * 0.4 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == diffI[i] * 0.01 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == diffI[i] * 0.1 * 0.4 + ) for j in range(7): - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + j)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + j) + ] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidI"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7) + ] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place]["incidH"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidICU"][dt] == 0 - hpar = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() + hpar = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() # Doubled everything from previous config.yaml for i, place in enumerate(subpop): assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.1 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "duration")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "duration") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.01 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 2 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.4 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 0 * 2 ) @@ -330,17 +482,31 @@ def test_outcomes_read_write_hnpi(): out_run_id=106, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -356,17 +522,27 @@ def test_outcomes_read_write_hnpi2(): out_run_id=106, ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() hnpi_read["value"] = np.random.random(len(hnpi_read)) * 2 - 1 out_hnpi = pa.Table.from_pandas(hnpi_read, preserve_index=False) - pa.parquet.write_table(out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet")) + pa.parquet.write_table( + out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet") + ) import random random.seed(10) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() # runs with the new, random NPI @@ -378,16 +554,30 @@ def test_outcomes_read_write_hnpi2(): stoch_traj_flag=False, out_run_id=107, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -402,89 +592,152 @@ def test_outcomes_npi_custom_pname(): stoch_traj_flag=False, out_run_id=105, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False, sim_id2load=1 + ) - hosp = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() + hosp = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() hosp.set_index("date", drop=True, inplace=True) # same as config.yaml (doubled, then NPI halve it) for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: assert hosp[hosp["subpop"] == place]["incidI"][dt] == diffI[i] - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == diffI[i] * 0.01 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 * 0.4 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == diffI[i] * 0.01 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == diffI[i] * 0.1 * 0.4 + ) for j in range(7): - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + j)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + j) + ] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidI"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7) + ] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place]["incidH"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidICU"][dt] == 0 - hpar = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() + hpar = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() # Doubled everything from previous config.yaml for i, place in enumerate(subpop): assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.1 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "duration")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "duration") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.01 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 2 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.4 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 0 * 2 ) @@ -502,16 +755,30 @@ def test_outcomes_read_write_hnpi_custom_pname(): out_run_id=106, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -520,10 +787,14 @@ def test_outcomes_read_write_hnpi2_custom_pname(): prefix = "" - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() hnpi_read["value"] = np.random.random(len(hnpi_read)) * 2 - 1 out_hnpi = pa.Table.from_pandas(hnpi_read, preserve_index=False) - pa.parquet.write_table(out_hnpi, file_paths.create_file_name(105, prefix, 1, "hnpi", "parquet")) + pa.parquet.write_table( + out_hnpi, file_paths.create_file_name(105, prefix, 1, "hnpi", "parquet") + ) import random random.seed(10) @@ -537,10 +808,16 @@ def test_outcomes_read_write_hnpi2_custom_pname(): out_run_id=106, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() # runs with the new, random NPI @@ -553,16 +830,30 @@ def test_outcomes_read_write_hnpi2_custom_pname(): out_run_id=107, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -580,7 +871,9 @@ def test_outcomes_pcomp(): ) p_compmult = [1, 3] - seir = pq.read_table(f"{config_filepath_prefix}model_output/seir/000000001.105.seir.parquet").to_pandas() + seir = pq.read_table( + f"{config_filepath_prefix}model_output/seir/000000001.105.seir.parquet" + ).to_pandas() seir2 = seir.copy() seir2["mc_vaccination_stage"] = "first_dose" @@ -591,10 +884,16 @@ def test_outcomes_pcomp(): seir2[pl] = seir2[pl] * p_compmult[1] new_seir = pd.concat([seir, seir2]) out_df = pa.Table.from_pandas(new_seir, preserve_index=False) - pa.parquet.write_table(out_df, file_paths.create_file_name(110, prefix, 1, "seir", "parquet")) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False) + pa.parquet.write_table( + out_df, file_paths.create_file_name(110, prefix, 1, "seir", "parquet") + ) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False + ) - hosp_f = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet").to_pandas() + hosp_f = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet" + ).to_pandas() hosp_f.set_index("date", drop=True, inplace=True) # same as config.yaml (doubled, then NPI halve it) for k, p_comp in enumerate(["0dose", "1dose"]): @@ -602,42 +901,90 @@ def test_outcomes_pcomp(): for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: - assert hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt] == diffI[i] * p_compmult[k] assert ( - hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][dt + datetime.timedelta(7)] + hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt] + == diffI[i] * p_compmult[k] + ) + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][ + dt + datetime.timedelta(7) + ] - diffI[i] * 0.1 * p_compmult[k] < 1e-8 ) assert ( - hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][dt + datetime.timedelta(2)] + hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][ + dt + datetime.timedelta(2) + ] - diffI[i] * 0.01 * p_compmult[k] < 1e-8 ) assert ( - hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][dt + datetime.timedelta(7)] + hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][ + dt + datetime.timedelta(7) + ] - diffI[i] * 0.1 * 0.4 * p_compmult[k] < 1e-8 ) for j in range(7): assert ( - hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][dt + datetime.timedelta(7 + j)] + hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][ + dt + datetime.timedelta(7 + j) + ] - diffI[i] * 0.1 * p_compmult[k] < 1e-8 ) - assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][ + dt + datetime.timedelta(7) + ] + == 0 + ) assert hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt] == 0 - assert hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][ + dt + datetime.timedelta(2) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][ + dt + datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][ + dt + datetime.timedelta(7) + ] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][dt] == 0 - assert hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][ + dt - datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][ + dt - datetime.timedelta(4) + ] + == 0 + ) assert hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][dt] == 0 - hpar_f = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet").to_pandas() + hpar_f = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet" + ).to_pandas() # Doubled everything from previous config.yaml # for k, p_comp in enumerate(["unvaccinated", "first_dose"]): for k, p_comp in enumerate(["0dose", "1dose"]): @@ -727,16 +1074,30 @@ def test_outcomes_pcomp_read_write(): out_run_id=112, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.112.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.112.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.111.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.112.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.111.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.112.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.112.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.112.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() diff --git a/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py b/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py index 1e0915b82..53c34e039 100644 --- a/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py +++ b/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py @@ -42,7 +42,9 @@ def test_parameters_from_timeserie_file(): ) # p = inference_simulator.s.parameters - p_draw = p.parameters_quick_draw(n_days=inference_simulator.s.n_days, nnodes=inference_simulator.s.nnodes) + p_draw = p.parameters_quick_draw( + n_days=inference_simulator.s.n_days, nnodes=inference_simulator.s.nnodes + ) p_df = p.getParameterDF(p_draw)["parameter"] diff --git a/flepimop/gempyor_pkg/tests/seir/test_compartments.py b/flepimop/gempyor_pkg/tests/seir/test_compartments.py index 1d4319e3b..0b40b5423 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_compartments.py +++ b/flepimop/gempyor_pkg/tests/seir/test_compartments.py @@ -10,7 +10,14 @@ import pyarrow.parquet as pq import filecmp -from gempyor import compartments, seir, NPI, file_paths, model_info, subpopulation_structure +from gempyor import ( + compartments, + seir, + NPI, + file_paths, + model_info, + subpopulation_structure, +) from gempyor.utils import config @@ -24,7 +31,9 @@ def test_check_transitions_parquet_creation(): config.set_file(f"{DATA_DIR}/config_compartmental_model_format.yml") original_compartments_file = f"{DATA_DIR}/parsed_compartment_compartments.parquet" original_transitions_file = f"{DATA_DIR}/parsed_compartment_transitions.parquet" - lhs = compartments.Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + lhs = compartments.Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) rhs = compartments.Compartments( seir_config=config["seir"], compartments_file=original_compartments_file, @@ -43,10 +52,16 @@ def test_check_transitions_parquet_writing_and_loading(): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config_compartmental_model_format.yml") - lhs = compartments.Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + lhs = compartments.Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) temp_compartments_file = f"{DATA_DIR}/parsed_compartment_compartments.test.parquet" temp_transitions_file = f"{DATA_DIR}/parsed_compartment_transitions.test.parquet" - lhs.toFile(compartments_file=temp_compartments_file, transitions_file=temp_transitions_file, write_parquet=True) + lhs.toFile( + compartments_file=temp_compartments_file, + transitions_file=temp_transitions_file, + write_parquet=True, + ) rhs = compartments.Compartments( seir_config=config["seir"], compartments_file=temp_compartments_file, @@ -86,4 +101,3 @@ def test_ModelInfo_has_compartments_component(): ) assert type(s.compartments) == compartments.Compartments assert type(s.compartments) == compartments.Compartments - diff --git a/flepimop/gempyor_pkg/tests/seir/test_ic.py b/flepimop/gempyor_pkg/tests/seir/test_ic.py index b4cd240ee..9e61aafcf 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_ic.py +++ b/flepimop/gempyor_pkg/tests/seir/test_ic.py @@ -21,7 +21,9 @@ def test_IC_success(self): outcome_modifiers_scenario=None, write_csv=False, ) - sic = initial_conditions.InitialConditionsFactory(config=s.initial_conditions_config) + sic = initial_conditions.InitialConditionsFactory( + config=s.initial_conditions_config + ) assert sic.initial_conditions_config == s.initial_conditions_config def test_IC_allow_missing_node_compartments_success(self): @@ -40,11 +42,15 @@ def test_IC_allow_missing_node_compartments_success(self): s.initial_conditions_config["allow_missing_nodes"] = True s.initial_conditions_config["allow_missing_compartments"] = True - sic = initial_conditions.InitialConditionsFactory(config=s.initial_conditions_config) + sic = initial_conditions.InitialConditionsFactory( + config=s.initial_conditions_config + ) sic.get_from_config(sim_id=100, modinf=s) def test_IC_IC_notImplemented_fail(self): - with pytest.raises(NotImplementedError, match=r".*unknown.*initial.*conditions.*"): + with pytest.raises( + NotImplementedError, match=r".*unknown.*initial.*conditions.*" + ): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config.yml") @@ -58,6 +64,8 @@ def test_IC_IC_notImplemented_fail(self): write_csv=False, ) s.initial_conditions_config["method"] = "unknown" - sic = initial_conditions.InitialConditionsFactory(config=s.initial_conditions_config) + sic = initial_conditions.InitialConditionsFactory( + config=s.initial_conditions_config + ) sic.get_from_config(sim_id=100, modinf=s) diff --git a/flepimop/gempyor_pkg/tests/seir/test_parameters.py b/flepimop/gempyor_pkg/tests/seir/test_parameters.py index 9e03bf87d..045e6643b 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_parameters.py +++ b/flepimop/gempyor_pkg/tests/seir/test_parameters.py @@ -10,7 +10,14 @@ import pyarrow.parquet as pq import filecmp -from gempyor import model_info, seir, NPI, file_paths, parameters, subpopulation_structure +from gempyor import ( + model_info, + seir, + NPI, + file_paths, + parameters, + subpopulation_structure, +) from gempyor.utils import config, write_df, read_df @@ -65,7 +72,9 @@ def test_parameters_from_config_plus_read_write(): tf=s.tf, subpop_names=s.subpop_struct.subpop_names, ) - p_load = rhs.parameters_load(param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops) + p_load = rhs.parameters_load( + param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops + ) assert (p_draw == p_load).all() @@ -102,9 +111,13 @@ def test_parameters_quick_draw_old(): assert params.pnames == ["alpha", "sigma", "gamma", "R0s"] assert params.npar == 4 assert params.stacked_modifier_method["sum"] == [] - assert params.stacked_modifier_method["product"] == [pn.lower() for pn in params.pnames] + assert params.stacked_modifier_method["product"] == [ + pn.lower() for pn in params.pnames + ] - p_array = params.parameters_quick_draw(n_days=modinf.n_days, nsubpops=modinf.nsubpops) + p_array = params.parameters_quick_draw( + n_days=modinf.n_days, nsubpops=modinf.nsubpops + ) print(p_array.shape) alpha = p_array[params.pnames2pindex["alpha"]] @@ -122,7 +135,12 @@ def test_parameters_quick_draw_old(): assert ((2 <= R0s) & (R0s <= 3)).all() assert sigma.shape == (modinf.n_days, modinf.nsubpops) - assert (sigma == config["seir"]["parameters"]["sigma"]["value"]["value"].as_evaled_expression()).all() + assert ( + sigma + == config["seir"]["parameters"]["sigma"]["value"][ + "value" + ].as_evaled_expression() + ).all() assert gamma.shape == (modinf.n_days, modinf.nsubpops) assert len(np.unique(gamma)) == 1 @@ -174,6 +192,8 @@ def test_parameters_from_timeseries_file(): tf=s.tf, subpop_names=s.subpop_struct.subpop_names, ) - p_load = rhs.parameters_load(param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops) + p_load = rhs.parameters_load( + param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops + ) assert (p_draw == p_load).all() diff --git a/flepimop/gempyor_pkg/tests/seir/test_seir.py b/flepimop/gempyor_pkg/tests/seir/test_seir.py index 99a4cc236..94a40d7d9 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_seir.py +++ b/flepimop/gempyor_pkg/tests/seir/test_seir.py @@ -73,8 +73,12 @@ def test_constant_population_legacy_integration(): ) integration_method = "legacy" - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -82,7 +86,9 @@ def test_constant_population_legacy_integration(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -94,7 +100,9 @@ def test_constant_population_legacy_integration(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) states = seir.steps_SEIR( modinf, @@ -142,16 +150,24 @@ def test_constant_population_rk4jit_integration_fail(): ) modinf.seir_config["integration"]["method"] = "rk4.jit" - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, modinf=modinf, modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, - pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method[ + "sum" + ], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -163,7 +179,9 @@ def test_constant_population_rk4jit_integration_fail(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) states = seir.steps_SEIR( modinf, @@ -212,8 +230,12 @@ def test_constant_population_rk4jit_integration(): # s.integration_method = "rk4.jit" assert modinf.seir_config["integration"]["method"].get() == "rk4" - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -221,7 +243,9 @@ def test_constant_population_rk4jit_integration(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -233,7 +257,9 @@ def test_constant_population_rk4jit_integration(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) states = seir.steps_SEIR( modinf, parsed_parameters, @@ -280,8 +306,12 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): out_prefix=prefix, ) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -289,7 +319,9 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -301,7 +333,9 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( @@ -316,11 +350,17 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "10001"] + df[ + (df["mc_value_type"] == "prevalence") + & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "10001"] > 1 ) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") + & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] > 1 ) @@ -336,11 +376,26 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") + & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] > 1 ) - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["20002"] > 0 - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["10001"] > 0 + assert ( + df[ + (df["mc_value_type"] == "incidence") + & (df["mc_infection_stage"] == "I1") + ].max()["20002"] + > 0 + ) + assert ( + df[ + (df["mc_value_type"] == "incidence") + & (df["mc_infection_stage"] == "I1") + ].max()["10001"] + > 0 + ) def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): @@ -366,8 +421,12 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): out_prefix=prefix, ) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -375,7 +434,9 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -387,7 +448,9 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( @@ -402,8 +465,20 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): ) df = seir.states2Df(modinf, states) - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["20002"] > 0 - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["10001"] > 0 + assert ( + df[ + (df["mc_value_type"] == "incidence") + & (df["mc_infection_stage"] == "I1") + ].max()["20002"] + > 0 + ) + assert ( + df[ + (df["mc_value_type"] == "incidence") + & (df["mc_infection_stage"] == "I1") + ].max()["10001"] + > 0 + ) def test_steps_SEIR_no_spread(): @@ -426,8 +501,12 @@ def test_steps_SEIR_no_spread(): out_prefix=prefix, ) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) modinf.mobility.data = modinf.mobility.data * 0 @@ -437,7 +516,9 @@ def test_steps_SEIR_no_spread(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -449,7 +530,9 @@ def test_steps_SEIR_no_spread(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(10): states = seir.steps_SEIR( @@ -464,7 +547,10 @@ def test_steps_SEIR_no_spread(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") + & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] == 0.0 ) @@ -480,7 +566,10 @@ def test_steps_SEIR_no_spread(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") + & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] == 0.0 ) @@ -515,7 +604,9 @@ def test_continuation_resume(): seir.onerun_SEIR(sim_id2write=int(sim_id2write), modinf=modinf, config=config) states_old = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, 100, "seir", "parquet"), + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, 100, "seir", "parquet" + ), ).to_pandas() states_old = states_old[states_old["date"] == "2020-03-15"].reset_index(drop=True) @@ -547,7 +638,9 @@ def test_continuation_resume(): seir.onerun_SEIR(sim_id2write=sim_id2write, modinf=modinf, config=config) states_new = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write, "seir", "parquet"), + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write, "seir", "parquet" + ), ).to_pandas() states_new = states_new[states_new["date"] == "2020-03-15"].reset_index(drop=True) assert ( @@ -560,10 +653,16 @@ def test_continuation_resume(): ) seir.onerun_SEIR( - sim_id2write=sim_id2write + 1, modinf=modinf, sim_id2load=sim_id2write, load_ID=True, config=config + sim_id2write=sim_id2write + 1, + modinf=modinf, + sim_id2load=sim_id2write, + load_ID=True, + config=config, ) states_new = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "seir", "parquet"), + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "seir", "parquet" + ), ).to_pandas() states_new = states_new[states_new["date"] == "2020-03-15"].reset_index(drop=True) for path in ["model_output/seir", "model_output/snpi", "model_output/spar"]: @@ -587,7 +686,9 @@ def test_inference_resume(): spatial_config = config["subpop_setup"] if "data_path" in config: - raise ValueError("The config has a data_path section. This is no longer supported.") + raise ValueError( + "The config has a data_path section. This is no longer supported." + ) # spatial_base_path = pathlib.Path(config["data_path"].get()) modinf = model_info.ModelInfo( config=config, @@ -604,7 +705,9 @@ def test_inference_resume(): seir.onerun_SEIR(sim_id2write=int(sim_id2write), modinf=modinf, config=config) npis_old = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write, "snpi", "parquet") + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write, "snpi", "parquet" + ) ).to_pandas() config.clear() @@ -632,14 +735,24 @@ def test_inference_resume(): ) seir.onerun_SEIR( - sim_id2write=sim_id2write + 1, modinf=modinf, sim_id2load=sim_id2write, load_ID=True, config=config + sim_id2write=sim_id2write + 1, + modinf=modinf, + sim_id2load=sim_id2write, + load_ID=True, + config=config, ) npis_new = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "snpi", "parquet") + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "snpi", "parquet" + ) ).to_pandas() assert npis_old["modifier_name"].isin(["None", "Wuhan", "KansasCity"]).all() - assert npis_new["modifier_name"].isin(["None", "Wuhan", "KansasCity", "BrandNew"]).all() + assert ( + npis_new["modifier_name"] + .isin(["None", "Wuhan", "KansasCity", "BrandNew"]) + .all() + ) # assert((['None', 'Wuhan', 'KansasCity']).isin(npis_old["modifier_name"]).all()) # assert((['None', 'Wuhan', 'KansasCity', 'BrandNew']).isin(npis_new["modifier_name"]).all()) assert (npis_old["start_date"] == "2020-04-01").all() @@ -674,8 +787,12 @@ def test_parallel_compartments_with_vacc(): out_prefix=prefix, ) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -683,7 +800,9 @@ def test_parallel_compartments_with_vacc(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -695,7 +814,9 @@ def test_parallel_compartments_with_vacc(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( @@ -761,8 +882,12 @@ def test_parallel_compartments_no_vacc(): out_prefix=prefix, ) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -770,7 +895,9 @@ def test_parallel_compartments_no_vacc(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -782,7 +909,9 @@ def test_parallel_compartments_no_vacc(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( diff --git a/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py b/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py index 34df630c3..9a11531a3 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py +++ b/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py @@ -25,7 +25,9 @@ def test_subpopulation_structure_mobility(): mobility: {DATA_DIR}/mobility.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path @@ -59,7 +61,9 @@ def test_subpopulation_structure_mobility_txt(): mobility: {DATA_DIR}/mobility.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path @@ -89,7 +93,9 @@ def test_subpopulation_structure_subpop_population_zero_fail(): mobility: {DATA_DIR}/mobility.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path @@ -110,7 +116,9 @@ def test_subpopulation_structure_dulpicate_subpop_names_fail(): mobility: {DATA_DIR}/mobility.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path @@ -130,12 +138,16 @@ def test_subpopulation_structure_mobility_shape_fail(): mobility: {DATA_DIR}/mobility_2x3.txt """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path - with pytest.raises(ValueError, match=r"mobility data must have dimensions of length of geodata.*"): + with pytest.raises( + ValueError, match=r"mobility data must have dimensions of length of geodata.*" + ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] ) @@ -150,12 +162,16 @@ def test_subpopulation_structure_mobility_fluxes_same_ori_and_dest_fail(): mobility: {DATA_DIR}/mobility_same_ori_dest.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path - with pytest.raises(ValueError, match=r"Mobility fluxes with same origin and destination.*"): + with pytest.raises( + ValueError, match=r"Mobility fluxes with same origin and destination.*" + ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] ) @@ -170,12 +186,16 @@ def test_subpopulation_structure_mobility_npz_shape_fail(): mobility: {DATA_DIR}/mobility_2x3.npz """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path - with pytest.raises(ValueError, match=r"mobility data must have dimensions of length of geodata.*"): + with pytest.raises( + ValueError, match=r"mobility data must have dimensions of length of geodata.*" + ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] ) @@ -190,7 +210,9 @@ def test_subpopulation_structure_mobility_no_extension_fail(): mobility: {DATA_DIR}/mobility """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path @@ -210,13 +232,16 @@ def test_subpopulation_structure_mobility_exceed_source_node_pop_fail(): mobility: {DATA_DIR}/mobility1001.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path with pytest.raises( - ValueError, match=r"The following entries in the mobility data exceed the source subpop populations.*" + ValueError, + match=r"The following entries in the mobility data exceed the source subpop populations.*", ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] @@ -232,13 +257,16 @@ def test_subpopulation_structure_mobility_rows_exceed_source_node_pop_fail(): mobility: {DATA_DIR}/mobility_row_exceeed.txt """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path with pytest.raises( - ValueError, match=r"The following entries in the mobility data exceed the source subpop populations.*" + ValueError, + match=r"The following entries in the mobility data exceed the source subpop populations.*", ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] @@ -252,7 +280,9 @@ def test_subpopulation_structure_mobility_no_mobility_matrix_specified(): """ config.clear() config.read(user=False) - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py index 367a7f550..b12989ec4 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py +++ b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py @@ -7,7 +7,7 @@ class TestGetLogNormal: """Unit tests for the `gempyor.utils.get_log_normal` function.""" - + @pytest.mark.parametrize( "meanlog,sdlog", [ @@ -22,15 +22,15 @@ class TestGetLogNormal: ], ) def test_construct_distribution( - self, - meanlog: float | int, - sdlog: float | int, + self, + meanlog: float | int, + sdlog: float | int, ) -> None: """Test the construction of a log normal distribution. - This test checks whether the `get_log_normal` function correctly constructs - a log normal distribution with the specified parameters. It verifies that - the returned object is an instance of `rv_frozen`, and that its support and + This test checks whether the `get_log_normal` function correctly constructs + a log normal distribution with the specified parameters. It verifies that + the returned object is an instance of `rv_frozen`, and that its support and parameters (log mean and log standard deviation) are correctly set. Args: diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py index 23e4fad58..c3ccbca79 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py +++ b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py @@ -7,7 +7,7 @@ class TestGetTruncatedNormal: """Unit tests for the `gempyor.utils.get_truncated_normal` function.""" - + @pytest.mark.parametrize( "mean,sd,a,b", [ @@ -21,10 +21,10 @@ class TestGetTruncatedNormal: ], ) def test_construct_distribution( - self, - mean: float | int, - sd: float | int, - a: float | int, + self, + mean: float | int, + sd: float | int, + a: float | int, b: float | int, ) -> None: """Test the construction of a truncated normal distribution. diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 768451b2a..7013e0ea3 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -10,7 +10,11 @@ @pytest.mark.parametrize( - ("fname", "extension"), [("mobility", "csv"), ("usa-geoid-params-output", "parquet"),], + ("fname", "extension"), + [ + ("mobility", "csv"), + ("usa-geoid-params-output", "parquet"), + ], ) def test_read_df_and_write_success(fname, extension): os.chdir(tmp_path) @@ -23,13 +27,18 @@ def test_read_df_and_write_success(fname, extension): utils.write_df(tmp_path + "/data/" + fname, df2, extension=extension) assert os.path.isfile(tmp_path + "/data/" + fname + "." + extension) elif extension == "parquet": - df2 = pa.parquet.read_table(f"{DATA_DIR}/" + fname + "." + extension).to_pandas() + df2 = pa.parquet.read_table( + f"{DATA_DIR}/" + fname + "." + extension + ).to_pandas() assert df2.equals(df1) utils.write_df(tmp_path + "/data/" + fname, df2, extension=extension) assert os.path.isfile(tmp_path + "/data/" + fname + "." + extension) -@pytest.mark.parametrize(("fname", "extension"), [("mobility", "csv"), ("usa-geoid-params-output", "parquet")]) +@pytest.mark.parametrize( + ("fname", "extension"), + [("mobility", "csv"), ("usa-geoid-params-output", "parquet")], +) def test_read_df_and_write_fail(fname, extension): with pytest.raises(NotImplementedError, match=r".*Invalid.*extension.*Must.*"): os.chdir(tmp_path) @@ -41,7 +50,9 @@ def test_read_df_and_write_fail(fname, extension): assert df2.equals(df1) utils.write_df(tmp_path + "/data/" + fname, df2, extension="") elif extension == "parquet": - df2 = pa.parquet.read_table(f"{DATA_DIR}/" + fname + "." + extension).to_pandas() + df2 = pa.parquet.read_table( + f"{DATA_DIR}/" + fname + "." + extension + ).to_pandas() assert df2.equals(df1) utils.write_df(tmp_path + "/data/" + fname, df2, extension="") @@ -91,9 +102,7 @@ def test_create_resume_out_filename(): filetype="spar", liketype="global", ) - expected_filename = ( - "model_output/output/123/spar/global/intermediate/000000002.000000001.000000001.123.spar.parquet" - ) + expected_filename = "model_output/output/123/spar/global/intermediate/000000002.000000001.000000001.123.spar.parquet" assert result == expected_filename result2 = utils.create_resume_out_filename( @@ -111,32 +120,59 @@ def test_create_resume_out_filename(): def test_create_resume_input_filename(): result = utils.create_resume_input_filename( - flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="spar", liketype="global" + flepi_slot_index="2", + resume_run_index="321", + flepi_prefix="output", + filetype="spar", + liketype="global", + ) + expect_filename = ( + "model_output/output/321/spar/global/final/000000002.321.spar.parquet" ) - expect_filename = "model_output/output/321/spar/global/final/000000002.321.spar.parquet" assert result == expect_filename result2 = utils.create_resume_input_filename( - flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="seed", liketype="chimeric" + flepi_slot_index="2", + resume_run_index="321", + flepi_prefix="output", + filetype="seed", + liketype="chimeric", + ) + expect_filename2 = ( + "model_output/output/321/seed/chimeric/final/000000002.321.seed.csv" ) - expect_filename2 = "model_output/output/321/seed/chimeric/final/000000002.321.seed.csv" assert result2 == expect_filename2 def test_get_filetype_resume_discard_seeding_true_flepi_block_index_1(): expected_types = ["spar", "snpi", "hpar", "hnpi", "init"] - assert utils.get_filetype_for_resume(resume_discard_seeding="true", flepi_block_index="1") == expected_types + assert ( + utils.get_filetype_for_resume( + resume_discard_seeding="true", flepi_block_index="1" + ) + == expected_types + ) def test_get_filetype_resume_discard_seeding_false_flepi_block_index_1(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "init"] - assert utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="1") == expected_types + assert ( + utils.get_filetype_for_resume( + resume_discard_seeding="false", flepi_block_index="1" + ) + == expected_types + ) def test_get_filetype_flepi_block_index_2(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] - assert utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="2") == expected_types + assert ( + utils.get_filetype_for_resume( + resume_discard_seeding="false", flepi_block_index="2" + ) + == expected_types + ) def test_create_resume_file_names_map(): diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils2.py b/flepimop/gempyor_pkg/tests/utils/test_utils2.py index 4b0ae59ba..0822604ed 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils2.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils2.py @@ -26,7 +26,9 @@ class SampleClass: def __init__(self): self.value = 11 - @utils.profile(output_file="get_value.prof", sort_by="time", lines_to_print=10, strip_dirs=True) + @utils.profile( + output_file="get_value.prof", sort_by="time", lines_to_print=10, strip_dirs=True + ) def get_value(self): return self.value @@ -198,7 +200,9 @@ def test_as_random_distribution_binomial_w_fraction_error(config): def test_as_random_distribution_truncnorm(config): - config.add({"value": {"distribution": "truncnorm", "mean": 0, "sd": 1, "a": -1, "b": 1}}) + config.add( + {"value": {"distribution": "truncnorm", "mean": 0, "sd": 1, "a": -1, "b": 1}} + ) dist = config["value"].as_random_distribution() rvs = dist(size=1000) assert len(rvs) == 1000 diff --git a/postprocessing/postprocess_auto.py b/postprocessing/postprocess_auto.py index aaf4a0bff..13f0849d6 100644 --- a/postprocessing/postprocess_auto.py +++ b/postprocessing/postprocess_auto.py @@ -26,7 +26,13 @@ def __init__(self, run_id, config_filepath=None, folder_path=None): self.folder_path = folder_path -def get_all_filenames(file_type, all_runs, finals_only=False, intermediates_only=False, ignore_chimeric=True) -> dict: +def get_all_filenames( + file_type, + all_runs, + finals_only=False, + intermediates_only=False, + ignore_chimeric=True, +) -> dict: """ return dictionanary for each run name """ @@ -159,7 +165,14 @@ def slack_multiple_files_v2(slack_token, message, file_list, channel): help="Maximum number of files to load for in depth plot and individual sim plot", ) def generate_pdf( - config_filepath, run_id, job_name, fs_results_path, slack_token, slack_channel, max_files, max_files_deep + config_filepath, + run_id, + job_name, + fs_results_path, + slack_token, + slack_channel, + max_files, + max_files_deep, ): print("Generating plots") print(f">> config {config_filepath} for run_id {run_id}") @@ -217,10 +230,14 @@ def generate_pdf( for filename in file_list: slot = int(filename.split("/")[-1].split(".")[0]) block = int(filename.split("/")[-1].split(".")[1]) - sim_str = filename.split("/")[-1].split(".")[2] # not necessarily a sim number now + sim_str = filename.split("/")[-1].split(".")[ + 2 + ] # not necessarily a sim number now if sim_str.isdigit(): sim = int(sim_str) - if block == 1 and (sim == 1 or sim % 5 == 0): ## first block, only one + if block == 1 and ( + sim == 1 or sim % 5 == 0 + ): ## first block, only one df_raw = pq.read_table(filename).to_pandas() df_raw["slot"] = slot df_raw["sim"] = sim @@ -238,7 +255,9 @@ def generate_pdf( # In[23]: - fig, axes = plt.subplots(len(node_names) + 1, 4, figsize=(4 * 4, len(node_names) * 3), sharex=True) + fig, axes = plt.subplots( + len(node_names) + 1, 4, figsize=(4 * 4, len(node_names) * 3), sharex=True + ) colors = ["b", "r", "y", "c"] icl = 0 @@ -255,32 +274,68 @@ def generate_pdf( lls = lls.cumsum() feature = "accepts, cumulative" axes[idp, ift].fill_between( - lls.index, lls.quantile(0.025, axis=1), lls.quantile(0.975, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.025, axis=1), + lls.quantile(0.975, axis=1), + alpha=0.1, + color=colors[icl], ) axes[idp, ift].fill_between( - lls.index, lls.quantile(0.25, axis=1), lls.quantile(0.75, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.25, axis=1), + lls.quantile(0.75, axis=1), + alpha=0.1, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, + lls.median(axis=1), + marker="o", + label=run_id, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3 ) - axes[idp, ift].plot(lls.index, lls.median(axis=1), marker="o", label=run_id, color=colors[icl]) - axes[idp, ift].plot(lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3) axes[idp, ift].set_title(f"National, {feature}") axes[idp, ift].grid() for idp, nn in enumerate(node_names): idp = idp + 1 - all_nn = full_df[full_df["subpop"] == nn][["sim", "slot", "ll", "accept", "accept_avg", "accept_prob"]] - for ift, feature in enumerate(["ll", "accept", "accept_avg", "accept_prob"]): + all_nn = full_df[full_df["subpop"] == nn][ + ["sim", "slot", "ll", "accept", "accept_avg", "accept_prob"] + ] + for ift, feature in enumerate( + ["ll", "accept", "accept_avg", "accept_prob"] + ): lls = all_nn.pivot(index="sim", columns="slot", values=feature) if feature == "accept": lls = lls.cumsum() feature = "accepts, cumulative" axes[idp, ift].fill_between( - lls.index, lls.quantile(0.025, axis=1), lls.quantile(0.975, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.025, axis=1), + lls.quantile(0.975, axis=1), + alpha=0.1, + color=colors[icl], ) axes[idp, ift].fill_between( - lls.index, lls.quantile(0.25, axis=1), lls.quantile(0.75, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.25, axis=1), + lls.quantile(0.75, axis=1), + alpha=0.1, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, + lls.median(axis=1), + marker="o", + label=run_id, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3 ) - axes[idp, ift].plot(lls.index, lls.median(axis=1), marker="o", label=run_id, color=colors[icl]) - axes[idp, ift].plot(lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3) axes[idp, ift].set_title(f"{nn}, {feature}") axes[idp, ift].grid() if idp == len(node_names) - 1: @@ -292,8 +347,11 @@ def generate_pdf( pass import gempyor.utils - llik_filenames = gempyor.utils.list_filenames(folder="model_output/", filters=["final", "llik" , ".parquet"]) - #get_all_filenames("llik", fs_results_path, finals_only=True, intermediates_only=False) + + llik_filenames = gempyor.utils.list_filenames( + folder="model_output/", filters=["final", "llik", ".parquet"] + ) + # get_all_filenames("llik", fs_results_path, finals_only=True, intermediates_only=False) # In[7]: resultST = [] for filename in llik_filenames: diff --git a/utilities/clean_s3.py b/utilities/clean_s3.py index 2998c65c2..08982b2b2 100644 --- a/utilities/clean_s3.py +++ b/utilities/clean_s3.py @@ -9,7 +9,9 @@ s3 = boto3.client("s3") paginator = s3.get_paginator("list_objects_v2") -pages = paginator.paginate(Bucket=bucket, Prefix="", Delimiter="/") # needs paginator cause more than 1000 files +pages = paginator.paginate( + Bucket=bucket, Prefix="", Delimiter="/" +) # needs paginator cause more than 1000 files to_prun = [] # folders: diff --git a/utilities/copy_for_continuation.py b/utilities/copy_for_continuation.py index 33d9da40f..05b7a803a 100644 --- a/utilities/copy_for_continuation.py +++ b/utilities/copy_for_continuation.py @@ -77,10 +77,14 @@ def detect_old_run_id(fp): fn = files[0] old_run_id = detect_old_run_id(fn) new_name = ( - fn.replace("seir", "cont").replace(f"{input_folder}/model_output", "model_output").replace(old_run_id, run_id) + fn.replace("seir", "cont") + .replace(f"{input_folder}/model_output", "model_output") + .replace(old_run_id, run_id) ) - print(f"detected old_run_id: {old_run_id} which will be replaced by user provided run_id: {run_id}") + print( + f"detected old_run_id: {old_run_id} which will be replaced by user provided run_id: {run_id}" + ) empty_str = "°" * len(input_folder) print(f"file: \n OLD NAME: {fn}\n NEW NAME: {empty_str}{new_name}") for fn in tqdm.tqdm(files): diff --git a/utilities/prune_by_llik.py b/utilities/prune_by_llik.py index 5b1f3224b..c783591b1 100644 --- a/utilities/prune_by_llik.py +++ b/utilities/prune_by_llik.py @@ -11,7 +11,11 @@ def get_all_filenames( - file_type, fs_results_path="to_prune/", finals_only=False, intermediates_only=True, ignore_chimeric=True + file_type, + fs_results_path="to_prune/", + finals_only=False, + intermediates_only=True, + ignore_chimeric=True, ) -> dict: """ return dictionary for each run name @@ -113,14 +117,18 @@ def get_all_filenames( if fill_missing: # Extract the numbers from the filenames numbers = [int(os.path.basename(filename).split(".")[0]) for filename in all_files] - missing_numbers = [num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers] + missing_numbers = [ + num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers + ] if missing_numbers: missing_filenames = [] for num in missing_numbers: filename = os.path.basename(all_files[0]) filename_prefix = re.search(r"^.*?(\d+)", filename).group() filename_suffix = re.search(r"(\..*?)$", filename).group() - missing_filename = os.path.join(os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}") + missing_filename = os.path.join( + os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}" + ) missing_filenames.append(missing_filename) print("The missing filenames with full paths are:") for missing_filename in missing_filenames: @@ -143,7 +151,7 @@ def copy_path(src, dst): file_types = [ "llik", - #"seed", + # "seed", "init", "snpi", "hnpi", @@ -160,7 +168,9 @@ def copy_path(src, dst): if fn in files_to_keep: for file_type in file_types: src = fn.replace("llik", file_type) - dst = fn.replace(fs_results_path, output_folder).replace("llik", file_type) + dst = fn.replace(fs_results_path, output_folder).replace( + "llik", file_type + ) if file_type == "seed": src = src.replace(".parquet", ".csv") dst = dst.replace(".parquet", ".csv") @@ -169,7 +179,9 @@ def copy_path(src, dst): file_to_keep = np.random.choice(files_to_keep) for file_type in file_types: src = file_to_keep.replace("llik", file_type) - dst = fn.replace(fs_results_path, output_folder).replace("llik", file_type) + dst = fn.replace(fs_results_path, output_folder).replace( + "llik", file_type + ) if file_type == "seed": src = src.replace(".parquet", ".csv") dst = dst.replace(".parquet", ".csv") diff --git a/utilities/prune_by_llik_and_proj.py b/utilities/prune_by_llik_and_proj.py index 53e623224..8bc50b163 100644 --- a/utilities/prune_by_llik_and_proj.py +++ b/utilities/prune_by_llik_and_proj.py @@ -11,7 +11,11 @@ def get_all_filenames( - file_type, fs_results_path="to_prune/", finals_only=False, intermediates_only=True, ignore_chimeric=True + file_type, + fs_results_path="to_prune/", + finals_only=False, + intermediates_only=True, + ignore_chimeric=True, ) -> dict: """ return dictionary for each run name @@ -23,7 +27,7 @@ def get_all_filenames( l = [] for f in Path(str(fs_results_path + "model_output")).rglob(f"*.{ext}"): f = str(f) - + if file_type in f: print(f) if ( @@ -61,7 +65,9 @@ def get_all_filenames( fs_results_path = "to_prune/" best_n = 200 -llik_filenames = get_all_filenames("llik", fs_results_path, finals_only=True, intermediates_only=False) +llik_filenames = get_all_filenames( + "llik", fs_results_path, finals_only=True, intermediates_only=False +) # In[7]: resultST = [] for filename in llik_filenames: @@ -100,7 +106,6 @@ def get_all_filenames( print(f" - {slot:4}, llik: {sorted_llik.loc[slot]['ll']:0.3f}") - #### RERUN FROM HERE TO CHANGE THE REGULARIZATION files_to_keep = list(full_df.loc[best_slots]["filename"].unique()) # important to sort by llik @@ -109,8 +114,9 @@ def get_all_filenames( files_to_keep = [] for fn in all_files: if fn in files_to_keep3: - outcome_fn = fn.replace("llik", "hosp") + outcome_fn = fn.replace("llik", "hosp") import gempyor.utils + outcomes_df = gempyor.utils.read_df(outcome_fn) outcomes_df = outcomes_df.set_index("date") reg = 1.5 @@ -118,19 +124,27 @@ def get_all_filenames( this_bad = 0 bad_subpops = [] for sp in outcomes_df["subpop"].unique(): - max_fit = outcomes_df[outcomes_df["subpop"]==sp]["incidC"][:"2024-04-08"].max() - max_summer = outcomes_df[outcomes_df["subpop"]==sp]["incidC"]["2024-04-08":"2024-09-30"].max() - if max_summer > max_fit*reg: + max_fit = outcomes_df[outcomes_df["subpop"] == sp]["incidC"][ + :"2024-04-08" + ].max() + max_summer = outcomes_df[outcomes_df["subpop"] == sp]["incidC"][ + "2024-04-08":"2024-09-30" + ].max() + if max_summer > max_fit * reg: this_bad += 1 - max_reg = max(max_reg, max_summer/max_fit) + max_reg = max(max_reg, max_summer / max_fit) bad_subpops.append(sp) - #print(f"changing {sp} because max_summer max_summer={max_summer:.1f} > reg*max_fit={max_fit:.1f}, diff {max_fit/max_summer*100:.1f}%") - #print(f">>> MULT BY {max_summer/max_fit*mult:2f}") - #outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]] = outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]]*max_summer/max_fit*mult - if this_bad>4 or max_reg>4: - print(f"{outcome_fn.split('/')[-1].split('.')[0]} >>> BAAD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}") + # print(f"changing {sp} because max_summer max_summer={max_summer:.1f} > reg*max_fit={max_fit:.1f}, diff {max_fit/max_summer*100:.1f}%") + # print(f">>> MULT BY {max_summer/max_fit*mult:2f}") + # outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]] = outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]]*max_summer/max_fit*mult + if this_bad > 4 or max_reg > 4: + print( + f"{outcome_fn.split('/')[-1].split('.')[0]} >>> BAAD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}" + ) else: - print(f"{outcome_fn.split('/')[-1].split('.')[0]} >>> GOOD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}") + print( + f"{outcome_fn.split('/')[-1].split('.')[0]} >>> GOOD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}" + ) files_to_keep.append(fn) print(len(files_to_keep)) ### END OF CODE @@ -146,14 +160,18 @@ def get_all_filenames( if fill_missing: # Extract the numbers from the filenames numbers = [int(os.path.basename(filename).split(".")[0]) for filename in all_files] - missing_numbers = [num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers] + missing_numbers = [ + num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers + ] if missing_numbers: missing_filenames = [] for num in missing_numbers: filename = os.path.basename(all_files[0]) filename_prefix = re.search(r"^.*?(\d+)", filename).group() filename_suffix = re.search(r"(\..*?)$", filename).group() - missing_filename = os.path.join(os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}") + missing_filename = os.path.join( + os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}" + ) missing_filenames.append(missing_filename) print("The missing filenames with full paths are:") for missing_filename in missing_filenames: @@ -191,7 +209,9 @@ def copy_path(src, dst): if fn in files_to_keep: for file_type in file_types: src = fn.replace("llik", file_type) - dst = fn.replace(fs_results_path, output_folder).replace("llik", file_type) + dst = fn.replace(fs_results_path, output_folder).replace( + "llik", file_type + ) if file_type == "seed": src = src.replace(".parquet", ".csv") dst = dst.replace(".parquet", ".csv") @@ -200,7 +220,9 @@ def copy_path(src, dst): file_to_keep = np.random.choice(files_to_keep) for file_type in file_types: src = file_to_keep.replace("llik", file_type) - dst = fn.replace(fs_results_path, output_folder).replace("llik", file_type) + dst = fn.replace(fs_results_path, output_folder).replace( + "llik", file_type + ) if file_type == "seed": src = src.replace(".parquet", ".csv") dst = dst.replace(".parquet", ".csv") From 98dfd4c2ef37cd11777a864f61fe818c935666b7 Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 7 Aug 2024 09:12:47 -0400 Subject: [PATCH 8/9] Edits to formatting section of contribution docs Added a few more examples of commands to help jump start users into using black when working on python code for `flepiMoP`. Also added a note leaving the door open for future additions. --- .../python-guidelines-for-developers.md | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/documentation/gitbook/development/python-guidelines-for-developers.md b/documentation/gitbook/development/python-guidelines-for-developers.md index 9e6b172c0..4397e88ca 100644 --- a/documentation/gitbook/development/python-guidelines-for-developers.md +++ b/documentation/gitbook/development/python-guidelines-for-developers.md @@ -40,16 +40,21 @@ Before committing, make sure you **format your code** using black (see below) an ### Formatting -We try to remain close to python conventions and to follow the updated rules and best practices. For formatting, we use [black](https://github.com/psf/black), the _Uncompromising Code Formatter_ before submitting pull-requests. It provides a consistent style, which is useful when diffing. We use a custom length of 120 characters as the baseline is short for scientific code. Here is the line to use to format your code: +{% hint style="info" %} +Code formatters are necessary, but not sufficient for well formatted code and further style changes may be requested in PRs. Furthermore, the formatting/linting requirements for code contributed to `flepiMoP` are likely to be enhanced in the future and those changes will be reflected here when they come. +{% endhint %} + +For python code formatting the [black](https://black.readthedocs.io/en/stable/) code formatter is applied to all edits to python files being merged into `flepiMoP`. For installation and detailed usage guides please refer to the black documentation. For most use cases the following commands are sufficient: ```bash -black --line-length 120 . --exclude renv* +# See what style changes need to be made +black --diff . +# Reformat the python files automatically +black . +# Check if current work would be allowed to merged into flepiMoP +black --check . ``` -{% hint style="warning" %} -Please use type-hints as much as possible, as we are trying to move towards static checks. -{% endhint %} - ### Structure of the main classes The main classes, such as `Parameter`, `NPI`, `SeedingAndInitialConditions`,`Compartments` should tend to the same struture: From 165b18ee16743d10d58776c1979436df646463df Mon Sep 17 00:00:00 2001 From: Timothy Willard <9395586+TimothyWillard@users.noreply.github.com> Date: Wed, 7 Aug 2024 09:38:31 -0400 Subject: [PATCH 9/9] Fix code linting GitHub action Fix missing `with` in black formatter check step to contain `src` and `options`. --- .github/workflows/code-linting-ci.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-linting-ci.yml b/.github/workflows/code-linting-ci.yml index 26c3f51be..e22141eb3 100644 --- a/.github/workflows/code-linting-ci.yml +++ b/.github/workflows/code-linting-ci.yml @@ -28,5 +28,6 @@ jobs: lfs: true - name: Black Formatter Check uses: psf/black@stable - src: "." - options: "--check --quiet" + with: + src: "." + options: "--check --quiet"