diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 29e9a186d..000000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,70 +0,0 @@ -name: unit-tests - -on: - workflow_dispatch: - push: - branches: - - main - - dev - pull_request: - branches: - - main - - dev - - breaking-improvments - -jobs: - unit-tests: - runs-on: ubuntu-20.04 - container: - image: hopkinsidd/flepimop:latest-dev - options: --user root - steps: - - name: Checkout - uses: actions/checkout@v3 - with: - lfs: true - - name: Set up Rprofile - run: | - cp build/docker/Docker.Rprofile $HOME/.Rprofile - cp /home/app/.bashrc $HOME/.bashrc - shell: bash - - name: Install the gempyor package - run: | - source /var/python/3.10/virtualenv/bin/activate - python -m pip install --upgrade pip - python -m pip install "flepimop/gempyor_pkg[test]" - shell: bash - - name: Install local R packages - run: Rscript build/local_install.R - shell: bash - - name: Run gempyor tests - run: | - source /var/python/3.10/virtualenv/bin/activate - cd flepimop/gempyor_pkg - pytest -s - shell: bash - - name: Run gempyor-cli integration tests from examples - run: | - source /var/python/3.10/virtualenv/bin/activate - cd examples - pytest -s - shell: bash - - name: Run flepicommon tests - run: | - setwd("flepimop/R_packages/flepicommon") - devtools::test(stop_on_failure=TRUE) - shell: Rscript {0} - - name: Run inference tests - run: | - setwd("flepimop/R_packages/inference") - devtools::test(stop_on_failure=TRUE) - shell: Rscript {0} -# - name: Run integration tests -# env: -# CENSUS_API_KEY: ${{ secrets.CENSUS_API_KEY }} -# run: | -# Rscript build/local_install.R -# cd test -# source /var/python/3.10/virtualenv/bin/activate -# pytest run_tests.py -# shell: bash diff --git a/.github/workflows/code-linting-ci.yml b/.github/workflows/code-linting-ci.yml new file mode 100644 index 000000000..e22141eb3 --- /dev/null +++ b/.github/workflows/code-linting-ci.yml @@ -0,0 +1,33 @@ +name: Code Linting + +on: + workflow_dispatch: + push: + paths: + - '**/*.py' + branches: + - main + - dev + pull_request: + paths: + - '**/*.py' + branches: + - main + - dev + +jobs: + unit-tests: + runs-on: ubuntu-latest + container: + image: hopkinsidd/flepimop:latest-dev + options: --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + - name: Black Formatter Check + uses: psf/black@stable + with: + src: "." + options: "--check --quiet" diff --git a/.github/workflows/flepicommon-ci.yml b/.github/workflows/flepicommon-ci.yml new file mode 100644 index 000000000..5314c1b4f --- /dev/null +++ b/.github/workflows/flepicommon-ci.yml @@ -0,0 +1,41 @@ +name: flepicommon-ci + +on: + workflow_dispatch: + push: + paths: + - flepimop/R_packages/flepicommon/**/* + branches: + - main + - dev + pull_request: + paths: + - flepimop/R_packages/flepicommon/**/* + branches: + - main + - dev + +jobs: + unit-tests: + runs-on: ubuntu-latest + container: + image: hopkinsidd/flepimop:latest-dev + options: --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Rprofile + run: | + cp build/docker/Docker.Rprofile $HOME/.Rprofile + cp /home/app/.bashrc $HOME/.bashrc + shell: bash + - name: Install local R packages + run: Rscript build/local_install.R + shell: bash + - name: Run flepicommon tests + run: | + setwd("flepimop/R_packages/flepicommon") + devtools::test(stop_on_failure=TRUE) + shell: Rscript {0} diff --git a/.github/workflows/gempyor-ci.yml b/.github/workflows/gempyor-ci.yml new file mode 100644 index 000000000..a2cb6e313 --- /dev/null +++ b/.github/workflows/gempyor-ci.yml @@ -0,0 +1,48 @@ +name: gempyor-ci + +on: + workflow_dispatch: + push: + paths: + - examples/**/* + - flepimop/gempyor_pkg/**/* + branches: + - main + - dev + pull_request: + paths: + - examples/**/* + - flepimop/gempyor_pkg/**/* + branches: + - main + - dev + +jobs: + unit-tests: + runs-on: ubuntu-latest + container: + image: hopkinsidd/flepimop:latest-dev + options: --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + - name: Install the gempyor package + run: | + source /var/python/3.10/virtualenv/bin/activate + python -m pip install --upgrade pip + python -m pip install "flepimop/gempyor_pkg[test]" + shell: bash + - name: Run gempyor tests + run: | + source /var/python/3.10/virtualenv/bin/activate + cd flepimop/gempyor_pkg + pytest --exitfirst + shell: bash + - name: Run gempyor-cli integration tests from examples + run: | + source /var/python/3.10/virtualenv/bin/activate + cd examples + pytest --exitfirst + shell: bash diff --git a/.github/workflows/inference-ci.yml b/.github/workflows/inference-ci.yml new file mode 100644 index 000000000..2ca3d4897 --- /dev/null +++ b/.github/workflows/inference-ci.yml @@ -0,0 +1,41 @@ +name: inference-ci + +on: + workflow_dispatch: + push: + paths: + - flepimop/R_packages/inference/**/* + branches: + - main + - dev + pull_request: + paths: + - flepimop/R_packages/inference/**/* + branches: + - main + - dev + +jobs: + unit-tests: + runs-on: ubuntu-latest + container: + image: hopkinsidd/flepimop:latest-dev + options: --user root + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Rprofile + run: | + cp build/docker/Docker.Rprofile $HOME/.Rprofile + cp /home/app/.bashrc $HOME/.bashrc + shell: bash + - name: Install local R packages + run: Rscript build/local_install.R + shell: bash + - name: Run inference tests + run: | + setwd("flepimop/R_packages/inference") + devtools::test(stop_on_failure=TRUE) + shell: Rscript {0} diff --git a/batch/inference_job_launcher.py b/batch/inference_job_launcher.py index a8b17ffb0..7b7e2908e 100755 --- a/batch/inference_job_launcher.py +++ b/batch/inference_job_launcher.py @@ -322,7 +322,9 @@ def launch_batch( # TODO: does this really save the config file? if "inference" in config: config["inference"]["iterations_per_slot"] = sims_per_job - if not os.path.exists(pathlib.Path(data_path, config["inference"]["gt_data_path"])): + if not os.path.exists( + pathlib.Path(data_path, config["inference"]["gt_data_path"]) + ): print( f"ERROR: inference.data_path path {pathlib.Path(data_path, config['inference']['gt_data_path'])} does not exist!" ) @@ -403,23 +405,37 @@ def launch_batch( if "scenarios" in config["outcome_modifiers"]: outcome_modifiers_scenarios = config["outcome_modifiers"]["scenarios"] - handler.launch(job_name, config_filepath, seir_modifiers_scenarios, outcome_modifiers_scenarios) + handler.launch( + job_name, config_filepath, seir_modifiers_scenarios, outcome_modifiers_scenarios + ) # Set job_name as environmental variable so it can be pulled for pushing to git os.environ["job_name"] = job_name # Set run_id as environmental variable so it can be pulled for pushing to git TODO - (rc, txt) = subprocess.getstatusoutput(f"git checkout -b run_{job_name}") # TODO: cd ... + (rc, txt) = subprocess.getstatusoutput( + f"git checkout -b run_{job_name}" + ) # TODO: cd ... print(txt) return rc -def autodetect_params(config, data_path, *, num_jobs=None, sims_per_job=None, num_blocks=None, batch_system=None): +def autodetect_params( + config, + data_path, + *, + num_jobs=None, + sims_per_job=None, + num_blocks=None, + batch_system=None, +): if num_jobs and sims_per_job and num_blocks: return (num_jobs, sims_per_job, num_blocks) if "inference" not in config or "iterations_per_slot" not in config["inference"]: - raise click.UsageError("inference::iterations_per_slot undefined in config, can't autodetect parameters") + raise click.UsageError( + "inference::iterations_per_slot undefined in config, can't autodetect parameters" + ) iterations_per_slot = int(config["inference"]["iterations_per_slot"]) if num_jobs is None: @@ -429,11 +445,17 @@ def autodetect_params(config, data_path, *, num_jobs=None, sims_per_job=None, nu if sims_per_job is None: if num_blocks is not None: sims_per_job = int(math.ceil(iterations_per_slot / num_blocks)) - print(f"Setting number of blocks to {num_blocks} [via num_blocks (-k) argument]") - print(f"Setting sims per job to {sims_per_job} [via {iterations_per_slot} iterations_per_slot in config]") + print( + f"Setting number of blocks to {num_blocks} [via num_blocks (-k) argument]" + ) + print( + f"Setting sims per job to {sims_per_job} [via {iterations_per_slot} iterations_per_slot in config]" + ) else: if "data_path" in config: - raise ValueError("The config has a data_path section. This is no longer supported.") + raise ValueError( + "The config has a data_path section. This is no longer supported." + ) geodata_fname = pathlib.Path(data_path) / config["subpop_setup"]["geodata"] with open(geodata_fname) as geodata_fp: num_subpops = sum(1 for line in geodata_fp) @@ -458,7 +480,9 @@ def autodetect_params(config, data_path, *, num_jobs=None, sims_per_job=None, nu if num_blocks is None: num_blocks = int(math.ceil(iterations_per_slot / sims_per_job)) - print(f"Setting number of blocks to {num_blocks} [via {iterations_per_slot} iterations_per_slot in config]") + print( + f"Setting number of blocks to {num_blocks} [via {iterations_per_slot} iterations_per_slot in config]" + ) return (num_jobs, sims_per_job, num_blocks) @@ -472,13 +496,17 @@ def get_aws_job_queues(job_queue_prefix): for q in resp["jobQueues"]: queue_name = q["jobQueueName"] if queue_name.startswith(job_queue_prefix): - job_list_resp = batch_client.list_jobs(jobQueue=queue_name, jobStatus="PENDING") + job_list_resp = batch_client.list_jobs( + jobQueue=queue_name, jobStatus="PENDING" + ) queues_with_jobs[queue_name] = len(job_list_resp["jobSummaryList"]) # Return the least-loaded queues first return sorted(queues_with_jobs, key=queues_with_jobs.get) -def aws_countfiles_autodetect_runid(s3_bucket, restart_from_location, restart_from_run_id, num_jobs, strict=False): +def aws_countfiles_autodetect_runid( + s3_bucket, restart_from_location, restart_from_run_id, num_jobs, strict=False +): import boto3 s3 = boto3.resource("s3") @@ -487,15 +515,24 @@ def aws_countfiles_autodetect_runid(s3_bucket, restart_from_location, restart_fr all_files = list(bucket.objects.filter(Prefix=prefix)) all_files = [f.key for f in all_files] if restart_from_run_id is None: - print("WARNING: no --restart_from_run_id specified, autodetecting... please wait querying S3 👀🔎...") + print( + "WARNING: no --restart_from_run_id specified, autodetecting... please wait querying S3 👀🔎..." + ) restart_from_run_id = all_files[0].split("/")[3] - if user_confirmation(question=f"Auto-detected run_id {restart_from_run_id}. Correct ?", default=True): + if user_confirmation( + question=f"Auto-detected run_id {restart_from_run_id}. Correct ?", + default=True, + ): print(f"great, continuing with run_id {restart_from_run_id}...") else: - raise ValueError(f"Abording, please specify --restart_from_run_id manually.") + raise ValueError( + f"Abording, please specify --restart_from_run_id manually." + ) final_llik = [f for f in all_files if ("llik" in f) and ("final" in f)] - if len(final_llik) == 0: # hacky: there might be a bucket with no llik files, e.g if init. + if ( + len(final_llik) == 0 + ): # hacky: there might be a bucket with no llik files, e.g if init. final_llik = [f for f in all_files if ("init" in f) and ("final" in f)] if len(final_llik) != num_jobs: @@ -583,8 +620,12 @@ def build_job_metadata(self, job_name): manifest = {} manifest["cmd"] = " ".join(sys.argv[:]) manifest["job_name"] = job_name - manifest["data_sha"] = subprocess.getoutput("cd {self.data_path}; git rev-parse HEAD") - manifest["flepimop_sha"] = subprocess.getoutput(f"cd {self.flepi_path}; git rev-parse HEAD") + manifest["data_sha"] = subprocess.getoutput( + "cd {self.data_path}; git rev-parse HEAD" + ) + manifest["flepimop_sha"] = subprocess.getoutput( + f"cd {self.flepi_path}; git rev-parse HEAD" + ) # Save the manifest file to S3 with open("manifest.json", "w") as f: @@ -594,17 +635,27 @@ def build_job_metadata(self, job_name): # need these to be uploaded so they can be executed. this_file_path = os.path.dirname(os.path.realpath(__file__)) self.save_file( - source=os.path.join(this_file_path, "AWS_inference_runner.sh"), destination=f"{job_name}-runner.sh" + source=os.path.join(this_file_path, "AWS_inference_runner.sh"), + destination=f"{job_name}-runner.sh", ) self.save_file( - source=os.path.join(this_file_path, "AWS_inference_copy.sh"), destination=f"{job_name}-copy.sh" + source=os.path.join(this_file_path, "AWS_inference_copy.sh"), + destination=f"{job_name}-copy.sh", ) tarfile_name = f"{job_name}.tar.gz" self.tar_working_dir(tarfile_name=tarfile_name) - self.save_file(source=tarfile_name, destination=f"{job_name}.tar.gz", remove_source=True) + self.save_file( + source=tarfile_name, + destination=f"{job_name}.tar.gz", + remove_source=True, + ) - self.save_file(source="manifest.json", destination=f"{job_name}/manifest.json", remove_source=True) + self.save_file( + source="manifest.json", + destination=f"{job_name}/manifest.json", + remove_source=True, + ) def tar_working_dir(self, tarfile_name): # this tar file always has the structure: @@ -616,10 +667,14 @@ def tar_working_dir(self, tarfile_name): or q == "covid-dashboard-app" or q == "renv.cache" or q == "sample_data" - or q == "renv" # joseph: I added this to fix a bug, hopefully it doesn't break anything + or q + == "renv" # joseph: I added this to fix a bug, hopefully it doesn't break anything or q.startswith(".") ): - tar.add(os.path.join(self.flepi_path, q), arcname=os.path.join("flepiMoP", q)) + tar.add( + os.path.join(self.flepi_path, q), + arcname=os.path.join("flepiMoP", q), + ) elif q == "sample_data": for r in os.listdir(os.path.join(self.flepi_path, "sample_data")): if r != "united-states-commutes": @@ -629,10 +684,17 @@ def tar_working_dir(self, tarfile_name): ) # tar.add(os.path.join("flepiMoP", "sample_data", r)) for p in os.listdir(self.data_path): - if not (p.startswith(".") or p.endswith("tar.gz") or p in self.outputs or p == "flepiMoP"): + if not ( + p.startswith(".") + or p.endswith("tar.gz") + or p in self.outputs + or p == "flepiMoP" + ): tar.add( p, - filter=lambda x: None if os.path.basename(x.name).startswith(".") else x, + filter=lambda x: ( + None if os.path.basename(x.name).startswith(".") else x + ), ) tar.close() @@ -644,7 +706,9 @@ def save_file(self, source, destination, remove_source=False, prefix=""): import boto3 s3_client = boto3.client("s3") - s3_client.upload_file(source, self.s3_bucket, os.path.join(prefix, destination)) + s3_client.upload_file( + source, self.s3_bucket, os.path.join(prefix, destination) + ) if self.batch_system == "slurm": import shutil @@ -656,7 +720,13 @@ def save_file(self, source, destination, remove_source=False, prefix=""): if remove_source: os.remove(source) - def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_modifiers_scenarios): + def launch( + self, + job_name, + config_filepath, + seir_modifiers_scenarios, + outcome_modifiers_scenarios, + ): s3_results_path = f"s3://{self.s3_bucket}/{job_name}" if self.batch_system == "slurm": @@ -676,7 +746,10 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo ## TODO: check how each of these variables are used downstream base_env_vars = [ {"name": "BATCH_SYSTEM", "value": self.batch_system}, - {"name": "S3_MODEL_PROJECT_PATH", "value": f"s3://{self.s3_bucket}/{job_name}.tar.gz"}, + { + "name": "S3_MODEL_PROJECT_PATH", + "value": f"s3://{self.s3_bucket}/{job_name}.tar.gz", + }, {"name": "DVC_OUTPUTS", "value": " ".join(self.outputs)}, {"name": "S3_RESULTS_PATH", "value": s3_results_path}, {"name": "FS_RESULTS_PATH", "value": fs_results_path}, @@ -700,14 +773,22 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo }, {"name": "FLEPI_STOCHASTIC_RUN", "value": str(self.stochastic)}, {"name": "FLEPI_RESET_CHIMERICS", "value": str(self.reset_chimerics)}, - {"name": "FLEPI_MEM_PROFILE", "value": str(os.getenv("FLEPI_MEM_PROFILE", default="FALSE"))}, - {"name": "FLEPI_MEM_PROF_ITERS", "value": str(os.getenv("FLEPI_MEM_PROF_ITERS", default="50"))}, + { + "name": "FLEPI_MEM_PROFILE", + "value": str(os.getenv("FLEPI_MEM_PROFILE", default="FALSE")), + }, + { + "name": "FLEPI_MEM_PROF_ITERS", + "value": str(os.getenv("FLEPI_MEM_PROF_ITERS", default="50")), + }, {"name": "SLACK_CHANNEL", "value": str(self.slack_channel)}, ] with open(config_filepath) as f: config = yaml.full_load(f) - for ctr, (s, d) in enumerate(itertools.product(seir_modifiers_scenarios, outcome_modifiers_scenarios)): + for ctr, (s, d) in enumerate( + itertools.product(seir_modifiers_scenarios, outcome_modifiers_scenarios) + ): cur_job_name = f"{job_name}_{s}_{d}" # Create first job cur_env_vars = base_env_vars.copy() @@ -719,7 +800,12 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo cur_env_vars.append({"name": "FLEPI_BLOCK_INDEX", "value": "1"}) cur_env_vars.append({"name": "FLEPI_RUN_INDEX", "value": f"{self.run_id}"}) if not (self.restart_from_location is None): - cur_env_vars.append({"name": "LAST_JOB_OUTPUT", "value": f"{self.restart_from_location}"}) + cur_env_vars.append( + { + "name": "LAST_JOB_OUTPUT", + "value": f"{self.restart_from_location}", + } + ) cur_env_vars.append( { "name": "OLD_FLEPI_RUN_INDEX", @@ -732,8 +818,18 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo if self.continuation: cur_env_vars.append({"name": "FLEPI_CONTINUATION", "value": f"TRUE"}) - cur_env_vars.append({"name": "FLEPI_CONTINUATION_RUN_ID", "value": f"{self.continuation_run_id}"}) - cur_env_vars.append({"name": "FLEPI_CONTINUATION_LOCATION", "value": f"{self.continuation_location}"}) + cur_env_vars.append( + { + "name": "FLEPI_CONTINUATION_RUN_ID", + "value": f"{self.continuation_run_id}", + } + ) + cur_env_vars.append( + { + "name": "FLEPI_CONTINUATION_LOCATION", + "value": f"{self.continuation_location}", + } + ) cur_env_vars.append( { "name": "FLEPI_CONTINUATION_FTYPE", @@ -743,7 +839,9 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo # First job: if self.batch_system == "aws": - cur_env_vars.append({"name": "JOB_NAME", "value": f"{cur_job_name}_block0"}) + cur_env_vars.append( + {"name": "JOB_NAME", "value": f"{cur_job_name}_block0"} + ) runner_script_path = f"s3://{self.s3_bucket}/{job_name}-runner.sh" s3_cp_run_script = f"aws s3 cp {runner_script_path} $PWD/run-flepimop-inference" # line to copy the runner script in wd as ./run-covid-pipeline command = [ @@ -814,7 +912,9 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo postprod_command, command_name="sbatch postprod", fail_on_fail=True ) postprod_job_id = stdout.decode().split(" ")[-1][:-1] - print(f">>> SUCCESS SCHEDULING POST-PROCESSING JOB. Slurm job id is {postprod_job_id}") + print( + f">>> SUCCESS SCHEDULING POST-PROCESSING JOB. Slurm job id is {postprod_job_id}" + ) elif self.batch_system == "local": cur_env_vars.append({"name": "JOB_NAME", "value": f"{cur_job_name}"}) @@ -831,12 +931,27 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo cur_env_vars = base_env_vars.copy() cur_env_vars.append({"name": "FLEPI_SEIR_SCENARIOS", "value": s}) cur_env_vars.append({"name": "FLEPI_OUTCOME_SCENARIOS", "value": d}) - cur_env_vars.append({"name": "FLEPI_PREFIX", "value": f"{config['name']}_{s}_{d}"}) - cur_env_vars.append({"name": "FLEPI_BLOCK_INDEX", "value": f"{block_idx+1}"}) - cur_env_vars.append({"name": "FLEPI_RUN_INDEX", "value": f"{self.run_id}"}) - cur_env_vars.append({"name": "OLD_FLEPI_RUN_INDEX", "value": f"{self.run_id}"}) - cur_env_vars.append({"name": "LAST_JOB_OUTPUT", "value": f"{s3_results_path}/"}) - cur_env_vars.append({"name": "JOB_NAME", "value": f"{cur_job_name}_block{block_idx}"}) + cur_env_vars.append( + {"name": "FLEPI_PREFIX", "value": f"{config['name']}_{s}_{d}"} + ) + cur_env_vars.append( + {"name": "FLEPI_BLOCK_INDEX", "value": f"{block_idx+1}"} + ) + cur_env_vars.append( + {"name": "FLEPI_RUN_INDEX", "value": f"{self.run_id}"} + ) + cur_env_vars.append( + {"name": "OLD_FLEPI_RUN_INDEX", "value": f"{self.run_id}"} + ) + cur_env_vars.append( + {"name": "LAST_JOB_OUTPUT", "value": f"{s3_results_path}/"} + ) + cur_env_vars.append( + { + "name": "JOB_NAME", + "value": f"{cur_job_name}_block{block_idx}", + } + ) cur_job = batch_client.submit_job( jobName=f"{cur_job_name}_block{block_idx}", jobQueue=cur_job_queue, @@ -862,7 +977,9 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo ] copy_script_path = f"s3://{self.s3_bucket}/{job_name}-copy.sh" - s3_cp_run_script = f"aws s3 cp {copy_script_path} $PWD/run-flepimop-copy" + s3_cp_run_script = ( + f"aws s3 cp {copy_script_path} $PWD/run-flepimop-copy" + ) cp_command = [ "sh", "-c", @@ -895,21 +1012,33 @@ def launch(self, job_name, config_filepath, seir_modifiers_scenarios, outcome_mo em = "" if self.resume_discard_seeding: em = f", discarding seeding results." - print(f" >> Resuming from run id is {self.restart_from_run_id} located in {self.restart_from_location}{em}") + print( + f" >> Resuming from run id is {self.restart_from_run_id} located in {self.restart_from_location}{em}" + ) if self.batch_system == "aws": print(f" >> Final output will be: {s3_results_path}/model_output/") elif self.batch_system == "slurm": print(f" >> Final output will be: {fs_results_path}/model_output/") if self.s3_upload: - print(f" >> Final output will be uploaded to {s3_results_path}/model_output/") + print( + f" >> Final output will be uploaded to {s3_results_path}/model_output/" + ) if self.continuation: - print(f" >> Continuing from run id is {self.continuation_run_id} located in {self.continuation_location}") + print( + f" >> Continuing from run id is {self.continuation_run_id} located in {self.continuation_location}" + ) print(f" >> Run id is {self.run_id}") print(f" >> config is {config_filepath.split('/')[-1]}") - flepimop_branch = subprocess.getoutput(f"cd {self.flepi_path}; git rev-parse --abbrev-ref HEAD") - data_branch = subprocess.getoutput(f"cd {self.data_path}; git rev-parse --abbrev-ref HEAD") + flepimop_branch = subprocess.getoutput( + f"cd {self.flepi_path}; git rev-parse --abbrev-ref HEAD" + ) + data_branch = subprocess.getoutput( + f"cd {self.data_path}; git rev-parse --abbrev-ref HEAD" + ) data_hash = subprocess.getoutput(f"cd {self.data_path}; git rev-parse HEAD") - flepimop_hash = subprocess.getoutput(f"cd {self.flepi_path}; git rev-parse HEAD") + flepimop_hash = subprocess.getoutput( + f"cd {self.flepi_path}; git rev-parse HEAD" + ) print(f""" >> FLEPIMOP branch is {flepimop_branch} with hash {flepimop_hash}""") print(f""" >> DATA branch is {data_branch} with hash {data_hash}""") print(f" ------------------------- END -------------------------") diff --git a/batch/scenario_job.py b/batch/scenario_job.py index 1961974eb..4d94460b4 100755 --- a/batch/scenario_job.py +++ b/batch/scenario_job.py @@ -196,13 +196,20 @@ def launch_job_inner( tarfile_name = f"{job_name}.tar.gz" tar = tarfile.open(tarfile_name, "w:gz") for p in os.listdir("."): - if not (p.startswith(".") or p.endswith("tar.gz") or p in dvc_outputs or p == "batch"): + if not ( + p.startswith(".") + or p.endswith("tar.gz") + or p in dvc_outputs + or p == "batch" + ): tar.add(p, filter=lambda x: None if x.name.startswith(".") else x) tar.close() # Upload the tar'd contents of this directory and the runner script to S3 runner_script_name = f"{job_name}-runner.sh" - local_runner_script = os.path.join(os.path.dirname(os.path.realpath(__file__)), "AWS_scenario_runner.sh") + local_runner_script = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "AWS_scenario_runner.sh" + ) s3_client = boto3.client("s3") s3_client.upload_file(local_runner_script, s3_input_bucket, runner_script_name) s3_client.upload_file(tarfile_name, s3_input_bucket, tarfile_name) @@ -246,7 +253,9 @@ def launch_job_inner( containerOverrides=container_overrides, ) - print(f"Batch job with id {resp['jobId']} launched; output will be written to {results_path}") + print( + f"Batch job with id {resp['jobId']} launched; output will be written to {results_path}" + ) def get_dvc_outputs(): diff --git a/documentation/gitbook/development/python-guidelines-for-developers.md b/documentation/gitbook/development/python-guidelines-for-developers.md index 9e6b172c0..4397e88ca 100644 --- a/documentation/gitbook/development/python-guidelines-for-developers.md +++ b/documentation/gitbook/development/python-guidelines-for-developers.md @@ -40,16 +40,21 @@ Before committing, make sure you **format your code** using black (see below) an ### Formatting -We try to remain close to python conventions and to follow the updated rules and best practices. For formatting, we use [black](https://github.com/psf/black), the _Uncompromising Code Formatter_ before submitting pull-requests. It provides a consistent style, which is useful when diffing. We use a custom length of 120 characters as the baseline is short for scientific code. Here is the line to use to format your code: +{% hint style="info" %} +Code formatters are necessary, but not sufficient for well formatted code and further style changes may be requested in PRs. Furthermore, the formatting/linting requirements for code contributed to `flepiMoP` are likely to be enhanced in the future and those changes will be reflected here when they come. +{% endhint %} + +For python code formatting the [black](https://black.readthedocs.io/en/stable/) code formatter is applied to all edits to python files being merged into `flepiMoP`. For installation and detailed usage guides please refer to the black documentation. For most use cases the following commands are sufficient: ```bash -black --line-length 120 . --exclude renv* +# See what style changes need to be made +black --diff . +# Reformat the python files automatically +black . +# Check if current work would be allowed to merged into flepiMoP +black --check . ``` -{% hint style="warning" %} -Please use type-hints as much as possible, as we are trying to move towards static checks. -{% endhint %} - ### Structure of the main classes The main classes, such as `Parameter`, `NPI`, `SeedingAndInitialConditions`,`Compartments` should tend to the same struture: diff --git a/examples/test_cli.py b/examples/test_cli.py index 8b1d02982..1a7e726f3 100644 --- a/examples/test_cli.py +++ b/examples/test_cli.py @@ -1,4 +1,3 @@ - from click.testing import CliRunner from gempyor.simulate import simulate import os @@ -6,44 +5,48 @@ # See here to test click application https://click.palletsprojects.com/en/8.1.x/testing/ # would be useful to also call the command directly + def test_config_sample_2pop(): - os.chdir(os.path.dirname(__file__) + "/tutorial_two_subpops") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'config_sample_2pop.yml']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output + os.chdir(os.path.dirname(__file__) + "/tutorial_two_subpops") + runner = CliRunner() + result = runner.invoke(simulate, ["-c", "config_sample_2pop.yml"]) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output def test_sample_2pop_interventions_test(): - os.chdir(os.path.dirname(__file__) + "/tutorial_two_subpops") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'config_sample_2pop_interventions_test.yml']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output + os.chdir(os.path.dirname(__file__) + "/tutorial_two_subpops") + runner = CliRunner() + result = runner.invoke( + simulate, ["-c", "config_sample_2pop_interventions_test.yml"] + ) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output def test_simple_usa_statelevel(): - os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'simple_usa_statelevel.yml', '-n', '1']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output + os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel") + runner = CliRunner() + result = runner.invoke(simulate, ["-c", "simple_usa_statelevel.yml", "-n", "1"]) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output + def test_simple_usa_statelevel(): - os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel") - runner = CliRunner() - result = runner.invoke(simulate, ['-c', 'simple_usa_statelevel.yml', '-n', '1']) - print(result.output) # useful for debug - print(result.exit_code) # useful for debug - print(result.exception) # useful for debug - assert result.exit_code == 0 - assert 'completed in' in result.output \ No newline at end of file + os.chdir(os.path.dirname(__file__) + "/simple_usa_statelevel") + runner = CliRunner() + result = runner.invoke(simulate, ["-c", "simple_usa_statelevel.yml", "-n", "1"]) + print(result.output) # useful for debug + print(result.exit_code) # useful for debug + print(result.exception) # useful for debug + assert result.exit_code == 0 + assert "completed in" in result.output diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py index ecbbba962..117f0c36b 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/MultiPeriodModifier.py @@ -21,7 +21,8 @@ def __init__( name=getattr( npi_config, "key", - (npi_config["scenario"].exists() and npi_config["scenario"].get()) or "unknown", + (npi_config["scenario"].exists() and npi_config["scenario"].get()) + or "unknown", ) ) @@ -32,7 +33,9 @@ def __init__( self.subpops = subpops self.pnames_overlap_operation_sum = pnames_overlap_operation_sum - self.pnames_overlap_operation_reductionprod = pnames_overlap_operation_reductionprod + self.pnames_overlap_operation_reductionprod = ( + pnames_overlap_operation_reductionprod + ) self.param_name = npi_config["parameter"].as_str().lower() @@ -68,14 +71,22 @@ def __init__( # if parameters are exceeding global start/end dates, index of parameter df will be out of range so check first if self.sanitize: - too_early = min([min(i) for i in self.parameters["start_date"]]) < self.start_date - too_late = max([max(i) for i in self.parameters["end_date"]]) > self.end_date + too_early = ( + min([min(i) for i in self.parameters["start_date"]]) < self.start_date + ) + too_late = ( + max([max(i) for i in self.parameters["end_date"]]) > self.end_date + ) if too_early or too_late: - raise ValueError("at least one period start or end date is not between global dates") + raise ValueError( + "at least one period start or end date is not between global dates" + ) for grp_config in npi_config["groups"]: affected_subpops_grp = self.__get_affected_subpops_grp(grp_config) - for sub_index in range(len(self.parameters["start_date"][affected_subpops_grp[0]])): + for sub_index in range( + len(self.parameters["start_date"][affected_subpops_grp[0]]) + ): period_range = pd.date_range( self.parameters["start_date"][affected_subpops_grp[0]][sub_index], self.parameters["end_date"][affected_subpops_grp[0]][sub_index], @@ -101,7 +112,9 @@ def __checkErrors(self): max_start_date = max([max(i) for i in self.parameters["start_date"]]) min_end_date = min([min(i) for i in self.parameters["end_date"]]) max_end_date = max([max(i) for i in self.parameters["end_date"]]) - if not ((self.start_date <= min_start_date) & (max_start_date <= self.end_date)): + if not ( + (self.start_date <= min_start_date) & (max_start_date <= self.end_date) + ): raise ValueError( f"at least one period_start_date [{min_start_date}, {max_start_date}] is not between global dates [{self.start_date}, {self.end_date}]" ) @@ -111,7 +124,9 @@ def __checkErrors(self): ) if not (self.parameters["start_date"] <= self.parameters["end_date"]).all(): - raise ValueError(f"at least one period_start_date is greater than the corresponding period end date") + raise ValueError( + f"at least one period_start_date is greater than the corresponding period end date" + ) for n in self.affected_subpops: if n not in self.subpops: @@ -135,7 +150,9 @@ def __createFromConfig(self, npi_config): self.affected_subpops = self.__get_affected_subpops(npi_config) - self.parameters = self.parameters[self.parameters.index.isin(self.affected_subpops)] + self.parameters = self.parameters[ + self.parameters.index.isin(self.affected_subpops) + ] dist = npi_config["value"].as_random_distribution() self.parameters["modifier_name"] = self.name self.parameters["parameter"] = self.param_name @@ -153,7 +170,9 @@ def __createFromConfig(self, npi_config): else: start_dates = [self.start_date] end_dates = [self.end_date] - this_spatial_group = helpers.get_spatial_groups(grp_config, affected_subpops_grp) + this_spatial_group = helpers.get_spatial_groups( + grp_config, affected_subpops_grp + ) self.spatial_groups.append(this_spatial_group) # print(self.name, this_spatial_groups) @@ -182,7 +201,9 @@ def __createFromDf(self, loaded_df, npi_config): loaded_df = loaded_df[loaded_df["modifier_name"] == self.name] self.affected_subpops = self.__get_affected_subpops(npi_config) - self.parameters = self.parameters[self.parameters.index.isin(self.affected_subpops)] + self.parameters = self.parameters[ + self.parameters.index.isin(self.affected_subpops) + ] self.parameters["modifier_name"] = self.name self.parameters["parameter"] = self.param_name @@ -194,7 +215,9 @@ def __createFromDf(self, loaded_df, npi_config): if self.sanitize: if len(self.affected_subpops) != len(self.parameters): print(f"loading {self.name} and we got {len(self.parameters)} subpops") - print(f"getting from config that it affects {len(self.affected_subpops)}") + print( + f"getting from config that it affects {len(self.affected_subpops)}" + ) self.spatial_groups = [] for grp_config in npi_config["groups"]: @@ -209,7 +232,9 @@ def __createFromDf(self, loaded_df, npi_config): else: start_dates = [self.start_date] end_dates = [self.end_date] - this_spatial_group = helpers.get_spatial_groups(grp_config, affected_subpops_grp) + this_spatial_group = helpers.get_spatial_groups( + grp_config, affected_subpops_grp + ) self.spatial_groups.append(this_spatial_group) for subpop in this_spatial_group["ungrouped"]: @@ -227,7 +252,9 @@ def __createFromDf(self, loaded_df, npi_config): for subpop in group: self.parameters.at[subpop, "start_date"] = start_dates self.parameters.at[subpop, "end_date"] = end_dates - self.parameters.at[subpop, "value"] = loaded_df.at[",".join(group), "value"] + self.parameters.at[subpop, "value"] = loaded_df.at[ + ",".join(group), "value" + ] else: dist = npi_config["value"].as_random_distribution() drawn_value = dist(size=1) @@ -258,11 +285,16 @@ def __get_affected_subpops(self, npi_config): affected_subpops_grp += [str(n.get()) for n in grp_config["subpop"]] affected_subpops = set(affected_subpops_grp) if len(affected_subpops) != len(affected_subpops_grp): - raise ValueError(f"In NPI {self.name}, some subpops belong to several groups. This is unsupported.") + raise ValueError( + f"In NPI {self.name}, some subpops belong to several groups. This is unsupported." + ) return affected_subpops def get_default(self, param): - if param in self.pnames_overlap_operation_sum or param in self.pnames_overlap_operation_reductionprod: + if ( + param in self.pnames_overlap_operation_sum + or param in self.pnames_overlap_operation_reductionprod + ): return 0.0 else: return 1.0 @@ -278,7 +310,9 @@ def getReductionToWrite(self): # self.parameters.index is a list of subpops for this_spatial_groups in self.spatial_groups: # spatially ungrouped dataframe - df_ungroup = self.parameters[self.parameters.index.isin(this_spatial_groups["ungrouped"])].copy() + df_ungroup = self.parameters[ + self.parameters.index.isin(this_spatial_groups["ungrouped"]) + ].copy() df_ungroup.index.name = "subpop" df_ungroup["start_date"] = df_ungroup["start_date"].apply( lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l]) @@ -301,7 +335,9 @@ def getReductionToWrite(self): "start_date": df_group["start_date"].apply( lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l]) ), - "end_date": df_group["end_date"].apply(lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l])), + "end_date": df_group["end_date"].apply( + lambda l: ",".join([d.strftime("%Y-%m-%d") for d in l]) + ), "value": df_group["value"], } ).set_index("subpop") diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py index e078ddeba..f2438321c 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/SinglePeriodModifier.py @@ -21,7 +21,8 @@ def __init__( name=getattr( npi_config, "key", - (npi_config["scenario"].exists() and npi_config["scenario"].get()) or "unknown", + (npi_config["scenario"].exists() and npi_config["scenario"].get()) + or "unknown", ) ) @@ -29,7 +30,9 @@ def __init__( self.end_date = modinf.tf self.pnames_overlap_operation_sum = pnames_overlap_operation_sum - self.pnames_overlap_operation_reductionprod = pnames_overlap_operation_reductionprod + self.pnames_overlap_operation_reductionprod = ( + pnames_overlap_operation_reductionprod + ) self.subpops = subpops @@ -60,15 +63,22 @@ def __init__( self.__createFromConfig(npi_config) # if parameters are exceeding global start/end dates, index of parameter df will be out of range so check first - if self.parameters["start_date"].min() < self.start_date or self.parameters["end_date"].max() > self.end_date: - raise ValueError(f"""{self.name} : at least one period start or end date is not between global dates""") + if ( + self.parameters["start_date"].min() < self.start_date + or self.parameters["end_date"].max() > self.end_date + ): + raise ValueError( + f"""{self.name} : at least one period start or end date is not between global dates""" + ) # for index in self.parameters.index: # period_range = pd.date_range(self.parameters["start_date"][index], self.parameters["end_date"][index]) ## This the line that does the work # self.npi_old.loc[index, period_range] = np.tile(self.parameters["value"][index], (len(period_range), 1)).T - period_range = pd.date_range(self.parameters["start_date"].iloc[0], self.parameters["end_date"].iloc[0]) + period_range = pd.date_range( + self.parameters["start_date"].iloc[0], self.parameters["end_date"].iloc[0] + ) self.npi.loc[self.parameters.index, period_range] = np.tile( self.parameters["value"][:], (len(period_range), 1) ).T @@ -80,7 +90,9 @@ def __checkErrors(self): max_start_date = self.parameters["start_date"].max() min_end_date = self.parameters["end_date"].min() max_end_date = self.parameters["end_date"].max() - if not ((self.start_date <= min_start_date) & (max_start_date <= self.end_date)): + if not ( + (self.start_date <= min_start_date) & (max_start_date <= self.end_date) + ): raise ValueError( f"at least one period_start_date [{min_start_date}, {max_start_date}] is not between global dates [{self.start_date}, {self.end_date}]" ) @@ -90,7 +102,9 @@ def __checkErrors(self): ) if not (self.parameters["start_date"] <= self.parameters["end_date"]).all(): - raise ValueError(f"at least one period_start_date is greater than the corresponding period end date") + raise ValueError( + f"at least one period_start_date is greater than the corresponding period end date" + ) for n in self.affected_subpops: if n not in self.subpops: @@ -116,19 +130,27 @@ def __createFromConfig(self, npi_config): if npi_config["subpop"].exists() and npi_config["subpop"].get() != "all": self.affected_subpops = {str(n.get()) for n in npi_config["subpop"]} - self.parameters = self.parameters[self.parameters.index.isin(self.affected_subpops)] + self.parameters = self.parameters[ + self.parameters.index.isin(self.affected_subpops) + ] # Create reduction self.dist = npi_config["value"].as_random_distribution() self.parameters["modifier_name"] = self.name self.parameters["start_date"] = ( - npi_config["period_start_date"].as_date() if npi_config["period_start_date"].exists() else self.start_date + npi_config["period_start_date"].as_date() + if npi_config["period_start_date"].exists() + else self.start_date ) self.parameters["end_date"] = ( - npi_config["period_end_date"].as_date() if npi_config["period_end_date"].exists() else self.end_date + npi_config["period_end_date"].as_date() + if npi_config["period_end_date"].exists() + else self.end_date ) self.parameters["parameter"] = self.param_name - self.spatial_groups = helpers.get_spatial_groups(npi_config, list(self.affected_subpops)) + self.spatial_groups = helpers.get_spatial_groups( + npi_config, list(self.affected_subpops) + ) if self.spatial_groups["ungrouped"]: self.parameters.loc[self.spatial_groups["ungrouped"], "value"] = self.dist( size=len(self.spatial_groups["ungrouped"]) @@ -146,17 +168,23 @@ def __createFromDf(self, loaded_df, npi_config): if npi_config["subpop"].exists() and npi_config["subpop"].get() != "all": self.affected_subpops = {str(n.get()) for n in npi_config["subpop"]} - self.parameters = self.parameters[self.parameters.index.isin(self.affected_subpops)] + self.parameters = self.parameters[ + self.parameters.index.isin(self.affected_subpops) + ] self.parameters["modifier_name"] = self.name self.parameters["parameter"] = self.param_name # self.parameters = loaded_df[["modifier_name", "start_date", "end_date", "parameter", "value"]].copy() # dates are picked from config self.parameters["start_date"] = ( - npi_config["period_start_date"].as_date() if npi_config["period_start_date"].exists() else self.start_date + npi_config["period_start_date"].as_date() + if npi_config["period_start_date"].exists() + else self.start_date ) self.parameters["end_date"] = ( - npi_config["period_end_date"].as_date() if npi_config["period_end_date"].exists() else self.end_date + npi_config["period_end_date"].as_date() + if npi_config["period_end_date"].exists() + else self.end_date ) ## This is more legible to me, but if we change it here, we should change it in __createFromConfig as well # if npi_config["period_start_date"].exists(): @@ -175,17 +203,24 @@ def __createFromDf(self, loaded_df, npi_config): # TODO: to be consistent with MTR, we want to also draw the values for the subpops # that are not in the loaded_df. - self.spatial_groups = helpers.get_spatial_groups(npi_config, list(self.affected_subpops)) + self.spatial_groups = helpers.get_spatial_groups( + npi_config, list(self.affected_subpops) + ) if self.spatial_groups["ungrouped"]: - self.parameters.loc[self.spatial_groups["ungrouped"], "value"] = loaded_df.loc[ - self.spatial_groups["ungrouped"], "value" - ] + self.parameters.loc[self.spatial_groups["ungrouped"], "value"] = ( + loaded_df.loc[self.spatial_groups["ungrouped"], "value"] + ) if self.spatial_groups["grouped"]: for group in self.spatial_groups["grouped"]: - self.parameters.loc[group, "value"] = loaded_df.loc[",".join(group), "value"] + self.parameters.loc[group, "value"] = loaded_df.loc[ + ",".join(group), "value" + ] def get_default(self, param): - if param in self.pnames_overlap_operation_sum or param in self.pnames_overlap_operation_reductionprod: + if ( + param in self.pnames_overlap_operation_sum + or param in self.pnames_overlap_operation_reductionprod + ): return 0.0 else: return 1.0 @@ -198,7 +233,9 @@ def getReduction(self, param): def getReductionToWrite(self): # spatially ungrouped dataframe - df = self.parameters[self.parameters.index.isin(self.spatial_groups["ungrouped"])].copy() + df = self.parameters[ + self.parameters.index.isin(self.spatial_groups["ungrouped"]) + ].copy() df.index.name = "subpop" df["start_date"] = df["start_date"].astype("str") df["end_date"] = df["end_date"].astype("str") diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py b/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py index 489a48fbb..21165b0c8 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/StackedModifier.py @@ -32,7 +32,9 @@ def __init__( self.end_date = modinf.tf self.pnames_overlap_operation_sum = pnames_overlap_operation_sum - self.pnames_overlap_operation_reductionprod = pnames_overlap_operation_reductionprod + self.pnames_overlap_operation_reductionprod = ( + pnames_overlap_operation_reductionprod + ) self.subpops = subpops self.param_name = [] @@ -47,7 +49,9 @@ def __init__( if isinstance(scenario, str): settings = modifiers_library.get(scenario) if settings is None: - raise RuntimeError(f"couldn't find scenario in config file [got: {scenario}]") + raise RuntimeError( + f"couldn't find scenario in config file [got: {scenario}]" + ) # via profiling: faster to recreate the confuse view than to fetch+resolve due to confuse isinstance # checks scenario_npi_config = confuse.RootView([settings]) @@ -68,12 +72,16 @@ def __init__( ) new_params = sub_npi.param_name # either a list (if stacked) or a string - new_params = [new_params] if isinstance(new_params, str) else new_params # convert to list + new_params = ( + [new_params] if isinstance(new_params, str) else new_params + ) # convert to list # Add each parameter at first encounter, with a neutral start for new_p in new_params: if new_p not in self.param_name: self.param_name.append(new_p) - if new_p in pnames_overlap_operation_sum: # re.match("^transition_rate [1234567890]+$",new_p): + if ( + new_p in pnames_overlap_operation_sum + ): # re.match("^transition_rate [1234567890]+$",new_p): self.reductions[new_p] = 0 else: # for the reductionprod and product method, the initial neutral is 1 ) self.reductions[new_p] = 1 @@ -81,7 +89,9 @@ def __init__( for param in self.param_name: # Get reduction return a neutral value for this overlap operation if no parameeter exists reduction = sub_npi.getReduction(param) - if param in pnames_overlap_operation_sum: # re.match("^transition_rate [1234567890]+$",param): + if ( + param in pnames_overlap_operation_sum + ): # re.match("^transition_rate [1234567890]+$",param): self.reductions[param] += reduction elif param in pnames_overlap_operation_reductionprod: self.reductions[param] *= 1 - reduction @@ -104,7 +114,9 @@ def __init__( self.reduction_params.clear() for param in self.param_name: - if param in pnames_overlap_operation_reductionprod: # re.match("^transition_rate \d+$",param): + if ( + param in pnames_overlap_operation_reductionprod + ): # re.match("^transition_rate \d+$",param): self.reductions[param] = 1 - self.reductions[param] # check that no NPI is called several times, and retourn them @@ -124,7 +136,10 @@ def __checkErrors(self): # ) def get_default(self, param): - if param in self.pnames_overlap_operation_sum or param in self.pnames_overlap_operation_reductionprod: + if ( + param in self.pnames_overlap_operation_sum + or param in self.pnames_overlap_operation_reductionprod + ): return 0.0 else: return 1.0 diff --git a/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py b/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py index f964d4c6e..7a2b4ccc6 100644 --- a/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py +++ b/flepimop/gempyor_pkg/src/gempyor/NPI/helpers.py @@ -12,7 +12,9 @@ def reduce_parameter( if isinstance(modification, pd.DataFrame): modification = modification.T modification.index = pd.to_datetime(modification.index.astype(str)) - modification = modification.resample("1D").ffill().to_numpy() # Type consistency: + modification = ( + modification.resample("1D").ffill().to_numpy() + ) # Type consistency: if method == "reduction_product": return parameter * (1 - modification) elif method == "sum": @@ -43,16 +45,22 @@ def get_spatial_groups(grp_config, affected_subpops: list) -> dict: else: spatial_groups["grouped"] = grp_config["subpop_groups"].get() spatial_groups["ungrouped"] = list( - set(affected_subpops) - set(flatten_list_of_lists(spatial_groups["grouped"])) + set(affected_subpops) + - set(flatten_list_of_lists(spatial_groups["grouped"])) ) # flatten the list of lists of grouped subpops, so we can do some checks flat_grouped_list = flatten_list_of_lists(spatial_groups["grouped"]) # check that all subpops are either grouped or ungrouped if set(flat_grouped_list + spatial_groups["ungrouped"]) != set(affected_subpops): - print("set of grouped and ungrouped subpops", set(flat_grouped_list + spatial_groups["ungrouped"])) + print( + "set of grouped and ungrouped subpops", + set(flat_grouped_list + spatial_groups["ungrouped"]), + ) print("set of affected subpops ", set(affected_subpops)) - raise ValueError(f"The two above sets are differs for for intervention with config \n {grp_config}") + raise ValueError( + f"The two above sets are differs for for intervention with config \n {grp_config}" + ) if len(set(flat_grouped_list + spatial_groups["ungrouped"])) != len( flat_grouped_list + spatial_groups["ungrouped"] ): diff --git a/flepimop/gempyor_pkg/src/gempyor/calibrate.py b/flepimop/gempyor_pkg/src/gempyor/calibrate.py index e5cb287aa..41de8b673 100644 --- a/flepimop/gempyor_pkg/src/gempyor/calibrate.py +++ b/flepimop/gempyor_pkg/src/gempyor/calibrate.py @@ -158,7 +158,9 @@ def calibrate( # TODO here for resume if resume or resume_location is not None: - print("Doing a resume, this only work with the same number of slot and parameters right now") + print( + "Doing a resume, this only work with the same number of slot and parameters right now" + ) p0 = None if resume_location is not None: backend = emcee.backends.HDFBackend(resume_location) @@ -188,14 +190,19 @@ def calibrate( backend=backend, moves=moves, ) - state = sampler.run_mcmc(p0, niter, progress=True, skip_initial_state_check=True) + state = sampler.run_mcmc( + p0, niter, progress=True, skip_initial_state_check=True + ) print(f"Done, mean acceptance fraction: {np.mean(sampler.acceptance_fraction):.3f}") # plotting the chain sampler = emcee.backends.HDFBackend(filename, read_only=True) gempyor.postprocess_inference.plot_chains( - inferpar=gempyor_inference.inferpar, sampler_output=sampler, sampled_slots=None, save_to=f"{run_id}_chains.pdf" + inferpar=gempyor_inference.inferpar, + sampler_output=sampler, + sampled_slots=None, + save_to=f"{run_id}_chains.pdf", ) print("EMCEE Run done, doing sampling") @@ -203,11 +210,14 @@ def calibrate( shutil.rmtree(project_path + "model_output/", ignore_errors=True) max_indices = np.argsort(sampler.get_log_prob()[-1, :])[-nsamples:] - samples = sampler.get_chain()[-1, max_indices, :] # the last iteration, for selected slots + samples = sampler.get_chain()[ + -1, max_indices, : + ] # the last iteration, for selected slots gempyor_inference.set_save(True) with multiprocessing.Pool(ncpu) as pool: results = pool.starmap( - gempyor_inference.get_logloss_as_single_number, [(samples[i, :],) for i in range(len(max_indices))] + gempyor_inference.get_logloss_as_single_number, + [(samples[i, :],) for i in range(len(max_indices))], ) # results = [] # for fn in gempyor.utils.list_filenames(folder="model_output/", filters=[run_id, "hosp.parquet"]): diff --git a/flepimop/gempyor_pkg/src/gempyor/compartments.py b/flepimop/gempyor_pkg/src/gempyor/compartments.py index ec87cf7e5..cfbce6b50 100644 --- a/flepimop/gempyor_pkg/src/gempyor/compartments.py +++ b/flepimop/gempyor_pkg/src/gempyor/compartments.py @@ -13,7 +13,13 @@ class Compartments: # Minimal object to be easily picklable for // runs - def __init__(self, seir_config=None, compartments_config=None, compartments_file=None, transitions_file=None): + def __init__( + self, + seir_config=None, + compartments_config=None, + compartments_file=None, + transitions_file=None, + ): self.times_set = 0 ## Something like this is needed for check script: @@ -29,7 +35,7 @@ def __init__(self, seir_config=None, compartments_config=None, compartments_file return def constructFromConfig(self, seir_config, compartment_config): - """ + """ This method is called by the constructor if the compartments are not loaded from a file. It will parse the compartments and transitions from the configuration files. It will populate self.compartments and self.transitions. @@ -43,7 +49,7 @@ def __eq__(self, other): ).all().all() def parse_compartments(self, seir_config, compartment_config): - """ Parse the compartments from the configuration file: + """Parse the compartments from the configuration file: seir_config: the configuration file for the SEIR model compartment_config: the configuration file for the compartments Example: if config says: @@ -75,40 +81,57 @@ def parse_compartments(self, seir_config, compartment_config): else: compartment_df = pd.merge(compartment_df, tmp, on="key") compartment_df = compartment_df.drop(["key"], axis=1) - compartment_df["name"] = compartment_df.apply(lambda x: reduce(lambda a, b: a + "_" + b, x), axis=1) + compartment_df["name"] = compartment_df.apply( + lambda x: reduce(lambda a, b: a + "_" + b, x), axis=1 + ) return compartment_df def parse_transitions(self, seir_config, fake_config=False): rc = reduce( - lambda a, b: pd.concat([a, self.parse_single_transition(seir_config, b, fake_config)]), + lambda a, b: pd.concat( + [a, self.parse_single_transition(seir_config, b, fake_config)] + ), seir_config["transitions"], pd.DataFrame(), ) rc = rc.reset_index(drop=True) return rc - def check_transition_element(self, single_transition_config, problem_dimension=None): + def check_transition_element( + self, single_transition_config, problem_dimension=None + ): return True def check_transition_elements(self, single_transition_config, problem_dimension): return True - def access_original_config_by_multi_index(self, config_piece, index, dimension=None, encapsulate_as_list=False): + def access_original_config_by_multi_index( + self, config_piece, index, dimension=None, encapsulate_as_list=False + ): if dimension is None: dimension = [None for i in index] tmp = [y for y in zip(index, range(len(index)), dimension)] tmp = zip(index, range(len(index)), dimension) - tmp = [list_access_element_safe(config_piece[x[1]], x[0], x[2], encapsulate_as_list) for x in tmp] + tmp = [ + list_access_element_safe( + config_piece[x[1]], x[0], x[2], encapsulate_as_list + ) + for x in tmp + ] return tmp def expand_transition_elements(self, single_transition_config, problem_dimension): - proportion_size = get_list_dimension(single_transition_config["proportional_to"]) + proportion_size = get_list_dimension( + single_transition_config["proportional_to"] + ) new_transition_config = single_transition_config.copy() # replace "source" by the actual source from the config for p_idx in range(proportion_size): if new_transition_config["proportional_to"][p_idx] == "source": - new_transition_config["proportional_to"][p_idx] = new_transition_config["source"] + new_transition_config["proportional_to"][p_idx] = new_transition_config[ + "source" + ] temp_array = np.zeros(problem_dimension) @@ -116,44 +139,80 @@ def expand_transition_elements(self, single_transition_config, problem_dimension new_transition_config["destination"] = np.zeros(problem_dimension, dtype=object) new_transition_config["rate"] = np.zeros(problem_dimension, dtype=object) - new_transition_config["proportional_to"] = np.zeros(problem_dimension, dtype=object) - new_transition_config["proportion_exponent"] = np.zeros(problem_dimension, dtype=object) + new_transition_config["proportional_to"] = np.zeros( + problem_dimension, dtype=object + ) + new_transition_config["proportion_exponent"] = np.zeros( + problem_dimension, dtype=object + ) - it = np.nditer(temp_array, flags=["multi_index"]) # it is an iterator that will go through all the indexes of the array + it = np.nditer( + temp_array, flags=["multi_index"] + ) # it is an iterator that will go through all the indexes of the array for x in it: try: - new_transition_config["source"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["source"], it.multi_index) + new_transition_config["source"][it.multi_index] = ( + list_recursive_convert_to_string( + self.access_original_config_by_multi_index( + single_transition_config["source"], it.multi_index + ) + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `source:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `source:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e try: - new_transition_config["destination"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["destination"], it.multi_index) + new_transition_config["destination"][it.multi_index] = ( + list_recursive_convert_to_string( + self.access_original_config_by_multi_index( + single_transition_config["destination"], it.multi_index + ) + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `destination:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `destination:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e - + try: - new_transition_config["rate"][it.multi_index] = list_recursive_convert_to_string( - self.access_original_config_by_multi_index(single_transition_config["rate"], it.multi_index) + new_transition_config["rate"][it.multi_index] = ( + list_recursive_convert_to_string( + self.access_original_config_by_multi_index( + single_transition_config["rate"], it.multi_index + ) + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `rate:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `rate:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e @@ -173,43 +232,68 @@ def expand_transition_elements(self, single_transition_config, problem_dimension ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `proportional_to:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `proportional_to:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e - - if "proportion_exponent" in single_transition_config: # if proportion_exponent is not defined, it is set to 1 + + if ( + "proportion_exponent" in single_transition_config + ): # if proportion_exponent is not defined, it is set to 1 try: self.access_original_config_by_multi_index( single_transition_config["proportion_exponent"][0], it.multi_index, problem_dimension, ) - new_transition_config["proportion_exponent"][it.multi_index] = list_recursive_convert_to_string( - [ - self.access_original_config_by_multi_index( - single_transition_config["proportion_exponent"][p_idx], - it.multi_index, - problem_dimension, - ) - for p_idx in range(proportion_size) - ] + new_transition_config["proportion_exponent"][it.multi_index] = ( + list_recursive_convert_to_string( + [ + self.access_original_config_by_multi_index( + single_transition_config["proportion_exponent"][ + p_idx + ], + it.multi_index, + problem_dimension, + ) + for p_idx in range(proportion_size) + ] + ) ) except Exception as e: print(f"Error {e}:") - print(f">>> in expand_transition_elements for `proportion_exponent:` at index {it.multi_index}") - print(f">>> this transition source is: {single_transition_config['source']}") - print(f">>> this transition destination is: {single_transition_config['destination']}") + print( + f">>> in expand_transition_elements for `proportion_exponent:` at index {it.multi_index}" + ) + print( + f">>> this transition source is: {single_transition_config['source']}" + ) + print( + f">>> this transition destination is: {single_transition_config['destination']}" + ) print(f"transition_dimension: {problem_dimension}") raise e else: - new_transition_config["proportion_exponent"][it.multi_index] = ["1"] * proportion_size + new_transition_config["proportion_exponent"][it.multi_index] = [ + "1" + ] * proportion_size return new_transition_config def format_source(self, source_column): - rc = [y for y in map(lambda x: reduce(lambda a, b: str(a) + "_" + str(b), x), source_column)] + rc = [ + y + for y in map( + lambda x: reduce(lambda a, b: str(a) + "_" + str(b), x), source_column + ) + ] return rc def unformat_source(self, source_column): @@ -231,7 +315,12 @@ def unformat_destination(self, destination_column): return rc def format_rate(self, rate_column): - rc = [y for y in map(lambda x: reduce(lambda a, b: str(a) + "%*%" + str(b), x), rate_column)] + rc = [ + y + for y in map( + lambda x: reduce(lambda a, b: str(a) + "%*%" + str(b), x), rate_column + ) + ] return rc def unformat_rate(self, rate_column, compartment_dimension): @@ -251,7 +340,9 @@ def format_proportional_to(self, proportional_to_column): lambda x: reduce( lambda a, b: str(a) + "_" + str(b), map( - lambda x: reduce(lambda a, b: str(a) + "+" + str(b), as_list(x)), + lambda x: reduce( + lambda a, b: str(a) + "+" + str(b), as_list(x) + ), x, ), ), @@ -284,27 +375,41 @@ def format_proportion_exponent(self, proportion_exponent_column): ] return rc - def unformat_proportion_exponent(self, proportion_exponent_column, compartment_dimension): + def unformat_proportion_exponent( + self, proportion_exponent_column, compartment_dimension + ): rc = [x.split("%*%") for x in proportion_exponent_column] for row in range(len(rc)): - rc[row] = [x.split("*", maxsplit=compartment_dimension - 1) for x in rc[row]] + rc[row] = [ + x.split("*", maxsplit=compartment_dimension - 1) for x in rc[row] + ] for elem in rc[row]: while len(elem) < compartment_dimension: elem.append(1) return rc - def parse_single_transition(self, seir_config, single_transition_config, fake_config=False): + def parse_single_transition( + self, seir_config, single_transition_config, fake_config=False + ): ## This method relies on having run parse_compartments if not fake_config: single_transition_config = single_transition_config.get() self.check_transition_element(single_transition_config["source"]) self.check_transition_element(single_transition_config["destination"]) - source_dimension = [get_list_dimension(x) for x in single_transition_config["source"]] - destination_dimension = [get_list_dimension(x) for x in single_transition_config["destination"]] - problem_dimension = reduce(lambda x, y: max(x, y), (source_dimension, destination_dimension)) + source_dimension = [ + get_list_dimension(x) for x in single_transition_config["source"] + ] + destination_dimension = [ + get_list_dimension(x) for x in single_transition_config["destination"] + ] + problem_dimension = reduce( + lambda x, y: max(x, y), (source_dimension, destination_dimension) + ) self.check_transition_elements(single_transition_config, problem_dimension) - transitions = self.expand_transition_elements(single_transition_config, problem_dimension) + transitions = self.expand_transition_elements( + single_transition_config, problem_dimension + ) tmp_array = np.zeros(problem_dimension) it = np.nditer(tmp_array, flags=["multi_index"]) @@ -316,8 +421,12 @@ def parse_single_transition(self, seir_config, single_transition_config, fake_co "source": [transitions["source"][it.multi_index]], "destination": [transitions["destination"][it.multi_index]], "rate": [transitions["rate"][it.multi_index]], - "proportional_to": [transitions["proportional_to"][it.multi_index]], - "proportion_exponent": [transitions["proportion_exponent"][it.multi_index]], + "proportional_to": [ + transitions["proportional_to"][it.multi_index] + ], + "proportion_exponent": [ + transitions["proportion_exponent"][it.multi_index] + ], }, index=[0], ) @@ -328,7 +437,10 @@ def parse_single_transition(self, seir_config, single_transition_config, fake_co return rc def toFile( - self, compartments_file="compartments.parquet", transitions_file="transitions.parquet", write_parquet=True + self, + compartments_file="compartments.parquet", + transitions_file="transitions.parquet", + write_parquet=True, ): out_df = self.compartments.copy() if write_parquet: @@ -341,8 +453,12 @@ def toFile( out_df["source"] = self.format_source(out_df["source"]) out_df["destination"] = self.format_destination(out_df["destination"]) out_df["rate"] = self.format_rate(out_df["rate"]) - out_df["proportional_to"] = self.format_proportional_to(out_df["proportional_to"]) - out_df["proportion_exponent"] = self.format_proportion_exponent(out_df["proportion_exponent"]) + out_df["proportional_to"] = self.format_proportional_to( + out_df["proportional_to"] + ) + out_df["proportion_exponent"] = self.format_proportion_exponent( + out_df["proportion_exponent"] + ) if write_parquet: pa_df = pa.Table.from_pandas(out_df, preserve_index=False) pa.parquet.write_table(pa_df, transitions_file) @@ -355,9 +471,15 @@ def fromFile(self, compartments_file, transitions_file): self.transitions = pq.read_table(transitions_file).to_pandas() compartment_dimension = self.compartments.shape[1] - 1 self.transitions["source"] = self.unformat_source(self.transitions["source"]) - self.transitions["destination"] = self.unformat_destination(self.transitions["destination"]) - self.transitions["rate"] = self.unformat_rate(self.transitions["rate"], compartment_dimension) - self.transitions["proportional_to"] = self.unformat_proportional_to(self.transitions["proportional_to"]) + self.transitions["destination"] = self.unformat_destination( + self.transitions["destination"] + ) + self.transitions["rate"] = self.unformat_rate( + self.transitions["rate"], compartment_dimension + ) + self.transitions["proportional_to"] = self.unformat_proportional_to( + self.transitions["proportional_to"] + ) self.transitions["proportion_exponent"] = self.unformat_proportion_exponent( self.transitions["proportion_exponent"], compartment_dimension ) @@ -371,7 +493,9 @@ def get_comp_idx(self, comp_dict: dict, error_info: str = "no information") -> i :param comp_dict: :return: """ - mask = pd.concat([self.compartments[k] == v for k, v in comp_dict.items()], axis=1).all(axis=1) + mask = pd.concat( + [self.compartments[k] == v for k, v in comp_dict.items()], axis=1 + ).all(axis=1) comp_idx = self.compartments[mask].index.values if len(comp_idx) != 1: raise ValueError( @@ -382,10 +506,11 @@ def get_comp_idx(self, comp_dict: dict, error_info: str = "no information") -> i def get_ncomp(self) -> int: return len(self.compartments) - def get_transition_array(self): with Timer("SEIR.compartments"): - transition_array = np.zeros((self.transitions.shape[1], self.transitions.shape[0]), dtype="int64") + transition_array = np.zeros( + (self.transitions.shape[1], self.transitions.shape[0]), dtype="int64" + ) for cit, colname in enumerate(("source", "destination")): for it, elem in enumerate(self.transitions[colname]): elem = reduce(lambda a, b: a + "_" + b, elem) @@ -395,7 +520,9 @@ def get_transition_array(self): rc = compartment if rc == -1: print(self.compartments) - raise ValueError(f"Could not find {colname} defined by {elem} in compartments") + raise ValueError( + f"Could not find {colname} defined by {elem} in compartments" + ) transition_array[cit, it] = rc unique_strings = [] @@ -417,8 +544,12 @@ def get_transition_array(self): # parenthesis are now supported # assert reduce(lambda a, b: a and b, [(x.find("(") == -1) for x in unique_strings]) # assert reduce(lambda a, b: a and b, [(x.find(")") == -1) for x in unique_strings]) - assert reduce(lambda a, b: a and b, [(x.find("%") == -1) for x in unique_strings]) - assert reduce(lambda a, b: a and b, [(x.find(" ") == -1) for x in unique_strings]) + assert reduce( + lambda a, b: a and b, [(x.find("%") == -1) for x in unique_strings] + ) + assert reduce( + lambda a, b: a and b, [(x.find(" ") == -1) for x in unique_strings] + ) for it, elem in enumerate(self.transitions["rate"]): candidate = reduce(lambda a, b: a + "*" + b, elem) @@ -454,8 +585,12 @@ def get_transition_array(self): # rc = compartment # if rc == -1: # raise ValueError(f"Could not find match for {elem3} in compartments") - proportion_info[0][current_proportion_sum_it] = current_proportion_sum_start - proportion_info[1][current_proportion_sum_it] = current_proportion_sum_start + len(elem_tmp) + proportion_info[0][ + current_proportion_sum_it + ] = current_proportion_sum_start + proportion_info[1][current_proportion_sum_it] = ( + current_proportion_sum_start + len(elem_tmp) + ) current_proportion_sum_it += 1 current_proportion_sum_start += len(elem_tmp) proportion_compartment_index = 0 @@ -466,7 +601,9 @@ def get_transition_array(self): # candidate = candidate.replace("*1", "") if not candidate in unique_strings: raise ValueError("Something went wrong") - rc = [it for it, x in enumerate(unique_strings) if x == candidate][0] + rc = [it for it, x in enumerate(unique_strings) if x == candidate][ + 0 + ] proportion_info[2][proportion_compartment_index] = rc proportion_compartment_index += 1 @@ -490,7 +627,9 @@ def get_transition_array(self): if self.compartments["name"][compartment] == elem3: rc = compartment if rc == -1: - raise ValueError(f"Could not find proportional_to {elem3} in compartments") + raise ValueError( + f"Could not find proportional_to {elem3} in compartments" + ) proportion_array[proportion_index] = rc proportion_index += 1 @@ -528,21 +667,29 @@ def get_transition_array(self): def parse_parameters(self, parameters, parameter_names, unique_strings): # parsed_parameters_old = self.parse_parameter_strings_to_numpy_arrays(parameters, parameter_names, unique_strings) - parsed_parameters = self.parse_parameter_strings_to_numpy_arrays_v2(parameters, parameter_names, unique_strings) + parsed_parameters = self.parse_parameter_strings_to_numpy_arrays_v2( + parameters, parameter_names, unique_strings + ) # for i in range(len(unique_strings)): # print(unique_strings[i], (parsed_parameters[i]==parsed_parameters_old[i]).all()) return parsed_parameters - def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names, string_list): + def parse_parameter_strings_to_numpy_arrays_v2( + self, parameters, parameter_names, string_list + ): # is using eval a better way ??? import sympy as sp # Validate input lengths if len(parameters) != len(parameter_names): - raise ValueError("Number of parameter values does not match the number of parameter names.") + raise ValueError( + "Number of parameter values does not match the number of parameter names." + ) # Define the symbols used in the formulas - symbolic_parameters_namespace = {name: sp.symbols(name) for name in parameter_names} + symbolic_parameters_namespace = { + name: sp.symbols(name) for name in parameter_names + } symbolic_parameters = [sp.symbols(name) for name in parameter_names] @@ -554,12 +701,18 @@ def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names f = sp.sympify(formula, locals=symbolic_parameters_namespace) parsed_formulas.append(f) except Exception as e: - print(f"Cannot parse formula: '{formula}' from parameters {parameter_names}") + print( + f"Cannot parse formula: '{formula}' from parameters {parameter_names}" + ) raise (e) # Print the error message for debugging # the list order needs to be right. - parameter_values = {param: value for param, value in zip(symbolic_parameters, parameters)} - parameter_values_list = [parameter_values[param] for param in symbolic_parameters] + parameter_values = { + param: value for param, value in zip(symbolic_parameters, parameters) + } + parameter_values_list = [ + parameter_values[param] for param in symbolic_parameters + ] # Create a lambdify function for substitution substitution_function = sp.lambdify(symbolic_parameters, parsed_formulas) @@ -573,7 +726,9 @@ def parse_parameter_strings_to_numpy_arrays_v2(self, parameters, parameter_names if not isinstance(substituted_formulas[i], np.ndarray): for k in range(len(substituted_formulas)): if isinstance(substituted_formulas[k], np.ndarray): - substituted_formulas[i] = substituted_formulas[i] * np.ones_like(substituted_formulas[k]) + substituted_formulas[i] = substituted_formulas[ + i + ] * np.ones_like(substituted_formulas[k]) return np.array(substituted_formulas) @@ -621,19 +776,29 @@ def parse_parameter_strings_to_numpy_arrays( is_resolvable = [x[0] or x[1] for x in zip(is_numeric, is_parameter)] is_totally_resolvable = reduce(lambda a, b: a and b, is_resolvable) if not is_totally_resolvable: - not_resolvable_indices = [it for it, x in enumerate(is_resolvable) if not x] - - tmp_rc[not_resolvable_indices] = self.parse_parameter_strings_to_numpy_arrays( - parameters, - parameter_names, - [string[not is_resolvable]], - operator_reduce_lambdas, - operators[1:], + not_resolvable_indices = [ + it for it, x in enumerate(is_resolvable) if not x + ] + + tmp_rc[not_resolvable_indices] = ( + self.parse_parameter_strings_to_numpy_arrays( + parameters, + parameter_names, + [string[not is_resolvable]], + operator_reduce_lambdas, + operators[1:], + ) ) for numeric_index in [x for x in range(len(is_numeric)) if is_numeric[x]]: tmp_rc[numeric_index] = parameters[0] * 0 + float(string[numeric_index]) - for parameter_index in [x for x in range(len(is_parameter)) if is_parameter[x]]: - parameter_name_index = [it for it, x in enumerate(parameter_names) if x == string[parameter_index]] + for parameter_index in [ + x for x in range(len(is_parameter)) if is_parameter[x] + ]: + parameter_name_index = [ + it + for it, x in enumerate(parameter_names) + if x == string[parameter_index] + ] tmp_rc[parameter_index] = parameters[parameter_name_index] rc[sit] = reduce(operator_reduce_lambdas[operators[0]], tmp_rc) @@ -648,7 +813,9 @@ def get_compartments_explicitDF(self): df = df.rename(columns=rename_dict) return df - def plot(self, output_file="transition_graph", source_filters=[], destination_filters=[]): + def plot( + self, output_file="transition_graph", source_filters=[], destination_filters=[] + ): """ if source_filters is [["age0to17"], ["OMICRON", "WILD"]], it means filter all transitions that have as source age0to17 AND (OMICRON OR WILD). @@ -712,8 +879,12 @@ def list_access_element_safe(thing, idx, dimension=None, encapsulate_as_list=Fal except Exception as e: print(f"Error {e}:") print(f">>> in list_access_element_safe for {thing} at index {idx}") - print(">>> This is often, but not always because the object above is a list (there are brackets around it).") - print(">>> and in this case it is not broadcast, so if you want to it to be broadcasted, you need remove the brackets around it.") + print( + ">>> This is often, but not always because the object above is a list (there are brackets around it)." + ) + print( + ">>> and in this case it is not broadcast, so if you want to it to be broadcasted, you need remove the brackets around it." + ) print(f"dimension: {dimension}") raise e @@ -755,7 +926,9 @@ def compartments(): def plot(): assert config["compartments"].exists() assert config["seir"].exists() - comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + comp = Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) # TODO: this should be a command like build compartments. ( @@ -774,7 +947,9 @@ def plot(): def export(): assert config["compartments"].exists() assert config["seir"].exists() - comp = Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + comp = Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) ( unique_strings, transition_array, diff --git a/flepimop/gempyor_pkg/src/gempyor/config_validator.py b/flepimop/gempyor_pkg/src/gempyor/config_validator.py index 95f6e3029..40f9423ff 100644 --- a/flepimop/gempyor_pkg/src/gempyor/config_validator.py +++ b/flepimop/gempyor_pkg/src/gempyor/config_validator.py @@ -1,33 +1,60 @@ import yaml -from pydantic import BaseModel, ValidationError, model_validator, Field, AfterValidator, validator +from pydantic import ( + BaseModel, + ValidationError, + model_validator, + Field, + AfterValidator, + validator, +) from datetime import date from typing import Dict, List, Union, Literal, Optional, Annotated, Any from functools import partial from gempyor import compartments + def read_yaml(file_path: str) -> dict: - with open(file_path, 'r') as stream: + with open(file_path, "r") as stream: config = yaml.safe_load(stream) - + return CheckConfig(**config).model_dump() - + + def allowed_values(v, values): assert v in values return v + # def parse_value(cls, values): # value = values.get('value') # parsed_val = compartments.Compartments.parse_parameter_strings_to_numpy_arrays_v2(value) # return parsed_val - + + class SubpopSetupConfig(BaseModel): geodata: str mobility: Optional[str] selected: List[str] = Field(default_factory=list) # state_level: Optional[bool] = False # pretty sure this doesn't exist anymore + class InitialConditionsConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['Default', 'SetInitialConditions', 'SetInitialConditionsFolderDraw', 'InitialConditionsFolderDraw', 'FromFile', 'plugin']))] = 'Default' + method: Annotated[ + str, + AfterValidator( + partial( + allowed_values, + values=[ + "Default", + "SetInitialConditions", + "SetInitialConditionsFolderDraw", + "InitialConditionsFolderDraw", + "FromFile", + "plugin", + ], + ) + ), + ] = "Default" initial_file_type: Optional[str] = None initial_conditions_file: Optional[str] = None proportional: Optional[bool] = None @@ -36,105 +63,163 @@ class InitialConditionsConfig(BaseModel): ignore_population_checks: Optional[bool] = None plugin_file_path: Optional[str] = None - @model_validator(mode='before') + @model_validator(mode="before") def validate_initial_file_check(cls, values): - method = values.get('method') - initial_conditions_file = values.get('initial_conditions_file') - initial_file_type = values.get('initial_file_type') - if method in {'FromFile', 'SetInitialConditions'} and not initial_conditions_file: - raise ValueError(f'Error in InitialConditions: An initial_conditions_file is required when method is {method}') - if method in {'InitialConditionsFolderDraw','SetInitialConditionsFolderDraw'} and not initial_file_type: - raise ValueError(f'Error in InitialConditions: initial_file_type is required when method is {method}') + method = values.get("method") + initial_conditions_file = values.get("initial_conditions_file") + initial_file_type = values.get("initial_file_type") + if ( + method in {"FromFile", "SetInitialConditions"} + and not initial_conditions_file + ): + raise ValueError( + f"Error in InitialConditions: An initial_conditions_file is required when method is {method}" + ) + if ( + method in {"InitialConditionsFolderDraw", "SetInitialConditionsFolderDraw"} + and not initial_file_type + ): + raise ValueError( + f"Error in InitialConditions: initial_file_type is required when method is {method}" + ) return values - - @model_validator(mode='before') + + @model_validator(mode="before") def plugin_filecheck(cls, values): - method = values.get('method') - plugin_file_path = values.get('plugin_file_path') - if method == 'plugin' and not plugin_file_path: - raise ValueError('Error in InitialConditions: a plugin file path is required when method is plugin') + method = values.get("method") + plugin_file_path = values.get("plugin_file_path") + if method == "plugin" and not plugin_file_path: + raise ValueError( + "Error in InitialConditions: a plugin file path is required when method is plugin" + ) return values class SeedingConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['NoSeeding', 'PoissonDistributed', 'FolderDraw', 'FromFile', 'plugin']))] = 'NoSeeding' # note: removed NegativeBinomialDistributed because no longer supported + method: Annotated[ + str, + AfterValidator( + partial( + allowed_values, + values=[ + "NoSeeding", + "PoissonDistributed", + "FolderDraw", + "FromFile", + "plugin", + ], + ) + ), + ] = "NoSeeding" # note: removed NegativeBinomialDistributed because no longer supported lambda_file: Optional[str] = None seeding_file_type: Optional[str] = None seeding_file: Optional[str] = None plugin_file_path: Optional[str] = None - @model_validator(mode='before') + @model_validator(mode="before") def validate_seedingfile(cls, values): - method = values.get('method') - lambda_file = values.get('lambda_file') - seeding_file_type = values.get('seeding_file_type') - seeding_file = values.get('seeding_file') - if method == 'PoissonDistributed' and not lambda_file: - raise ValueError(f'Error in Seeding: A lambda_file is required when method is {method}') - if method == 'FolderDraw' and not seeding_file_type: - raise ValueError('Error in Seeding: A seeding_file_type is required when method is FolderDraw') - if method == 'FromFile' and not seeding_file: - raise ValueError('Error in Seeding: A seeding_file is required when method is FromFile') + method = values.get("method") + lambda_file = values.get("lambda_file") + seeding_file_type = values.get("seeding_file_type") + seeding_file = values.get("seeding_file") + if method == "PoissonDistributed" and not lambda_file: + raise ValueError( + f"Error in Seeding: A lambda_file is required when method is {method}" + ) + if method == "FolderDraw" and not seeding_file_type: + raise ValueError( + "Error in Seeding: A seeding_file_type is required when method is FolderDraw" + ) + if method == "FromFile" and not seeding_file: + raise ValueError( + "Error in Seeding: A seeding_file is required when method is FromFile" + ) return values - - @model_validator(mode='before') + + @model_validator(mode="before") def plugin_filecheck(cls, values): - method = values.get('method') - plugin_file_path = values.get('plugin_file_path') - if method == 'plugin' and not plugin_file_path: - raise ValueError('Error in Seeding: a plugin file path is required when method is plugin') + method = values.get("method") + plugin_file_path = values.get("plugin_file_path") + if method == "plugin" and not plugin_file_path: + raise ValueError( + "Error in Seeding: a plugin file path is required when method is plugin" + ) return values - + + class IntegrationConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['rk4', 'rk4.jit', 'best.current', 'legacy']))] = 'rk4' + method: Annotated[ + str, + AfterValidator( + partial(allowed_values, values=["rk4", "rk4.jit", "best.current", "legacy"]) + ), + ] = "rk4" dt: float = 2.0 + class ValueConfig(BaseModel): - distribution: str = 'fixed' - value: Optional[float] = None # NEED TO ADD ABILITY TO PARSE PARAMETERS + distribution: str = "fixed" + value: Optional[float] = None # NEED TO ADD ABILITY TO PARSE PARAMETERS mean: Optional[float] = None sd: Optional[float] = None a: Optional[float] = None b: Optional[float] = None - @model_validator(mode='before') + @model_validator(mode="before") def check_distr(cls, values): - distr = values.get('distribution') - value = values.get('value') - mean = values.get('mean') - sd = values.get('sd') - a = values.get('a') - b = values.get('b') - if distr != 'fixed': + distr = values.get("distribution") + value = values.get("value") + mean = values.get("mean") + sd = values.get("sd") + a = values.get("a") + b = values.get("b") + if distr != "fixed": if not mean and not sd: - raise ValueError('Error in value: mean and sd must be provided for non-fixed distributions') - if distr == 'truncnorm' and not a and not b: - raise ValueError('Error in value: a and b must be provided for truncated normal distributions') - if distr == 'fixed' and not value: - raise ValueError('Error in value: value must be provided for fixed distributions') + raise ValueError( + "Error in value: mean and sd must be provided for non-fixed distributions" + ) + if distr == "truncnorm" and not a and not b: + raise ValueError( + "Error in value: a and b must be provided for truncated normal distributions" + ) + if distr == "fixed" and not value: + raise ValueError( + "Error in value: value must be provided for fixed distributions" + ) return values + class BaseParameterConfig(BaseModel): value: Optional[ValueConfig] = None modifier_parameter: Optional[str] = None - name: Optional[str] = None # this is only for outcomes, to build outcome_prevalence_name (how to restrict this?) + name: Optional[str] = ( + None # this is only for outcomes, to build outcome_prevalence_name (how to restrict this?) + ) + class SeirParameterConfig(BaseParameterConfig): value: Optional[ValueConfig] = None - stacked_modifier_method: Annotated[str, AfterValidator(partial(allowed_values, values=['sum', 'product', 'reduction_product']))] = None + stacked_modifier_method: Annotated[ + str, + AfterValidator( + partial(allowed_values, values=["sum", "product", "reduction_product"]) + ), + ] = None rolling_mean_windows: Optional[float] = None timeseries: Optional[str] = None - @model_validator(mode='before') + @model_validator(mode="before") def which_value(cls, values): - value = values.get('value') is not None - timeseries = values.get('timeseries') is not None + value = values.get("value") is not None + timeseries = values.get("timeseries") is not None if value and timeseries: - raise ValueError('Error in seir::parameters: your parameter is both a timeseries and a value, please choose one') + raise ValueError( + "Error in seir::parameters: your parameter is both a timeseries and a value, please choose one" + ) return values - - -class TransitionConfig(BaseModel): + + +class TransitionConfig(BaseModel): # !! sometimes these are lists of lists and sometimes they are lists... how to deal with this? source: List[List[str]] destination: List[List[str]] @@ -142,11 +227,15 @@ class TransitionConfig(BaseModel): proportion_exponent: List[List[str]] proportional_to: List[str] + class SeirConfig(BaseModel): - integration: IntegrationConfig # is this Optional? - parameters: Dict[str, SeirParameterConfig] # there was a previous issue that gempyor doesn't work if there are no parameters (eg if just numbers are used in the transitions) - do we want to get around this? + integration: IntegrationConfig # is this Optional? + parameters: Dict[ + str, SeirParameterConfig + ] # there was a previous issue that gempyor doesn't work if there are no parameters (eg if just numbers are used in the transitions) - do we want to get around this? transitions: List[TransitionConfig] + class SinglePeriodModifierConfig(BaseModel): method: Literal["SinglePeriodModifier"] parameter: str @@ -157,15 +246,18 @@ class SinglePeriodModifierConfig(BaseModel): value: ValueConfig perturbation: Optional[ValueConfig] = None + class MultiPeriodDatesConfig(BaseModel): start_date: date end_date: date - + + class MultiPeriodGroupsConfig(BaseModel): subpop: List[str] subpop_groups: Optional[str] = None periods: List[MultiPeriodDatesConfig] + class MultiPeriodModifierConfig(BaseModel): method: Literal["MultiPeriodModifier"] parameter: str @@ -173,37 +265,47 @@ class MultiPeriodModifierConfig(BaseModel): value: ValueConfig perturbation: Optional[ValueConfig] = None + class StackedModifierConfig(BaseModel): method: Literal["StackedModifier"] modifiers: List[str] + class ModifiersConfig(BaseModel): scenarios: List[str] modifiers: Dict[str, Any] - + @field_validator("modifiers") def validate_data_dict(cls, value: Dict[str, Any]) -> Dict[str, Any]: errors = [] for key, entry in value.items(): method = entry.get("method") - if method not in {"SinglePeriodModifier", "MultiPeriodModifier", "StackedModifier"}: + if method not in { + "SinglePeriodModifier", + "MultiPeriodModifier", + "StackedModifier", + }: errors.append(f"Invalid modifier method: {method}") if errors: raise ValueError("Errors in modifiers:\n" + "\n".join(errors)) return value -class SourceConfig(BaseModel): # set up only for incidence or prevalence. Can this be any name? i don't think so atm +class SourceConfig( + BaseModel +): # set up only for incidence or prevalence. Can this be any name? i don't think so atm incidence: Dict[str, str] = None - prevalence: Dict[str, str] = None + prevalence: Dict[str, str] = None # note: these dictionaries have to have compartment names... more complicated to set this up - @model_validator(mode='before') + @model_validator(mode="before") def which_source(cls, values): - incidence = values.get('incidence') - prevalence = values.get('prevalence') + incidence = values.get("incidence") + prevalence = values.get("prevalence") if incidence and prevalence: - raise ValueError('Error in outcomes::source. Can only be incidence or prevalence, not both.') + raise ValueError( + "Error in outcomes::source. Can only be incidence or prevalence, not both." + ) return values # @model_validator(mode='before') # DOES NOT WORK @@ -214,12 +316,13 @@ def which_source(cls, values): # source_names.append(key) # return source_names # Access keys using a loop + class DelayFrameConfig(BaseModel): source: Optional[SourceConfig] = None probability: Optional[BaseParameterConfig] = None delay: Optional[BaseParameterConfig] = None duration: Optional[BaseParameterConfig] = None - sum: Optional[List[str]] = None # only for sums of other outcomes + sum: Optional[List[str]] = None # only for sums of other outcomes # @validator("sum") # def validate_sum_elements(cls, value: Optional[List[str]]) -> Optional[List[str]]: @@ -233,65 +336,96 @@ class DelayFrameConfig(BaseModel): # return value # note: ^^ this doesn't work yet because it needs to somehow be a level above? to access all OTHER source names - @model_validator(mode='before') + @model_validator(mode="before") def check_outcome_type(cls, values): - sum_present = values.get('sum') is not None - source_present = values.get('source') is not None + sum_present = values.get("sum") is not None + source_present = values.get("source") is not None if sum_present and source_present: - raise ValueError(f"Error in outcome: Both 'sum' and 'source' are present. Choose one.") + raise ValueError( + f"Error in outcome: Both 'sum' and 'source' are present. Choose one." + ) elif not sum_present and not source_present: - raise ValueError(f"Error in outcome: Neither 'sum' nor 'source' is present. Choose one.") + raise ValueError( + f"Error in outcome: Neither 'sum' nor 'source' is present. Choose one." + ) return values + class OutcomesConfig(BaseModel): - method: Literal["delayframe"] # Is this required? I don't see it anywhere in the gempyor code + method: Literal[ + "delayframe" + ] # Is this required? I don't see it anywhere in the gempyor code param_from_file: Optional[bool] = None param_subpop_file: Optional[str] = None outcomes: Dict[str, DelayFrameConfig] - - @model_validator(mode='before') + + @model_validator(mode="before") def check_paramfromfile_type(cls, values): - param_from_file = values.get('param_from_file') is not None - param_subpop_file = values.get('param_subpop_file') is not None + param_from_file = values.get("param_from_file") is not None + param_subpop_file = values.get("param_subpop_file") is not None if param_from_file and not param_subpop_file: - raise ValueError(f"Error in outcome: 'param_subpop_file' is required when 'param_from_file' is True") + raise ValueError( + f"Error in outcome: 'param_subpop_file' is required when 'param_from_file' is True" + ) return values + class ResampleConfig(BaseModel): aggregator: Optional[str] = None freq: Optional[str] = None skipna: Optional[bool] = False + class LikelihoodParams(BaseModel): scale: float # are there other options here? + class LikelihoodReg(BaseModel): - name: str + name: str + class LikelihoodConfig(BaseModel): - dist: Annotated[str, AfterValidator(partial(allowed_values, values=['pois', 'norm', 'norm_cov', 'nbinom', 'rmse', 'absolute_error']))] = None + dist: Annotated[ + str, + AfterValidator( + partial( + allowed_values, + values=["pois", "norm", "norm_cov", "nbinom", "rmse", "absolute_error"], + ) + ), + ] = None params: Optional[LikelihoodParams] = None + class StatisticsConfig(BaseModel): name: str sim_var: str data_var: str regularize: Optional[LikelihoodReg] = None resample: Optional[ResampleConfig] = None - scale: Optional[float] = None # is scale here or at likelihood level? - zero_to_one: Optional[bool] = False # is this the same as add_one? remove_na? + scale: Optional[float] = None # is scale here or at likelihood level? + zero_to_one: Optional[bool] = False # is this the same as add_one? remove_na? likelihood: LikelihoodConfig + class InferenceConfig(BaseModel): - method: Annotated[str, AfterValidator(partial(allowed_values, values=['emcee', 'default', 'classical']))] = 'default' # for now - i can only see emcee as an option here, otherwise ignored in classical - need to add these options - iterations_per_slot: Optional[int] # i think this is optional because it is also set in command line?? - do_inference: bool + method: Annotated[ + str, + AfterValidator( + partial(allowed_values, values=["emcee", "default", "classical"]) + ), + ] = "default" # for now - i can only see emcee as an option here, otherwise ignored in classical - need to add these options + iterations_per_slot: Optional[ + int + ] # i think this is optional because it is also set in command line?? + do_inference: bool gt_data_path: str statistics: Dict[str, StatisticsConfig] + class CheckConfig(BaseModel): name: str setup_name: Optional[str] = None @@ -312,32 +446,36 @@ class CheckConfig(BaseModel): outcome_modifiers: Optional[ModifiersConfig] = None inference: Optional[InferenceConfig] = None -# add validator for if modifiers exist but seir/outcomes do not - -# there is an error in the one below - @model_validator(mode='before') + # add validator for if modifiers exist but seir/outcomes do not + + # there is an error in the one below + @model_validator(mode="before") def verify_inference(cls, values): - inference_present = values.get('inference') is not None - start_date_groundtruth = values.get('start_date_groundtruth') is not None + inference_present = values.get("inference") is not None + start_date_groundtruth = values.get("start_date_groundtruth") is not None if inference_present and not start_date_groundtruth: - raise ValueError('Inference mode is enabled but no groundtruth dates are provided') + raise ValueError( + "Inference mode is enabled but no groundtruth dates are provided" + ) elif start_date_groundtruth and not inference_present: - raise ValueError('Groundtruth dates are provided but inference mode is not enabled') + raise ValueError( + "Groundtruth dates are provided but inference mode is not enabled" + ) return values - - @model_validator(mode='before') + + @model_validator(mode="before") def check_dates(cls, values): - start_date = values.get('start_date') - end_date = values.get('end_date') + start_date = values.get("start_date") + end_date = values.get("end_date") if start_date and end_date: if end_date <= start_date: - raise ValueError('end_date must be greater than start_date') + raise ValueError("end_date must be greater than start_date") return values - - @model_validator(mode='before') + + @model_validator(mode="before") def init_or_seed(cls, values): - init = values.get('initial_conditions') - seed = values.get('seeding') + init = values.get("initial_conditions") + seed = values.get("seeding") if not init or seed: - raise ValueError('either initial_conditions or seeding must be provided') + raise ValueError("either initial_conditions or seeding must be provided") return values diff --git a/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py b/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py index a847a5f90..f9ecba633 100644 --- a/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py +++ b/flepimop/gempyor_pkg/src/gempyor/dev/dev_seir.py @@ -36,7 +36,9 @@ ) seeding_data = modinf.seeding.get_from_config(sim_id=100, modinf=modinf) -initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) +initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf +) mobility_subpop_indices = modinf.mobility.indices mobility_data_indices = modinf.mobility.indptr @@ -48,7 +50,9 @@ modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -81,7 +85,12 @@ True, ) df = seir.states2Df(modinf, states) -assert df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] > 1 +assert ( + df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[ + str(modinf.tf), "20002" + ] + > 1 +) print(df) ts = df cp = "R" diff --git a/flepimop/gempyor_pkg/src/gempyor/dev/steps.py b/flepimop/gempyor_pkg/src/gempyor/dev/steps.py index 43066e5ee..0213fb912 100644 --- a/flepimop/gempyor_pkg/src/gempyor/dev/steps.py +++ b/flepimop/gempyor_pkg/src/gempyor/dev/steps.py @@ -53,7 +53,11 @@ def ode_integration( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -65,7 +69,9 @@ def ode_integration( def rhs(t, x, today): print("rhs.t", t) states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -80,52 +86,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -142,9 +168,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -168,18 +200,24 @@ def rhs(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -257,7 +295,11 @@ def rk4_integration1( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -267,7 +309,9 @@ def rk4_integration1( def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -282,52 +326,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -344,9 +408,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -380,18 +450,24 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -447,7 +523,11 @@ def rk4_integration2( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -458,7 +538,9 @@ def rk4_integration2( @jit(nopython=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -473,52 +555,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -535,9 +637,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -572,18 +680,24 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -644,7 +758,11 @@ def rk4_integration3( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -655,7 +773,9 @@ def rk4_integration3( @jit(nopython=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -670,52 +790,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -732,9 +872,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -753,16 +899,18 @@ def rk4_integrate(t, x, today): @jit(nopython=True) def day_wrapper_rk4(today, states_next): x_ = np.zeros((2, ncompartments, nspatial_nodes)) - for seeding_instance_idx in range(day_start_idx_dict[today], day_start_idx_dict[today + 1]): + for seeding_instance_idx in range( + day_start_idx_dict[today], day_start_idx_dict[today + 1] + ): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_subpops_dict[seeding_instance_idx] seeding_sources = seeding_sources_dict[seeding_instance_idx] seeding_destinations = seeding_destinations_dict[seeding_instance_idx] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts # ADD TO cumulative, this is debatable, @@ -838,7 +986,11 @@ def rk4_integration4( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -849,7 +1001,9 @@ def rk4_integration4( @jit(nopython=True) # , fastmath=True, parallel=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -864,52 +1018,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -926,9 +1100,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -963,18 +1143,24 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -1031,7 +1217,11 @@ def rk4_integration5( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -1058,18 +1248,24 @@ def rk4_integration5( this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -1092,7 +1288,9 @@ def rk4_integration5( x = x_ + kx[i - 1] * rk_coefs[i] states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -1104,48 +1302,72 @@ def rk4_integration5( relevant_number_in_comp = np.zeros((nspatial_nodes)) relevant_exponent = np.ones((nspatial_nodes)) for proportion_sum_index in range( - proportion_info[proportion_sum_starts_col][proportion_index], + proportion_info[proportion_sum_starts_col][ + proportion_index + ], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] - # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][ - today + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] ] + # exponents should not be a proportion, since we don't sum them over sum compartments + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][ + proportion_index + ] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 - ) == transitions[transition_proportion_stop_col][transition_index] + transitions[transition_proportion_start_col][ + transition_index + ] + + 1 + ) == transitions[transition_proportion_stop_col][ + transition_index + ] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 + - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][ - spatial_node - ] + * parameters[ + transitions[transition_rate_col][ + transition_index + ] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] rate_change_compartment = proportion_change_compartment @@ -1153,11 +1375,16 @@ def rk4_integration5( relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] ) - rate_change_compartment /= population[visiting_compartment] + rate_change_compartment /= population[ + visiting_compartment + ] rate_change_compartment *= parameters[ transitions[transition_rate_col][transition_index] ][today][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + total_rate[spatial_node] *= ( + rate_keep_compartment + + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -1169,13 +1396,19 @@ def rk4_integration5( # ) # else: if True: - number_move = source_number * total_rate # * compound_adjusted_rate + number_move = ( + source_number * total_rate + ) # * compound_adjusted_rate # for spatial_node in range(nspatial_nodes): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move states_diff[ 1, transitions[transition_destination_col][transition_index], : ] += number_move # Cumumlative @@ -1234,7 +1467,11 @@ def rk4_integration2_smart( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -1248,7 +1485,9 @@ def rhs(t, x): if (today) > ndays: today = ndays - 1 states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -1263,52 +1502,72 @@ def rhs(t, x): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -1325,9 +1584,15 @@ def rhs(t, x): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return np.reshape(states_diff, states_diff.size) # return a 1D vector @@ -1372,20 +1637,34 @@ def rk4_integrate(today, x): seeding_data["day_start_idx"][today + 1], ): this_seeding_amounts = seeding_amounts[seeding_instance_idx] - seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] - seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_subpops = seeding_data["seeding_subpops"][ + seeding_instance_idx + ] + seeding_sources = seeding_data["seeding_sources"][ + seeding_instance_idx + ] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] - states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( + states_next[seeding_sources][ + seeding_subpops + ] -= this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * ( states_next[seeding_sources][seeding_subpops] > 0 ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape @@ -1450,9 +1729,12 @@ def rk4_integrate(today, x): ## Return "UniTuple(float64[:, :, :], 2) (" ## return states and cumlative states, both [ ndays x ncompartments x nspatial_nodes ] ## Dimensions - "int32," "int32," "int32," ## ncompartments ## nspatial_nodes ## Number of days + "int32," + "int32," + "int32," ## ncompartments ## nspatial_nodes ## Number of days ## Parameters - "float64[:, :, :]," "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt + "float64[:, :, :]," + "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt ## Transitions "int64[:, :]," ## transitions [ [source, destination, proportion_start, proportion_stop, rate] x ntransitions ] "int64[:, :]," ## proportions_info [ [sum_starts, sum_stops, exponent] x ntransition_proportions ] @@ -1504,7 +1786,11 @@ def rk4_integration_aot( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -1515,7 +1801,9 @@ def rk4_integration_aot( def rhs(t, x, today): # states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] states_current = x[0] - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -1530,52 +1818,72 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -1592,9 +1900,15 @@ def rhs(t, x, today): # if number_move[spatial_node] > states_current[transitions[transition_source_col][transition_index]][spatial_node]: # number_move[spatial_node] = states_current[transitions[transition_source_col][transition_index]][spatial_node] # Not possible to enforce this anymore, but it shouldn't be a problem or maybe ? # TODO - states_diff[0, transitions[transition_source_col][transition_index]] -= number_move - states_diff[0, transitions[transition_destination_col][transition_index]] += number_move - states_diff[1, transitions[transition_destination_col][transition_index], :] += number_move # Cumumlative + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= number_move + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += number_move + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += number_move # Cumumlative # states_current = states_next.copy() return states_diff # return a 1D vector @@ -1628,18 +1942,24 @@ def rk4_integrate(t, x, today): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts ### Shape diff --git a/flepimop/gempyor_pkg/src/gempyor/inference.py b/flepimop/gempyor_pkg/src/gempyor/inference.py index 81b1c22d0..5cc6bb79c 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference.py @@ -82,11 +82,21 @@ def simulation_atomic( np.random.seed(int.from_bytes(os.urandom(4), byteorder="little")) random_id = np.random.randint(0, 1e8) - npi_seir = seir.build_npi_SEIR(modinf=modinf, load_ID=False, sim_id2load=None, config=config, bypass_DF=snpi_df_in) + npi_seir = seir.build_npi_SEIR( + modinf=modinf, + load_ID=False, + sim_id2load=None, + config=config, + bypass_DF=snpi_df_in, + ) if modinf.npi_config_outcomes: npi_outcomes = outcomes.build_outcome_modifiers( - modinf=modinf, load_ID=False, sim_id2load=None, config=config, bypass_DF=hnpi_df_in + modinf=modinf, + load_ID=False, + sim_id2load=None, + config=config, + bypass_DF=hnpi_df_in, ) else: npi_outcomes = None @@ -94,10 +104,14 @@ def simulation_atomic( # reduce them parameters = modinf.parameters.parameters_reduce(p_draw, npi_seir) # Parse them - parsed_parameters = modinf.compartments.parse_parameters(parameters, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + parameters, modinf.parameters.pnames, unique_strings + ) # Convert the seeding data dictionnary to a numba dictionnary - seeding_data_nbdict = nb.typed.Dict.empty(key_type=nb.types.unicode_type, value_type=nb.types.int64[:]) + seeding_data_nbdict = nb.typed.Dict.empty( + key_type=nb.types.unicode_type, value_type=nb.types.int64[:] + ) for k, v in seeding_data.items(): seeding_data_nbdict[k] = np.array(v, dtype=np.int64) @@ -151,7 +165,9 @@ def get_static_arguments(modinf): ) = modinf.compartments.get_transition_array() outcomes_parameters = outcomes.read_parameters_from_config(modinf) - npi_seir = seir.build_npi_SEIR(modinf=modinf, load_ID=False, sim_id2load=None, config=config) + npi_seir = seir.build_npi_SEIR( + modinf=modinf, load_ID=False, sim_id2load=None, config=config + ) if modinf.npi_config_outcomes: npi_outcomes = outcomes.build_outcome_modifiers( modinf=modinf, @@ -162,15 +178,23 @@ def get_static_arguments(modinf): else: npi_outcomes = None - p_draw = modinf.parameters.parameters_quick_draw(n_days=modinf.n_days, nsubpops=modinf.nsubpops) + p_draw = modinf.parameters.parameters_quick_draw( + n_days=modinf.n_days, nsubpops=modinf.nsubpops + ) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=0, modinf=modinf) - seeding_data, seeding_amounts = modinf.seeding.get_from_config(sim_id=0, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=0, modinf=modinf + ) + seeding_data, seeding_amounts = modinf.seeding.get_from_config( + sim_id=0, modinf=modinf + ) # reduce them parameters = modinf.parameters.parameters_reduce(p_draw, npi_seir) # Parse them - parsed_parameters = modinf.compartments.parse_parameters(parameters, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + parameters, modinf.parameters.pnames, unique_strings + ) if real_simulation: states = seir.steps_SEIR( @@ -190,7 +214,10 @@ def get_static_arguments(modinf): compartment_df = modinf.compartments.get_compartments_explicitDF() # Iterate over columns of the DataFrame and populate the dictionary for column in compartment_df.columns: - compartment_coords[column] = ("compartment", compartment_df[column].tolist()) + compartment_coords[column] = ( + "compartment", + compartment_df[column].tolist(), + ) coords = dict( date=pd.date_range(modinf.ti, modinf.tf, freq="D"), @@ -198,7 +225,9 @@ def get_static_arguments(modinf): subpop=modinf.subpop_struct.subpop_names, ) - zeros = np.zeros((len(coords["date"]), len(coords["mc_name"][1]), len(coords["subpop"]))) + zeros = np.zeros( + (len(coords["date"]), len(coords["mc_name"][1]), len(coords["subpop"])) + ) states = xr.Dataset( data_vars=dict( prevalence=(["date", "compartment", "subpop"], zeros), @@ -256,12 +285,16 @@ def autodetect_scenarios(config): seir_modifiers_scenarios = None if config["seir_modifiers"].exists(): if config["seir_modifiers"]["scenarios"].exists(): - seir_modifiers_scenarios = config["seir_modifiers"]["scenarios"].as_str_seq() + seir_modifiers_scenarios = config["seir_modifiers"][ + "scenarios" + ].as_str_seq() outcome_modifiers_scenarios = None if config["outcomes"].exists() and config["outcome_modifiers"].exists(): if config["outcome_modifiers"]["scenarios"].exists(): - outcome_modifiers_scenarios = config["outcome_modifiers"]["scenarios"].as_str_seq() + outcome_modifiers_scenarios = config["outcome_modifiers"][ + "scenarios" + ].as_str_seq() outcome_modifiers_scenarios = as_list(outcome_modifiers_scenarios) seir_modifiers_scenarios = as_list(seir_modifiers_scenarios) @@ -275,41 +308,41 @@ def autodetect_scenarios(config): return seir_modifiers_scenarios[0], outcome_modifiers_scenarios[0] + # rewrite the get log loss functions as single functions, not in a class. This is not faster # def get_logloss(proposal, inferpar, logloss, static_sim_arguments, modinf, silent=True, save=False): # if not inferpar.check_in_bound(proposal=proposal): # if not silent: # print("OUT OF BOUND!!") # return -np.inf, -np.inf, -np.inf -# +# # snpi_df_mod, hnpi_df_mod = inferpar.inject_proposal( # proposal=proposal, # snpi_df=static_sim_arguments["snpi_df_ref"], # hnpi_df=static_sim_arguments["hnpi_df_ref"], # ) -# +# # ss = copy.deepcopy(static_sim_arguments) # ss["snpi_df_in"] = snpi_df_mod # ss["hnpi_df_in"] = hnpi_df_mod # del ss["snpi_df_ref"] # del ss["hnpi_df_ref"] -# +# # outcomes_df = simulation_atomic(**ss, modinf=modinf, save=save) -# +# # ll_total, logloss, regularizations = logloss.compute_logloss( # model_df=outcomes_df, subpop_names=modinf.subpop_struct.subpop_names # ) # if not silent: # print(f"llik is {ll_total}") -# +# # return ll_total, logloss, regularizations -# +# # def get_logloss_as_single_number(proposal, inferpar, logloss, static_sim_arguments, modinf, silent=True, save=False): # ll_total, logloss, regularizations = get_logloss(proposal, inferpar, logloss, static_sim_arguments, modinf, silent, save) # return ll_total - class GempyorInference: def __init__( self, @@ -333,12 +366,20 @@ def __init__( config.set_file(os.path.join(path_prefix, config_filepath)) - self.seir_modifiers_scenario, self.outcome_modifiers_scenario = autodetect_scenarios(config) + self.seir_modifiers_scenario, self.outcome_modifiers_scenario = ( + autodetect_scenarios(config) + ) if run_id is None: run_id = file_paths.run_id() if prefix is None: - prefix = config["name"].get() + f"_{self.seir_modifiers_scenario}_{self.outcome_modifiers_scenario}" + "/" + run_id + "/" + prefix = ( + config["name"].get() + + f"_{self.seir_modifiers_scenario}_{self.outcome_modifiers_scenario}" + + "/" + + run_id + + "/" + ) in_run_id = run_id if out_run_id is None: out_run_id = in_run_id @@ -387,7 +428,8 @@ def __init__( self.do_inference = True self.inference_method = "emcee" self.inferpar = inference_parameter.InferenceParameters( - global_config=config, subpop_names=self.modinf.subpop_struct.subpop_names + global_config=config, + subpop_names=self.modinf.subpop_struct.subpop_names, ) self.logloss = logloss.LogLoss( inference_config=config["inference"], @@ -412,7 +454,14 @@ def set_save(self, save): def get_all_sim_arguments(self): # inferpar, logloss, static_sim_arguments, modinf, proposal, silent, save - return [self.inferpar, self.logloss, self.static_sim_arguments, self.modinf, self.silent, self.save] + return [ + self.inferpar, + self.logloss, + self.static_sim_arguments, + self.modinf, + self.silent, + self.save, + ] def get_logloss(self, proposal): if not self.inferpar.check_in_bound(proposal=proposal): @@ -479,11 +528,15 @@ def update_run_id(self, new_run_id, new_out_run_id=None): else: self.modinf.out_run_id = new_out_run_id - def one_simulation_legacy(self, sim_id2write: int, load_ID: bool = False, sim_id2load: int = None): + def one_simulation_legacy( + self, sim_id2write: int, load_ID: bool = False, sim_id2load: int = None + ): sim_id2write = int(sim_id2write) if load_ID: sim_id2load = int(sim_id2load) - with Timer(f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}"): + with Timer( + f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}" + ): with Timer("onerun_SEIR"): seir.onerun_SEIR( sim_id2write=sim_id2write, @@ -533,16 +586,31 @@ def one_simulation( sim_id2load = int(sim_id2load) self.lastsim_sim_id2load = sim_id2load - with Timer(f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}"): + with Timer( + f">>> GEMPYOR onesim {'(loading file)' if load_ID else '(from config)'}" + ): if not self.already_built and self.modinf.outcomes_config is not None: - self.outcomes_parameters = outcomes.read_parameters_from_config(self.modinf) + self.outcomes_parameters = outcomes.read_parameters_from_config( + self.modinf + ) npi_outcomes = None if parallel: with Timer("//things"): - with ProcessPoolExecutor(max_workers=max(mp.cpu_count(), 3)) as executor: - ret_seir = executor.submit(seir.build_npi_SEIR, self.modinf, load_ID, sim_id2load, config) - if self.modinf.outcomes_config is not None and self.modinf.npi_config_outcomes: + with ProcessPoolExecutor( + max_workers=max(mp.cpu_count(), 3) + ) as executor: + ret_seir = executor.submit( + seir.build_npi_SEIR, + self.modinf, + load_ID, + sim_id2load, + config, + ) + if ( + self.modinf.outcomes_config is not None + and self.modinf.npi_config_outcomes + ): ret_outcomes = executor.submit( outcomes.build_outcome_modifiers, self.modinf, @@ -551,7 +619,9 @@ def one_simulation( config, ) if not self.already_built: - ret_comparments = executor.submit(self.modinf.compartments.get_transition_array) + ret_comparments = executor.submit( + self.modinf.compartments.get_transition_array + ) # print("expections:", ret_seir.exception(), ret_outcomes.exception(), ret_comparments.exception()) @@ -564,15 +634,24 @@ def one_simulation( ) = ret_comparments.result() self.already_built = True npi_seir = ret_seir.result() - if self.modinf.outcomes_config is not None and self.modinf.npi_config_outcomes: + if ( + self.modinf.outcomes_config is not None + and self.modinf.npi_config_outcomes + ): npi_outcomes = ret_outcomes.result() else: if not self.already_built: self.build_structure() npi_seir = seir.build_npi_SEIR( - modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + modinf=self.modinf, + load_ID=load_ID, + sim_id2load=sim_id2load, + config=config, ) - if self.modinf.outcomes_config is not None and self.modinf.npi_config_outcomes: + if ( + self.modinf.outcomes_config is not None + and self.modinf.npi_config_outcomes + ): npi_outcomes = outcomes.build_outcome_modifiers( modinf=self.modinf, load_ID=load_ID, @@ -585,7 +664,9 @@ def one_simulation( with Timer("SEIR.parameters"): # Draw or load parameters - p_draw = self.get_seir_parameters(load_ID=load_ID, sim_id2load=sim_id2load) + p_draw = self.get_seir_parameters( + load_ID=load_ID, sim_id2load=sim_id2load + ) # reduce them parameters = self.modinf.parameters.parameters_reduce(p_draw, npi_seir) # Parse them @@ -598,8 +679,12 @@ def one_simulation( with Timer("onerun_SEIR.seeding"): if load_ID: - initial_conditions = self.modinf.initial_conditions.get_from_file(sim_id2load, modinf=self.modinf) - seeding_data, seeding_amounts = self.modinf.seeding.get_from_file(sim_id2load, modinf=self.modinf) + initial_conditions = self.modinf.initial_conditions.get_from_file( + sim_id2load, modinf=self.modinf + ) + seeding_data, seeding_amounts = self.modinf.seeding.get_from_file( + sim_id2load, modinf=self.modinf + ) else: initial_conditions = self.modinf.initial_conditions.get_from_config( sim_id2write, modinf=self.modinf @@ -645,7 +730,7 @@ def one_simulation( parameters=self.outcomes_parameters, loaded_values=loaded_values, npi=npi_outcomes, - bypass_seir_xr=states + bypass_seir_xr=states, ) self.lastsim_outcomes_df = outcomes_df self.lastsim_hpar_df = hpar_df @@ -660,14 +745,18 @@ def one_simulation( ) return 0 - def plot_transition_graph(self, output_file="transition_graph", source_filters=[], destination_filters=[]): + def plot_transition_graph( + self, output_file="transition_graph", source_filters=[], destination_filters=[] + ): self.modinf.compartments.plot( output_file=output_file, source_filters=source_filters, destination_filters=destination_filters, ) - def get_outcome_npi(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_outcome_npi( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): npi_outcomes = None if self.modinf.npi_config_outcomes: npi_outcomes = outcomes.build_outcome_modifiers( @@ -680,7 +769,9 @@ def get_outcome_npi(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypas ) return npi_outcomes - def get_seir_npi(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_seir_npi( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): npi_seir = seir.build_npi_SEIR( modinf=self.modinf, load_ID=load_ID, @@ -691,7 +782,9 @@ def get_seir_npi(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_F ) return npi_seir - def get_seir_parameters(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_seir_parameters( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): param_df = None if bypass_DF is not None: param_df = bypass_DF @@ -712,7 +805,9 @@ def get_seir_parameters(self, load_ID=False, sim_id2load=None, bypass_DF=None, b ) return p_draw - def get_seir_parametersDF(self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None): + def get_seir_parametersDF( + self, load_ID=False, sim_id2load=None, bypass_DF=None, bypass_FN=None + ): p_draw = self.get_seir_parameters( load_ID=load_ID, sim_id2load=sim_id2load, @@ -767,7 +862,9 @@ def get_parsed_parameters_seir( if not self.already_built: self.build_structure() - npi_seir = seir.build_npi_SEIR(modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi_seir = seir.build_npi_SEIR( + modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) p_draw = self.get_seir_parameters(load_ID=load_ID, sim_id2load=sim_id2load) parameters = self.modinf.parameters.parameters_reduce(p_draw, npi_seir) @@ -784,7 +881,9 @@ def get_reduced_parameters_seir( # bypass_DF=None, # bypass_FN=None, ): - npi_seir = seir.build_npi_SEIR(modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi_seir = seir.build_npi_SEIR( + modinf=self.modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) p_draw = self.get_seir_parameters(load_ID=load_ID, sim_id2load=sim_id2load) parameters = self.modinf.parameters.parameters_reduce(p_draw, npi_seir) @@ -805,7 +904,9 @@ def paramred_parallel(run_spec, snpi_fn): seir_modifiers_scenario="inference", # NPIs scenario to use outcome_modifiers_scenario="med", # Outcome scenario to use stoch_traj_flag=False, - path_prefix=run_spec["geodata"], # prefix where to find the folder indicated in subpop_setup$ + path_prefix=run_spec[ + "geodata" + ], # prefix where to find the folder indicated in subpop_setup$ ) snpi = pq.read_table(snpi_fn).to_pandas() @@ -816,7 +917,9 @@ def paramred_parallel(run_spec, snpi_fn): params_draw_arr = gempyor_inference.get_seir_parameters( bypass_FN=snpi_fn.replace("snpi", "spar") ) # could also accept (load_ID=True, sim_id2load=XXX) or (bypass_DF=) or (bypass_FN=) - param_reduc_from = gempyor_inference.get_seir_parameter_reduced(npi_seir=npi_seir, p_draw=params_draw_arr) + param_reduc_from = gempyor_inference.get_seir_parameter_reduced( + npi_seir=npi_seir, p_draw=params_draw_arr + ) return param_reduc_from @@ -831,7 +934,9 @@ def paramred_parallel_config(run_spec, dummy): seir_modifiers_scenario="inference", # NPIs scenario to use outcome_modifiers_scenario="med", # Outcome scenario to use stoch_traj_flag=False, - path_prefix=run_spec["geodata"], # prefix where to find the folder indicated in subpop_setup$ + path_prefix=run_spec[ + "geodata" + ], # prefix where to find the folder indicated in subpop_setup$ ) npi_seir = gempyor_inference.get_seir_npi() @@ -839,6 +944,8 @@ def paramred_parallel_config(run_spec, dummy): params_draw_arr = ( gempyor_inference.get_seir_parameters() ) # could also accept (load_ID=True, sim_id2load=XXX) or (bypass_DF=) or (bypass_FN=) - param_reduc_from = gempyor_inference.get_seir_parameter_reduced(npi_seir=npi_seir, p_draw=params_draw_arr) + param_reduc_from = gempyor_inference.get_seir_parameter_reduced( + npi_seir=npi_seir, p_draw=params_draw_arr + ) return param_reduc_from diff --git a/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py b/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py index 8f1b0fc53..30cab0bba 100644 --- a/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py +++ b/flepimop/gempyor_pkg/src/gempyor/inference_parameter.py @@ -36,9 +36,14 @@ def add_modifier(self, pname, ptype, parameter_config, subpops): """ # identify spatial group affected_subpops = set(subpops) - if parameter_config["subpop"].exists() and parameter_config["subpop"].get() != "all": + if ( + parameter_config["subpop"].exists() + and parameter_config["subpop"].get() != "all" + ): affected_subpops = {str(n.get()) for n in parameter_config["subpop"]} - spatial_groups = NPI.helpers.get_spatial_groups(parameter_config, list(affected_subpops)) + spatial_groups = NPI.helpers.get_spatial_groups( + parameter_config, list(affected_subpops) + ) # ungrouped subpop (all affected subpop by default) have one parameter per subpop if spatial_groups["ungrouped"]: @@ -87,11 +92,15 @@ def build_from_config(self, global_config, subpop_names): for config_part in ["seir_modifiers", "outcome_modifiers"]: if global_config[config_part].exists(): for npi in global_config[config_part]["modifiers"].get(): - if global_config[config_part]["modifiers"][npi]["perturbation"].exists(): + if global_config[config_part]["modifiers"][npi][ + "perturbation" + ].exists(): self.add_modifier( pname=npi, ptype=config_part, - parameter_config=global_config[config_part]["modifiers"][npi], + parameter_config=global_config[config_part]["modifiers"][ + npi + ], subpops=subpop_names, ) @@ -156,7 +165,10 @@ def check_in_bound(self, proposal) -> bool: Returns: bool: True if the proposal is within bounds, False otherwise. """ - if self.hit_lbs(proposal=proposal).any() or self.hit_ubs(proposal=proposal).any(): + if ( + self.hit_lbs(proposal=proposal).any() + or self.hit_ubs(proposal=proposal).any() + ): return False return True diff --git a/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py b/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py index 554f3910c..eef32531f 100644 --- a/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py +++ b/flepimop/gempyor_pkg/src/gempyor/initial_conditions.py @@ -45,19 +45,30 @@ def __init__( if self.initial_conditions_config is not None: if "ignore_population_checks" in self.initial_conditions_config.keys(): - self.ignore_population_checks = self.initial_conditions_config["ignore_population_checks"].get(bool) + self.ignore_population_checks = self.initial_conditions_config[ + "ignore_population_checks" + ].get(bool) if "allow_missing_subpops" in self.initial_conditions_config.keys(): - self.allow_missing_subpops = self.initial_conditions_config["allow_missing_subpops"].get(bool) + self.allow_missing_subpops = self.initial_conditions_config[ + "allow_missing_subpops" + ].get(bool) if "allow_missing_compartments" in self.initial_conditions_config.keys(): - self.allow_missing_compartments = self.initial_conditions_config["allow_missing_compartments"].get(bool) + self.allow_missing_compartments = self.initial_conditions_config[ + "allow_missing_compartments" + ].get(bool) # TODO: add check, this option onlywork with tidy dataframe if "proportional" in self.initial_conditions_config.keys(): - self.proportional_ic = self.initial_conditions_config["proportional"].get(bool) + self.proportional_ic = self.initial_conditions_config[ + "proportional" + ].get(bool) def get_from_config(self, sim_id: int, modinf) -> np.ndarray: method = "Default" - if self.initial_conditions_config is not None and "method" in self.initial_conditions_config.keys(): + if ( + self.initial_conditions_config is not None + and "method" in self.initial_conditions_config.keys() + ): method = self.initial_conditions_config["method"].as_str() if method == "Default": @@ -66,13 +77,20 @@ def get_from_config(self, sim_id: int, modinf) -> np.ndarray: y0[0, :] = modinf.subpop_pop return y0 # we finish here: no rest and not proportionallity applies - if method == "SetInitialConditions" or method == "SetInitialConditionsFolderDraw": + if ( + method == "SetInitialConditions" + or method == "SetInitialConditionsFolderDraw" + ): # TODO Think about - Does not support the new way of doing compartment indexing if method == "SetInitialConditionsFolderDraw": - ic_df = modinf.read_simID(ftype=self.initial_conditions_config["initial_file_type"], sim_id=sim_id) + ic_df = modinf.read_simID( + ftype=self.initial_conditions_config["initial_file_type"], + sim_id=sim_id, + ) else: ic_df = read_df( - self.path_prefix / self.initial_conditions_config["initial_conditions_file"].get(), + self.path_prefix + / self.initial_conditions_config["initial_conditions_file"].get(), ) y0 = read_initial_condition_from_tidydataframe( ic_df=ic_df, @@ -85,11 +103,13 @@ def get_from_config(self, sim_id: int, modinf) -> np.ndarray: elif method == "InitialConditionsFolderDraw" or method == "FromFile": if method == "InitialConditionsFolderDraw": ic_df = modinf.read_simID( - ftype=self.initial_conditions_config["initial_file_type"].get(), sim_id=sim_id + ftype=self.initial_conditions_config["initial_file_type"].get(), + sim_id=sim_id, ) elif method == "FromFile": ic_df = read_df( - self.path_prefix / self.initial_conditions_config["initial_conditions_file"].get(), + self.path_prefix + / self.initial_conditions_config["initial_conditions_file"].get(), ) y0 = read_initial_condition_from_seir_output( @@ -99,10 +119,14 @@ def get_from_config(self, sim_id: int, modinf) -> np.ndarray: allow_missing_subpops=self.allow_missing_subpops, ) else: - raise NotImplementedError(f"unknown initial conditions method [got: {method}]") + raise NotImplementedError( + f"unknown initial conditions method [got: {method}]" + ) # check that the inputed values sums to the subpop population: - check_population(y0=y0, modinf=modinf, ignore_population_checks=self.ignore_population_checks) + check_population( + y0=y0, modinf=modinf, ignore_population_checks=self.ignore_population_checks + ) return y0 @@ -136,7 +160,11 @@ def check_population(y0, modinf, ignore_population_checks=False): def read_initial_condition_from_tidydataframe( - ic_df, modinf, allow_missing_subpops, allow_missing_compartments, proportional_ic=False + ic_df, + modinf, + allow_missing_subpops, + allow_missing_compartments, + proportional_ic=False, ): rests = [] # Places to allocate the rest of the population y0 = np.zeros((modinf.compartments.compartments.shape[0], modinf.nsubpops)) @@ -145,9 +173,13 @@ def read_initial_condition_from_tidydataframe( states_pl = ic_df[ic_df["subpop"] == pl] for comp_idx, comp_name in modinf.compartments.compartments["name"].items(): if "mc_name" in states_pl.columns: - ic_df_compartment_val = states_pl[states_pl["mc_name"] == comp_name]["amount"] + ic_df_compartment_val = states_pl[ + states_pl["mc_name"] == comp_name + ]["amount"] else: - filters = modinf.compartments.compartments.iloc[comp_idx].drop("name") + filters = modinf.compartments.compartments.iloc[comp_idx].drop( + "name" + ) ic_df_compartment_val = states_pl.copy() for mc_name, mc_value in filters.items(): ic_df_compartment_val = ic_df_compartment_val[ @@ -177,7 +209,9 @@ def read_initial_condition_from_tidydataframe( logger.critical( f"No initial conditions for for subpop {pl}, assuming everyone (n={modinf.subpop_pop[pl_idx]}) in the first metacompartment ({modinf.compartments.compartments['name'].iloc[0]})" ) - raise ValueError("THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy") + raise ValueError( + "THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy" + ) # TODO: this is probably ok but highlighting for consistency if "proportional" in self.initial_conditions_config.keys(): if self.initial_conditions_config["proportional"].get(): @@ -202,7 +236,9 @@ def read_initial_condition_from_tidydataframe( return y0 -def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops, allow_missing_compartments): +def read_initial_condition_from_seir_output( + ic_df, modinf, allow_missing_subpops, allow_missing_compartments +): """ Read the initial conditions from the SEIR output. @@ -227,9 +263,13 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops ic_df["date"] = ic_df["date"].dt.date ic_df["date"] = ic_df["date"].astype(str) - ic_df = ic_df[(ic_df["date"] == str(modinf.ti)) & (ic_df["mc_value_type"] == "prevalence")] + ic_df = ic_df[ + (ic_df["date"] == str(modinf.ti)) & (ic_df["mc_value_type"] == "prevalence") + ] if ic_df.empty: - raise ValueError(f"There is no entry for initial time ti in the provided initial_conditions::states_file.") + raise ValueError( + f"There is no entry for initial time ti in the provided initial_conditions::states_file." + ) y0 = np.zeros((modinf.compartments.compartments.shape[0], modinf.nsubpops)) for comp_idx, comp_name in modinf.compartments.compartments["name"].items(): @@ -239,7 +279,9 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops filters = modinf.compartments.compartments.iloc[comp_idx].drop("name") ic_df_compartment = ic_df.copy() for mc_name, mc_value in filters.items(): - ic_df_compartment = ic_df_compartment[ic_df_compartment["mc_" + mc_name] == mc_value] + ic_df_compartment = ic_df_compartment[ + ic_df_compartment["mc_" + mc_name] == mc_value + ] if len(ic_df_compartment) > 1: # ic_df_compartment = ic_df_compartment.iloc[0] @@ -248,7 +290,9 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops ) elif ic_df_compartment.empty: if allow_missing_compartments: - ic_df_compartment = pd.DataFrame(0, columns=ic_df_compartment.columns, index=[0]) + ic_df_compartment = pd.DataFrame( + 0, columns=ic_df_compartment.columns, index=[0] + ) else: raise ValueError( f"Initial Conditions: Could not set compartment {comp_name} (id: {comp_idx}) in subpop {pl} (id: {pl_idx}). The data from the init file is {ic_df_compartment[pl]}." @@ -262,7 +306,9 @@ def read_initial_condition_from_seir_output(ic_df, modinf, allow_missing_subpops if pl in ic_df.columns: y0[comp_idx, pl_idx] = float(ic_df_compartment[pl].iloc[0]) elif allow_missing_subpops: - raise ValueError("THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy") + raise ValueError( + "THERE IS A BUG; REPORT THIS MESSAGE. Past implemenation was buggy" + ) # TODO this should set the full subpop, not just the 0th commpartment logger.critical( f"No initial conditions for for subpop {pl}, assuming everyone (n={modinf.subpop_pop[pl_idx]}) in the first metacompartments ({modinf.compartments.compartments['name'].iloc[0]})" diff --git a/flepimop/gempyor_pkg/src/gempyor/logloss.py b/flepimop/gempyor_pkg/src/gempyor/logloss.py index 95892dac0..0ed62fad2 100644 --- a/flepimop/gempyor_pkg/src/gempyor/logloss.py +++ b/flepimop/gempyor_pkg/src/gempyor/logloss.py @@ -17,7 +17,13 @@ class LogLoss: - def __init__(self, inference_config: confuse.ConfigView, subpop_struct, time_setup, path_prefix: str = "."): + def __init__( + self, + inference_config: confuse.ConfigView, + subpop_struct, + time_setup, + path_prefix: str = ".", + ): # TODO: bad format for gt because each date must have a value for each column, but if it doesn't and you add NA # then this NA has a meaning that depends on skip NA, which is annoying. # A lot of things can go wrong here, in the previous approach where GT was cast to xarray as @@ -35,20 +41,36 @@ def __init__(self, inference_config: confuse.ConfigView, subpop_struct, time_set # made the controversial choice of storing the gt as an xarray dataset instead of a dictionary # of dataframes - self.gt_xr = xr.Dataset.from_dataframe(self.gt.reset_index().set_index(["date", "subpop"])) + self.gt_xr = xr.Dataset.from_dataframe( + self.gt.reset_index().set_index(["date", "subpop"]) + ) # Very important: subsample the subpop in the population, in the right order, and sort by the date index. - self.gt_xr = self.gt_xr.sortby("date").reindex({"subpop": subpop_struct.subpop_names}) + self.gt_xr = self.gt_xr.sortby("date").reindex( + {"subpop": subpop_struct.subpop_names} + ) # This will force at 0, if skipna is False, data of some variable that don't exist if iother exist # and damn python datetime types are ugly... - self.first_date = max(pd.to_datetime(self.gt_xr.date[0].values).date(), time_setup.ti) - self.last_date = min(pd.to_datetime(self.gt_xr.date[-1].values).date(), time_setup.tf) + self.first_date = max( + pd.to_datetime(self.gt_xr.date[0].values).date(), time_setup.ti + ) + self.last_date = min( + pd.to_datetime(self.gt_xr.date[-1].values).date(), time_setup.tf + ) self.statistics = {} for key, value in inference_config["statistics"].items(): self.statistics[key] = statistics.Statistic(key, value) - def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename=None, **kwargs): + def plot_gt( + self, + ax=None, + subpop=None, + statistic=None, + subplot=False, + filename=None, + **kwargs, + ): """Plots ground truth data. Args: @@ -68,7 +90,10 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= fig, axes = plt.subplots( len(self.gt["subpop"].unique()), len(self.gt.columns.drop("subpop")), - figsize=(4 * len(self.gt.columns.drop("subpop")), 3 * len(self.gt["subpop"].unique())), + figsize=( + 4 * len(self.gt.columns.drop("subpop")), + 3 * len(self.gt["subpop"].unique()), + ), dpi=250, sharex=True, ) @@ -81,7 +106,9 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= subpops = [subpop] if statistic is None: - statistics = self.gt.columns.drop("subpop") # Assuming other columns are statistics + statistics = self.gt.columns.drop( + "subpop" + ) # Assuming other columns are statistics else: statistics = [statistic] @@ -89,14 +116,18 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= # One subplot for each subpop/statistic combination for i, subpop in enumerate(subpops): for j, stat in enumerate(statistics): - data_to_plot = self.gt[(self.gt["subpop"] == subpop)][stat].sort_index() + data_to_plot = self.gt[(self.gt["subpop"] == subpop)][ + stat + ].sort_index() axes[i, j].plot(data_to_plot, **kwargs) axes[i, j].set_title(f"{subpop} - {stat}") else: # All lines in a single plot for subpop in subpops: for stat in statistics: - data_to_plot = self.gt[(self.gt["subpop"] == subpop)][stat].sort_index() + data_to_plot = self.gt[(self.gt["subpop"] == subpop)][ + stat + ].sort_index() data_to_plot.plot(ax=ax, **kwargs, label=f"{subpop} - {stat}") if len(statistics) > 1: ax.legend() @@ -107,7 +138,10 @@ def plot_gt(self, ax=None, subpop=None, statistic=None, subplot=False, filename= plt.savefig(filename, **kwargs) # Save the figure if subplot: - return fig, axes # Return figure and subplots for potential further customization + return ( + fig, + axes, + ) # Return figure and subplots for potential further customization else: return ax # Optionally return the axis @@ -121,13 +155,17 @@ def compute_logloss(self, model_df, subpop_names): coords = {"statistic": list(self.statistics.keys()), "subpop": subpop_names} logloss = xr.DataArray( - np.zeros((len(coords["statistic"]), len(coords["subpop"]))), dims=["statistic", "subpop"], coords=coords + np.zeros((len(coords["statistic"]), len(coords["subpop"]))), + dims=["statistic", "subpop"], + coords=coords, ) regularizations = 0 model_xr = ( - xr.Dataset.from_dataframe(model_df.reset_index().set_index(["date", "subpop"])) + xr.Dataset.from_dataframe( + model_df.reset_index().set_index(["date", "subpop"]) + ) .sortby("date") .reindex({"subpop": subpop_names}) ) diff --git a/flepimop/gempyor_pkg/src/gempyor/model_info.py b/flepimop/gempyor_pkg/src/gempyor/model_info.py index 54f981d8f..dc3477f8a 100644 --- a/flepimop/gempyor_pkg/src/gempyor/model_info.py +++ b/flepimop/gempyor_pkg/src/gempyor/model_info.py @@ -1,6 +1,13 @@ import pandas as pd import datetime, os, logging, pathlib, confuse -from . import seeding, subpopulation_structure, parameters, compartments, file_paths, initial_conditions +from . import ( + seeding, + subpopulation_structure, + parameters, + compartments, + file_paths, + initial_conditions, +) from .utils import read_df, write_df logger = logging.getLogger(__name__) @@ -11,7 +18,9 @@ def __init__(self, config: confuse.ConfigView): self.ti = config["start_date"].as_date() self.tf = config["end_date"].as_date() if self.tf <= self.ti: - raise ValueError("tf (time to finish) is less than or equal to ti (time to start)") + raise ValueError( + "tf (time to finish) is less than or equal to ti (time to start)" + ) self.n_days = (self.tf - self.ti).days + 1 self.dates = pd.date_range(start=self.ti, end=self.tf, freq="D") @@ -29,7 +38,7 @@ class ModelInfo: seeding # One of seeding or initial_conditions is required when running seir outcomes # Required if running outcomes seir_modifiers # Not required. If exists, every modifier will be applied to seir parameters - outcomes_modifiers # Not required. If exists, every modifier will be applied to outcomes + outcomes_modifiers # Not required. If exists, every modifier will be applied to outcomes inference # Required if running inference ``` """ @@ -94,7 +103,9 @@ def __init__( # 3. What about subpopulations subpop_config = config["subpop_setup"] if "data_path" in config: - raise ValueError("The config has a data_path section. This is no longer supported.") + raise ValueError( + "The config has a data_path section. This is no longer supported." + ) self.path_prefix = pathlib.Path(path_prefix) self.subpop_struct = subpopulation_structure.SubpopulationStructure( @@ -112,9 +123,13 @@ def __init__( self.seir_config = config["seir"] self.parameters_config = config["seir"]["parameters"] self.initial_conditions_config = ( - config["initial_conditions"] if config["initial_conditions"].exists() else None + config["initial_conditions"] + if config["initial_conditions"].exists() + else None + ) + self.seeding_config = ( + config["seeding"] if config["seeding"].exists() else None ) - self.seeding_config = config["seeding"] if config["seeding"].exists() else None if self.seeding_config is None and self.initial_conditions_config is None: logging.critical( @@ -130,25 +145,36 @@ def __init__( subpop_names=self.subpop_struct.subpop_names, path_prefix=self.path_prefix, ) - self.seeding = seeding.SeedingFactory(config=self.seeding_config, path_prefix=self.path_prefix) + self.seeding = seeding.SeedingFactory( + config=self.seeding_config, path_prefix=self.path_prefix + ) self.initial_conditions = initial_conditions.InitialConditionsFactory( config=self.initial_conditions_config, path_prefix=self.path_prefix ) # really ugly references to the config globally here. if config["compartments"].exists() and self.seir_config is not None: self.compartments = compartments.Compartments( - seir_config=self.seir_config, compartments_config=config["compartments"] + seir_config=self.seir_config, + compartments_config=config["compartments"], ) # SEIR modifiers self.npi_config_seir = None if config["seir_modifiers"].exists(): if config["seir_modifiers"]["scenarios"].exists(): - self.npi_config_seir = config["seir_modifiers"]["modifiers"][seir_modifiers_scenario] - self.seir_modifiers_library = config["seir_modifiers"]["modifiers"].get() + self.npi_config_seir = config["seir_modifiers"]["modifiers"][ + seir_modifiers_scenario + ] + self.seir_modifiers_library = config["seir_modifiers"][ + "modifiers" + ].get() else: - self.seir_modifiers_library = config["seir_modifiers"]["modifiers"].get() - raise ValueError("Not implemented yet") # TODO create a Stacked from all + self.seir_modifiers_library = config["seir_modifiers"][ + "modifiers" + ].get() + raise ValueError( + "Not implemented yet" + ) # TODO create a Stacked from all elif self.seir_modifiers_scenario is not None: raise ValueError( "An seir modifiers scenario was provided to ModelInfo but no 'seir_modifiers' sections in config" @@ -157,21 +183,33 @@ def __init__( logging.info("Running ModelInfo with seir but without SEIR Modifiers") elif self.seir_modifiers_scenario is not None: - raise ValueError("A seir modifiers scenario was provided to ModelInfo but no 'seir:' sections in config") + raise ValueError( + "A seir modifiers scenario was provided to ModelInfo but no 'seir:' sections in config" + ) else: logging.critical("Running ModelInfo without SEIR") # 5. Outcomes - self.outcomes_config = config["outcomes"] if config["outcomes"].exists() else None + self.outcomes_config = ( + config["outcomes"] if config["outcomes"].exists() else None + ) if self.outcomes_config is not None: self.npi_config_outcomes = None if config["outcome_modifiers"].exists(): if config["outcome_modifiers"]["scenarios"].exists(): - self.npi_config_outcomes = config["outcome_modifiers"]["modifiers"][self.outcome_modifiers_scenario] - self.outcome_modifiers_library = config["outcome_modifiers"]["modifiers"].get() + self.npi_config_outcomes = config["outcome_modifiers"]["modifiers"][ + self.outcome_modifiers_scenario + ] + self.outcome_modifiers_library = config["outcome_modifiers"][ + "modifiers" + ].get() else: - self.outcome_modifiers_library = config["outcome_modifiers"]["modifiers"].get() - raise ValueError("Not implemented yet") # TODO create a Stacked from all + self.outcome_modifiers_library = config["outcome_modifiers"][ + "modifiers" + ].get() + raise ValueError( + "Not implemented yet" + ) # TODO create a Stacked from all ## NEED TO IMPLEMENT THIS -- CURRENTLY CANNOT USE outcome modifiers elif self.outcome_modifiers_scenario is not None: @@ -182,7 +220,9 @@ def __init__( else: self.outcome_modifiers_scenario = None else: - logging.info("Running ModelInfo with outcomes but without Outcomes Modifiers") + logging.info( + "Running ModelInfo with outcomes but without Outcomes Modifiers" + ) elif self.outcome_modifiers_scenario is not None: raise ValueError( "An outcome modifiers scenario was provided to ModelInfo but no 'outcomes:' sections in config" @@ -228,7 +268,9 @@ def __init__( os.makedirs(datadir, exist_ok=True) if self.write_parquet and self.write_csv: - print("Confused between reading .csv or parquet. Assuming input file is .parquet") + print( + "Confused between reading .csv or parquet. Assuming input file is .parquet" + ) if self.write_parquet: self.extension = "parquet" elif self.write_csv: @@ -244,7 +286,9 @@ def get_input_filename(self, ftype: str, sim_id: int, extension_override: str = extension_override=extension_override, ) - def get_output_filename(self, ftype: str, sim_id: int, extension_override: str = ""): + def get_output_filename( + self, ftype: str, sim_id: int, extension_override: str = "" + ): return self.path_prefix / self.get_filename( ftype=ftype, sim_id=sim_id, @@ -252,7 +296,9 @@ def get_output_filename(self, ftype: str, sim_id: int, extension_override: str = extension_override=extension_override, ) - def get_filename(self, ftype: str, sim_id: int, input: bool, extension_override: str = ""): + def get_filename( + self, ftype: str, sim_id: int, input: bool, extension_override: str = "" + ): """return a CSP formated filename.""" if extension_override: # empty strings are Falsy @@ -281,7 +327,9 @@ def get_filename(self, ftype: str, sim_id: int, input: bool, extension_override: def get_setup_name(self): return self.setup_name - def read_simID(self, ftype: str, sim_id: int, input: bool = True, extension_override: str = ""): + def read_simID( + self, ftype: str, sim_id: int, input: bool = True, extension_override: str = "" + ): fname = self.get_filename( ftype=ftype, sim_id=sim_id, diff --git a/flepimop/gempyor_pkg/src/gempyor/outcomes.py b/flepimop/gempyor_pkg/src/gempyor/outcomes.py index 5563f4d85..3c693518c 100644 --- a/flepimop/gempyor_pkg/src/gempyor/outcomes.py +++ b/flepimop/gempyor_pkg/src/gempyor/outcomes.py @@ -22,7 +22,9 @@ def run_parallel_outcomes(modinf, *, sim_id2write, nslots=1, n_jobs=1): sim_id2writes = np.arange(sim_id2write, sim_id2write + modinf.nslots) loaded_values = None - if (n_jobs == 1) or (modinf.nslots == 1): # run single process for debugging/profiling purposes + if (n_jobs == 1) or ( + modinf.nslots == 1 + ): # run single process for debugging/profiling purposes for sim_offset in np.arange(nslots): onerun_delayframe_outcomes( sim_id2write=sim_id2writes[sim_offset], @@ -100,7 +102,9 @@ def onerun_delayframe_outcomes( npi_outcomes = None if modinf.npi_config_outcomes: - npi_outcomes = build_outcome_modifiers(modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi_outcomes = build_outcome_modifiers( + modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) loaded_values = None if load_ID: @@ -117,7 +121,13 @@ def onerun_delayframe_outcomes( ) with Timer("onerun_delayframe_outcomes.postprocess"): - postprocess_and_write(sim_id=sim_id2write, modinf=modinf, outcomes_df=outcomes_df, hpar=hpar, npi=npi_outcomes) + postprocess_and_write( + sim_id=sim_id2write, + modinf=modinf, + outcomes_df=outcomes_df, + hpar=hpar, + npi=npi_outcomes, + ) def read_parameters_from_config(modinf: model_info.ModelInfo): @@ -129,7 +139,10 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): if modinf.outcomes_config["param_from_file"].exists(): if modinf.outcomes_config["param_from_file"].get(): # Load the actual csv file - branching_file = modinf.path_prefix / modinf.outcomes_config["param_subpop_file"].as_str() + branching_file = ( + modinf.path_prefix + / modinf.outcomes_config["param_subpop_file"].as_str() + ) branching_data = pa.parquet.read_table(branching_file).to_pandas() if "relative_probability" not in list(branching_data["quantity"]): raise ValueError( @@ -142,14 +155,18 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): "", end="", ) - branching_data = branching_data[branching_data["subpop"].isin(modinf.subpop_struct.subpop_names)] + branching_data = branching_data[ + branching_data["subpop"].isin(modinf.subpop_struct.subpop_names) + ] print( "Intersect with seir simulation: ", len(branching_data.subpop.unique()), "kept", ) - if len(branching_data.subpop.unique()) != len(modinf.subpop_struct.subpop_names): + if len(branching_data.subpop.unique()) != len( + modinf.subpop_struct.subpop_names + ): raise ValueError( f"Places in seir input files does not correspond to subpops in outcome probability file {branching_file}" ) @@ -162,7 +179,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): src_name = outcomes_config[new_comp]["source"].get() if isinstance(src_name, str): parameters[new_comp]["source"] = src_name - elif ("incidence" in src_name.keys()) or ("prevalence" in src_name.keys()): + elif ("incidence" in src_name.keys()) or ( + "prevalence" in src_name.keys() + ): parameters[new_comp]["source"] = dict(src_name) else: @@ -170,10 +189,16 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"unsure how to read outcome {new_comp}: not a str, nor an incidence or prevalence: {src_name}" ) - parameters[new_comp]["probability"] = outcomes_config[new_comp]["probability"]["value"] - if outcomes_config[new_comp]["probability"]["modifier_parameter"].exists(): + parameters[new_comp]["probability"] = outcomes_config[new_comp][ + "probability" + ]["value"] + if outcomes_config[new_comp]["probability"][ + "modifier_parameter" + ].exists(): parameters[new_comp]["probability::npi_param_name"] = ( - outcomes_config[new_comp]["probability"]["modifier_parameter"].as_str().lower() + outcomes_config[new_comp]["probability"]["modifier_parameter"] + .as_str() + .lower() ) logging.debug( f"probability of outcome {new_comp} is affected by intervention " @@ -181,13 +206,21 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"instead of {new_comp}::probability" ) else: - parameters[new_comp]["probability::npi_param_name"] = f"{new_comp}::probability".lower() + parameters[new_comp][ + "probability::npi_param_name" + ] = f"{new_comp}::probability".lower() if outcomes_config[new_comp]["delay"].exists(): - parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"]["value"] - if outcomes_config[new_comp]["delay"]["modifier_parameter"].exists(): + parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"][ + "value" + ] + if outcomes_config[new_comp]["delay"][ + "modifier_parameter" + ].exists(): parameters[new_comp]["delay::npi_param_name"] = ( - outcomes_config[new_comp]["delay"]["modifier_parameter"].as_str().lower() + outcomes_config[new_comp]["delay"]["modifier_parameter"] + .as_str() + .lower() ) logging.debug( f"delay of outcome {new_comp} is affected by intervention " @@ -195,18 +228,32 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"instead of {new_comp}::delay" ) else: - parameters[new_comp]["delay::npi_param_name"] = f"{new_comp}::delay".lower() + parameters[new_comp][ + "delay::npi_param_name" + ] = f"{new_comp}::delay".lower() else: - logging.critical(f"No delay for outcome {new_comp}, using a 0 delay") + logging.critical( + f"No delay for outcome {new_comp}, using a 0 delay" + ) outcomes_config[new_comp]["delay"] = {"value": 0} - parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"]["value"] - parameters[new_comp]["delay::npi_param_name"] = f"{new_comp}::delay".lower() + parameters[new_comp]["delay"] = outcomes_config[new_comp]["delay"][ + "value" + ] + parameters[new_comp][ + "delay::npi_param_name" + ] = f"{new_comp}::delay".lower() if outcomes_config[new_comp]["duration"].exists(): - parameters[new_comp]["duration"] = outcomes_config[new_comp]["duration"]["value"] - if outcomes_config[new_comp]["duration"]["modifier_parameter"].exists(): + parameters[new_comp]["duration"] = outcomes_config[new_comp][ + "duration" + ]["value"] + if outcomes_config[new_comp]["duration"][ + "modifier_parameter" + ].exists(): parameters[new_comp]["duration::npi_param_name"] = ( - outcomes_config[new_comp]["duration"]["modifier_parameter"].as_str().lower() + outcomes_config[new_comp]["duration"]["modifier_parameter"] + .as_str() + .lower() ) logging.debug( f"duration of outcome {new_comp} is affected by intervention " @@ -214,7 +261,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): f"instead of {new_comp}::duration" ) else: - parameters[new_comp]["duration::npi_param_name"] = f"{new_comp}::duration".lower() + parameters[new_comp][ + "duration::npi_param_name" + ] = f"{new_comp}::duration".lower() if outcomes_config[new_comp]["duration"]["name"].exists(): parameters[new_comp]["outcome_prevalence_name"] = ( @@ -223,7 +272,9 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): ) else: # parameters[class_name]["outcome_prevalence_name"] = new_comp + "_curr" + subclass - parameters[new_comp]["outcome_prevalence_name"] = new_comp + "_curr" + parameters[new_comp]["outcome_prevalence_name"] = ( + new_comp + "_curr" + ) if modinf.outcomes_config["param_from_file"].exists(): if modinf.outcomes_config["param_from_file"].get(): rel_probability = branching_data[ @@ -231,14 +282,22 @@ def read_parameters_from_config(modinf: model_info.ModelInfo): & (branching_data["quantity"] == "relative_probability") ].copy(deep=True) if len(rel_probability) > 0: - logging.debug(f"Using 'param_from_file' for relative probability in outcome {new_comp}") + logging.debug( + f"Using 'param_from_file' for relative probability in outcome {new_comp}" + ) # Sort it in case the relative probablity file is mispecified - rel_probability.subpop = rel_probability.subpop.astype("category") - rel_probability.subpop = rel_probability.subpop.cat.set_categories( - modinf.subpop_struct.subpop_names + rel_probability.subpop = rel_probability.subpop.astype( + "category" + ) + rel_probability.subpop = ( + rel_probability.subpop.cat.set_categories( + modinf.subpop_struct.subpop_names + ) ) rel_probability = rel_probability.sort_values(["subpop"]) - parameters[new_comp]["rel_probability"] = rel_probability["value"].to_numpy() + parameters[new_comp]["rel_probability"] = rel_probability[ + "value" + ].to_numpy() else: logging.debug( f"*NOT* Using 'param_from_file' for relative probability in outcome {new_comp}" @@ -348,7 +407,9 @@ def compute_all_multioutcomes( outcome_name=new_comp, ) else: - raise ValueError(f"Unknown type for seir simulation provided, got f{type(seir_sim)}") + raise ValueError( + f"Unknown type for seir simulation provided, got f{type(seir_sim)}" + ) # we don't keep source in this cases else: # already defined outcomes if source_name in all_data: @@ -358,28 +419,40 @@ def compute_all_multioutcomes( f"ERROR with outcome {new_comp}: the specified source {source_name} is not a dictionnary (for seir outcome) nor an existing pre-identified outcomes." ) - if (loaded_values is not None) and (new_comp in loaded_values["outcome"].values): + if (loaded_values is not None) and ( + new_comp in loaded_values["outcome"].values + ): ## This may be unnecessary probabilities = loaded_values[ - (loaded_values["quantity"] == "probability") & (loaded_values["outcome"] == new_comp) + (loaded_values["quantity"] == "probability") + & (loaded_values["outcome"] == new_comp) + ]["value"].to_numpy() + delays = loaded_values[ + (loaded_values["quantity"] == "delay") + & (loaded_values["outcome"] == new_comp) ]["value"].to_numpy() - delays = loaded_values[(loaded_values["quantity"] == "delay") & (loaded_values["outcome"] == new_comp)][ - "value" - ].to_numpy() else: - probabilities = parameters[new_comp]["probability"].as_random_distribution()( + probabilities = parameters[new_comp][ + "probability" + ].as_random_distribution()( size=len(modinf.subpop_struct.subpop_names) ) # one draw per subpop if "rel_probability" in parameters[new_comp]: - probabilities = probabilities * parameters[new_comp]["rel_probability"] + probabilities = ( + probabilities * parameters[new_comp]["rel_probability"] + ) delays = parameters[new_comp]["delay"].as_random_distribution()( size=len(modinf.subpop_struct.subpop_names) ) # one draw per subpop probabilities[probabilities > 1] = 1 probabilities[probabilities < 0] = 0 - probabilities = np.repeat(probabilities[:, np.newaxis], len(dates), axis=1).T # duplicate in time - delays = np.repeat(delays[:, np.newaxis], len(dates), axis=1).T # duplicate in time + probabilities = np.repeat( + probabilities[:, np.newaxis], len(dates), axis=1 + ).T # duplicate in time + delays = np.repeat( + delays[:, np.newaxis], len(dates), axis=1 + ).T # duplicate in time delays = np.round(delays).astype(int) # Write hpar before NPI subpop_names_len = len(modinf.subpop_struct.subpop_names) @@ -387,7 +460,7 @@ def compute_all_multioutcomes( { "subpop": 2 * modinf.subpop_struct.subpop_names, "quantity": (subpop_names_len * ["probability"]) - + (subpop_names_len * ["delay"]), + + (subpop_names_len * ["delay"]), "outcome": 2 * subpop_names_len * [new_comp], "value": np.concatenate( ( @@ -402,42 +475,61 @@ def compute_all_multioutcomes( if npi is not None: delays = NPI.reduce_parameter( parameter=delays, - modification=npi.getReduction(parameters[new_comp]["delay::npi_param_name"].lower()), + modification=npi.getReduction( + parameters[new_comp]["delay::npi_param_name"].lower() + ), ) delays = np.round(delays).astype(int) probabilities = NPI.reduce_parameter( parameter=probabilities, - modification=npi.getReduction(parameters[new_comp]["probability::npi_param_name"].lower()), + modification=npi.getReduction( + parameters[new_comp]["probability::npi_param_name"].lower() + ), ) # Create new compartment incidence: all_data[new_comp] = np.empty_like(source_array) # Draw with from source compartment if modinf.stoch_traj_flag: - all_data[new_comp] = np.random.binomial(source_array.astype(np.int32), probabilities) + all_data[new_comp] = np.random.binomial( + source_array.astype(np.int32), probabilities + ) else: - all_data[new_comp] = source_array * (probabilities * np.ones_like(source_array)) + all_data[new_comp] = source_array * ( + probabilities * np.ones_like(source_array) + ) # Shift to account for the delay ## stoch_delay_flag is whether to use stochastic delays or not stoch_delay_flag = False - all_data[new_comp] = multishift(all_data[new_comp], delays, stoch_delay_flag=stoch_delay_flag) + all_data[new_comp] = multishift( + all_data[new_comp], delays, stoch_delay_flag=stoch_delay_flag + ) # Produce a dataframe an merge it - df_p = dataframe_from_array(all_data[new_comp], modinf.subpop_struct.subpop_names, dates, new_comp) + df_p = dataframe_from_array( + all_data[new_comp], modinf.subpop_struct.subpop_names, dates, new_comp + ) outcomes = pd.merge(outcomes, df_p) # Make duration if "duration" in parameters[new_comp]: - if (loaded_values is not None) and (new_comp in loaded_values["outcome"].values): + if (loaded_values is not None) and ( + new_comp in loaded_values["outcome"].values + ): durations = loaded_values[ - (loaded_values["quantity"] == "duration") & (loaded_values["outcome"] == new_comp) + (loaded_values["quantity"] == "duration") + & (loaded_values["outcome"] == new_comp) ]["value"].to_numpy() else: - durations = parameters[new_comp]["duration"].as_random_distribution()( + durations = parameters[new_comp][ + "duration" + ].as_random_distribution()( size=len(modinf.subpop_struct.subpop_names) ) # one draw per subpop - durations = np.repeat(durations[:, np.newaxis], len(dates), axis=1).T # duplicate in time + durations = np.repeat( + durations[:, np.newaxis], len(dates), axis=1 + ).T # duplicate in time durations = np.round(durations).astype(int) hpar = pd.DataFrame( data={ @@ -458,7 +550,9 @@ def compute_all_multioutcomes( # print(f"{new_comp}-duration".lower(), npi.getReduction(f"{new_comp}-duration".lower())) durations = NPI.reduce_parameter( parameter=durations, - modification=npi.getReduction(parameters[new_comp]["duration::npi_param_name"].lower()), + modification=npi.getReduction( + parameters[new_comp]["duration::npi_param_name"].lower() + ), ) # npi.getReduction(f"{new_comp}::duration".lower())) durations = np.round(durations).astype(int) # plt.imshow(durations) @@ -492,7 +586,9 @@ def compute_all_multioutcomes( for cmp in parameters[new_comp]["sum"]: sum_outcome += all_data[cmp] all_data[new_comp] = sum_outcome - df_p = dataframe_from_array(sum_outcome, modinf.subpop_struct.subpop_names, dates, new_comp) + df_p = dataframe_from_array( + sum_outcome, modinf.subpop_struct.subpop_names, dates, new_comp + ) outcomes = pd.merge(outcomes, df_p) # Concat our hpar dataframes hpar = ( @@ -525,7 +621,9 @@ def filter_seir_df(diffI, dates, subpops, filters, outcome_name) -> np.ndarray: df = df[df[f"mc_{mc_type}"].isin(mc_value)] for mcn in df["mc_name"].unique(): new_df = df[df["mc_name"] == mcn] - new_df = new_df.drop(["date"] + [c for c in new_df.columns if "mc_" in c], axis=1) + new_df = new_df.drop( + ["date"] + [c for c in new_df.columns if "mc_" in c], axis=1 + ) # new_df = new_df.drop("date", axis=1) incidI_arr = incidI_arr + new_df.to_numpy() return incidI_arr @@ -554,7 +652,9 @@ def filter_seir_xr(diffI, dates, subpops, filters, outcome_name) -> np.ndarray: if isinstance(mc_value, str): mc_value = [mc_value] # Filter data along the specified mc_type dimension - diffI_filtered = diffI_filtered.where(diffI_filtered[f"mc_{mc_type}"].isin(mc_value), drop=True) + diffI_filtered = diffI_filtered.where( + diffI_filtered[f"mc_{mc_type}"].isin(mc_value), drop=True + ) # Sum along the compartment dimension incidI_arr += diffI_filtered[vtype].sum(dim="compartment") @@ -626,7 +726,9 @@ def multishift(arr, shifts, stoch_delay_flag=True): # for k,case in enumerate(cases): # results[i+k][j] = cases[k] else: - for i in range(arr.shape[0]): # numba nopython does not allow iterating over 2D array + for i in range( + arr.shape[0] + ): # numba nopython does not allow iterating over 2D array for j in range(arr.shape[1]): if i + shifts[i, j] < arr.shape[0]: result[i + shifts[i, j], j] += arr[i, j] diff --git a/flepimop/gempyor_pkg/src/gempyor/parameters.py b/flepimop/gempyor_pkg/src/gempyor/parameters.py index b01fb52ab..b093322c9 100644 --- a/flepimop/gempyor_pkg/src/gempyor/parameters.py +++ b/flepimop/gempyor_pkg/src/gempyor/parameters.py @@ -41,12 +41,18 @@ def __init__( self.pdata = {} self.pnames2pindex = {} - self.stacked_modifier_method = {"sum": [], "product": [], "reduction_product": []} + self.stacked_modifier_method = { + "sum": [], + "product": [], + "reduction_product": [], + } self.pnames = self.pconfig.keys() self.npar = len(self.pnames) if self.npar != len(set([name.lower() for name in self.pnames])): - raise ValueError("Parameters of the SEIR model have the same name (remember that case is not sufficient!)") + raise ValueError( + "Parameters of the SEIR model have the same name (remember that case is not sufficient!)" + ) # Attributes of dictionary for idx, pn in enumerate(self.pnames): @@ -56,19 +62,29 @@ def __init__( # Parameter characterized by it's distribution if self.pconfig[pn]["value"].exists(): - self.pdata[pn]["dist"] = self.pconfig[pn]["value"].as_random_distribution() + self.pdata[pn]["dist"] = self.pconfig[pn][ + "value" + ].as_random_distribution() # Parameter given as a file elif self.pconfig[pn]["timeseries"].exists(): - fn_name = os.path.join(path_prefix, self.pconfig[pn]["timeseries"].get()) + fn_name = os.path.join( + path_prefix, self.pconfig[pn]["timeseries"].get() + ) df = utils.read_df(fn_name).set_index("date") df.index = pd.to_datetime(df.index) - if len(df.columns) == 1: # if only one ts, assume it applies to all subpops + if ( + len(df.columns) == 1 + ): # if only one ts, assume it applies to all subpops df = pd.DataFrame( - pd.concat([df] * len(subpop_names), axis=1).values, index=df.index, columns=subpop_names + pd.concat([df] * len(subpop_names), axis=1).values, + index=df.index, + columns=subpop_names, ) elif len(df.columns) >= len(subpop_names): # one ts per subpop - df = df[subpop_names] # make sure the order of subpops is the same as the reference + df = df[ + subpop_names + ] # make sure the order of subpops is the same as the reference # (subpop_names from spatial setup) and select the columns else: print("loaded col :", sorted(list(df.columns))) @@ -102,15 +118,23 @@ def __init__( self.pdata[pn]["ts"] = df if self.pconfig[pn]["stacked_modifier_method"].exists(): - self.pdata[pn]["stacked_modifier_method"] = self.pconfig[pn]["stacked_modifier_method"].as_str() + self.pdata[pn]["stacked_modifier_method"] = self.pconfig[pn][ + "stacked_modifier_method" + ].as_str() else: self.pdata[pn]["stacked_modifier_method"] = "product" - logging.debug(f"No 'stacked_modifier_method' for parameter {pn}, assuming multiplicative NPIs") + logging.debug( + f"No 'stacked_modifier_method' for parameter {pn}, assuming multiplicative NPIs" + ) if self.pconfig[pn]["rolling_mean_windows"].exists(): - self.pdata[pn]["rolling_mean_windows"] = self.pconfig[pn]["rolling_mean_windows"].get() + self.pdata[pn]["rolling_mean_windows"] = self.pconfig[pn][ + "rolling_mean_windows" + ].get() - self.stacked_modifier_method[self.pdata[pn]["stacked_modifier_method"]].append(pn.lower()) + self.stacked_modifier_method[ + self.pdata[pn]["stacked_modifier_method"] + ].append(pn.lower()) logging.debug(f"We have {self.npar} parameter: {self.pnames}") logging.debug(f"Data to sample is: {self.pdata}") @@ -146,7 +170,9 @@ def parameters_quick_draw(self, n_days: int, nsubpops: int) -> ndarray: return param_arr # we don't store it as a member because this object needs to be small to be pickable - def parameters_load(self, param_df: pd.DataFrame, n_days: int, nsubpops: int) -> ndarray: + def parameters_load( + self, param_df: pd.DataFrame, n_days: int, nsubpops: int + ) -> ndarray: """ drop-in equivalent to param_quick_draw() that take a file as written parameter_write() :param fname: @@ -165,7 +191,9 @@ def parameters_load(self, param_df: pd.DataFrame, n_days: int, nsubpops: int) -> elif "ts" in self.pdata[pn]: param_arr[idx] = self.pdata[pn]["ts"].values else: - print(f"PARAM: parameter {pn} NOT found in loadID file. Drawing from config distribution") + print( + f"PARAM: parameter {pn} NOT found in loadID file. Drawing from config distribution" + ) pval = self.pdata[pn]["dist"]() param_arr[idx] = np.full((n_days, nsubpops), pval) @@ -179,9 +207,15 @@ def getParameterDF(self, p_draw: ndarray) -> pd.DataFrame: """ # we don't write to disk time series parameters. out_df = pd.DataFrame( - [p_draw[idx, 0, 0] for idx, pn in enumerate(self.pnames) if "dist" in self.pdata[pn]], + [ + p_draw[idx, 0, 0] + for idx, pn in enumerate(self.pnames) + if "dist" in self.pdata[pn] + ], columns=["value"], - index=[pn for idx, pn in enumerate(self.pnames) if "dist" in self.pdata[pn]], + index=[ + pn for idx, pn in enumerate(self.pnames) if "dist" in self.pdata[pn] + ], ) out_df["parameter"] = out_df.index @@ -204,6 +238,8 @@ def parameters_reduce(self, p_draw: ndarray, npi: object) -> ndarray: ) p_reduced[idx] = npi_val if "rolling_mean_windows" in self.pdata[pn]: - p_reduced[idx] = utils.rolling_mean_pad(data=npi_val, window=self.pdata[pn]["rolling_mean_windows"]) + p_reduced[idx] = utils.rolling_mean_pad( + data=npi_val, window=self.pdata[pn]["rolling_mean_windows"] + ) return p_reduced diff --git a/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py b/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py index 2b42e2944..fa5081c0d 100644 --- a/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py +++ b/flepimop/gempyor_pkg/src/gempyor/postprocess_inference.py @@ -21,7 +21,15 @@ import pandas as pd import pyarrow.parquet as pq import xarray as xr -from gempyor import config, model_info, outcomes, seir, inference_parameter, logloss, inference +from gempyor import ( + config, + model_info, + outcomes, + seir, + inference_parameter, + logloss, + inference, +) from gempyor.inference import GempyorInference import tqdm import os @@ -37,7 +45,9 @@ def find_walkers_to_sample(inferpar, sampler_output, nsamples, nwalker, nthin): last_llik = sampler_output.get_log_prob()[-1, :] sampled_slots = last_llik > (last_llik.mean() - 1 * last_llik.std()) - print(f"there are {sampled_slots.sum()}/{len(sampled_slots)} good walkers... keeping these") + print( + f"there are {sampled_slots.sum()}/{len(sampled_slots)} good walkers... keeping these" + ) # TODO this function give back good_samples = sampler.get_chain()[:, sampled_slots, :] @@ -46,13 +56,15 @@ def find_walkers_to_sample(inferpar, sampler_output, nsamples, nwalker, nthin): exported_samples = np.empty((nsamples, inferpar.get_dim())) for i in range(nsamples): exported_samples[i, :] = good_samples[ - step_number - thin * (i // (sampled_slots.sum())), i % (sampled_slots.sum()), : + step_number - thin * (i // (sampled_slots.sum())), + i % (sampled_slots.sum()), + :, ] # parentesis around i//(sampled_slots.sum() are very important - - -def plot_chains(inferpar, chains, llik, save_to, sampled_slots=None, param_gt=None, llik_gt=None): +def plot_chains( + inferpar, chains, llik, save_to, sampled_slots=None, param_gt=None, llik_gt=None +): """ Plot the chains of the inference :param inferpar: the inference parameter object @@ -113,24 +125,47 @@ def plot_single_chain(frompt, ax, chain, label, gt=None): for sp in tqdm.tqdm(set(inferpar.subpops)): # find unique supopulation these_pars = inferpar.get_parameters_for_subpop(sp) - fig, axes = plt.subplots(max(len(these_pars), 2), 2, figsize=(6, (len(these_pars) + 1) * 2)) + fig, axes = plt.subplots( + max(len(these_pars), 2), 2, figsize=(6, (len(these_pars) + 1) * 2) + ) for idx, par_id in enumerate(these_pars): - plot_single_chain(first_thresh, axes[idx, 0], chains[:, :, par_id], labels[par_id], gt=param_gt[par_id] if param_gt is not None else None) - plot_single_chain(second_thresh, axes[idx, 1], chains[:, :, par_id], labels[par_id], gt=param_gt[par_id] if param_gt is not None else None) + plot_single_chain( + first_thresh, + axes[idx, 0], + chains[:, :, par_id], + labels[par_id], + gt=param_gt[par_id] if param_gt is not None else None, + ) + plot_single_chain( + second_thresh, + axes[idx, 1], + chains[:, :, par_id], + labels[par_id], + gt=param_gt[par_id] if param_gt is not None else None, + ) fig.tight_layout() pdf.savefig(fig) plt.close(fig) + def plot_fit(modinf, loss): subpop_names = modinf.subpop_struct.subpop_names fig, axes = plt.subplots( - len(subpop_names), len(loss.statistics), figsize=(3 * len(loss.statistics), 3 * len(subpop_names)), sharex=True + len(subpop_names), + len(loss.statistics), + figsize=(3 * len(loss.statistics), 3 * len(subpop_names)), + sharex=True, ) for j, subpop in enumerate(modinf.subpop_struct.subpop_names): gt_s = loss.gt[loss.gt["subpop"] == subpop].sort_index() first_date = max(gt_s.index.min(), results[0].index.min()) last_date = min(gt_s.index.max(), results[0].index.max()) - gt_s = gt_s.loc[first_date:last_date].drop(["subpop"], axis=1).resample("W-SAT").sum() + gt_s = ( + gt_s.loc[first_date:last_date] + .drop(["subpop"], axis=1) + .resample("W-SAT") + .sum() + ) for i, (stat_name, stat) in enumerate(loss.statistics.items()): ax = axes[j, i] diff --git a/flepimop/gempyor_pkg/src/gempyor/seeding.py b/flepimop/gempyor_pkg/src/gempyor/seeding.py index fe58657c0..f49114e5e 100644 --- a/flepimop/gempyor_pkg/src/gempyor/seeding.py +++ b/flepimop/gempyor_pkg/src/gempyor/seeding.py @@ -17,9 +17,13 @@ def _DataFrame2NumbaDict(df, amounts, modinf) -> nb.typed.Dict: if not df["date"].is_monotonic_increasing: - raise ValueError("_DataFrame2NumbaDict got an unsorted dataframe, exposing itself to non-sense") + raise ValueError( + "_DataFrame2NumbaDict got an unsorted dataframe, exposing itself to non-sense" + ) - cmp_grp_names = [col for col in modinf.compartments.compartments.columns if col != "name"] + cmp_grp_names = [ + col for col in modinf.compartments.compartments.columns if col != "name" + ] seeding_dict: nb.typed.Dict = nb.typed.Dict.empty( key_type=nb.types.unicode_type, value_type=nb.types.int64[:], @@ -45,16 +49,26 @@ def _DataFrame2NumbaDict(df, amounts, modinf) -> nb.typed.Dict: nb_seed_perday[(row["date"].date() - modinf.ti).days] = ( nb_seed_perday[(row["date"].date() - modinf.ti).days] + 1 ) - source_dict = {grp_name: row[f"source_{grp_name}"] for grp_name in cmp_grp_names} - destination_dict = {grp_name: row[f"destination_{grp_name}"] for grp_name in cmp_grp_names} + source_dict = { + grp_name: row[f"source_{grp_name}"] for grp_name in cmp_grp_names + } + destination_dict = { + grp_name: row[f"destination_{grp_name}"] + for grp_name in cmp_grp_names + } seeding_dict["seeding_sources"][idx] = modinf.compartments.get_comp_idx( - source_dict, error_info=f"(seeding source at idx={idx}, row_index={row_index}, row=>>{row}<<)" + source_dict, + error_info=f"(seeding source at idx={idx}, row_index={row_index}, row=>>{row}<<)", + ) + seeding_dict["seeding_destinations"][idx] = ( + modinf.compartments.get_comp_idx( + destination_dict, + error_info=f"(seeding destination at idx={idx}, row_index={row_index}, row=>>{row}<<)", + ) ) - seeding_dict["seeding_destinations"][idx] = modinf.compartments.get_comp_idx( - destination_dict, - error_info=f"(seeding destination at idx={idx}, row_index={row_index}, row=>>{row}<<)", + seeding_dict["seeding_subpops"][idx] = ( + modinf.subpop_struct.subpop_names.index(row["subpop"]) ) - seeding_dict["seeding_subpops"][idx] = modinf.subpop_struct.subpop_names.index(row["subpop"]) seeding_amounts[idx] = amounts[idx] # id_seed+=1 else: @@ -97,7 +111,9 @@ def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict: ) dupes = seeding[seeding.duplicated(["subpop", "date"])].index + 1 if not dupes.empty: - raise ValueError(f"Repeated subpop-date in rows {dupes.tolist()} of seeding::lambda_file.") + raise ValueError( + f"Repeated subpop-date in rows {dupes.tolist()} of seeding::lambda_file." + ) elif method == "FolderDraw": seeding = pd.read_csv( self.path_prefix @@ -127,7 +143,9 @@ def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict: # print(seeding.shape) seeding = seeding.sort_values(by="date", axis="index").reset_index() # print(seeding) - mask = (seeding["date"].dt.date > modinf.ti) & (seeding["date"].dt.date <= modinf.tf) + mask = (seeding["date"].dt.date > modinf.ti) & ( + seeding["date"].dt.date <= modinf.tf + ) seeding = seeding.loc[mask].reset_index() # print(seeding.shape) # print(seeding) @@ -138,7 +156,9 @@ def get_from_config(self, sim_id: int, modinf) -> nb.typed.Dict: if method == "PoissonDistributed": amounts = np.random.poisson(seeding["amount"]) elif method == "NegativeBinomialDistributed": - raise ValueError("Seeding method 'NegativeBinomialDistributed' is not supported by flepiMoP anymore.") + raise ValueError( + "Seeding method 'NegativeBinomialDistributed' is not supported by flepiMoP anymore." + ) elif method == "FolderDraw" or method == "FromFile": amounts = seeding["amount"] else: diff --git a/flepimop/gempyor_pkg/src/gempyor/seir.py b/flepimop/gempyor_pkg/src/gempyor/seir.py index 4e59761f2..d374bed25 100644 --- a/flepimop/gempyor_pkg/src/gempyor/seir.py +++ b/flepimop/gempyor_pkg/src/gempyor/seir.py @@ -41,7 +41,9 @@ def build_step_source_arg( else: integration_method = "rk4.jit" dt = 2.0 - logging.info(f"Integration method not provided, assuming type {integration_method} with dt=2") + logging.info( + f"Integration method not provided, assuming type {integration_method} with dt=2" + ) ## The type is very important for the call to the compiled function, and e.g mixing an int64 for an int32 can ## result in serious error. Note that "In Microsoft C, even on a 64 bit system, the size of the long int data type @@ -58,7 +60,10 @@ def build_step_source_arg( assert type(transition_array[0][0]) == np.int64 assert type(proportion_array[0]) == np.int64 assert type(proportion_info[0][0]) == np.int64 - assert initial_conditions.shape == (modinf.compartments.compartments.shape[0], modinf.nsubpops) + assert initial_conditions.shape == ( + modinf.compartments.compartments.shape[0], + modinf.nsubpops, + ) assert type(initial_conditions[0][0]) == np.float64 # Test of empty seeding: assert len(seeding_data.keys()) == 4 @@ -150,7 +155,9 @@ def steps_SEIR( else: from .dev import steps as steps_experimental - logging.critical("Experimental !!! These methods are not ready for production ! ") + logging.critical( + "Experimental !!! These methods are not ready for production ! " + ) if integration_method in [ "scipy.solve_ivp", "scipy.odeint", @@ -162,7 +169,9 @@ def steps_SEIR( f"with method {integration_method}, only deterministic " f"integration is possible (got stoch_straj_flag={modinf.stoch_traj_flag}" ) - seir_sim = steps_experimental.ode_integration(**fnct_args, integration_method=integration_method) + seir_sim = steps_experimental.ode_integration( + **fnct_args, integration_method=integration_method + ) elif integration_method == "rk4.jit1": seir_sim = steps_experimental.rk4_integration1(**fnct_args) elif integration_method == "rk4.jit2": @@ -200,13 +209,17 @@ def steps_SEIR( **compartment_coords, subpop=modinf.subpop_struct.subpop_names, ), - attrs=dict(description="Dynamical simulation results", run_id=modinf.in_run_id), # TODO add more information + attrs=dict( + description="Dynamical simulation results", run_id=modinf.in_run_id + ), # TODO add more information ) return states -def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_FN=None): +def build_npi_SEIR( + modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_FN=None +): with Timer("SEIR.NPI"): loaded_df = None if bypass_DF is not None: @@ -223,8 +236,12 @@ def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_ modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, loaded_df=loaded_df, - pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method[ + "sum" + ], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) else: npi = NPI.NPIBase.execute( @@ -232,8 +249,12 @@ def build_npi_SEIR(modinf, load_ID, sim_id2load, config, bypass_DF=None, bypass_ modinf=modinf, modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, - pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method[ + "sum" + ], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) return npi @@ -248,7 +269,9 @@ def onerun_SEIR( np.random.seed() npi = None if modinf.npi_config_seir: - npi = build_npi_SEIR(modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config) + npi = build_npi_SEIR( + modinf=modinf, load_ID=load_ID, sim_id2load=sim_id2load, config=config + ) with Timer("onerun_SEIR.compartments"): ( @@ -260,11 +283,19 @@ def onerun_SEIR( with Timer("onerun_SEIR.seeding"): if load_ID: - initial_conditions = modinf.initial_conditions.get_from_file(sim_id2load, modinf=modinf) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id2load, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_file( + sim_id2load, modinf=modinf + ) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id2load, modinf=modinf + ) else: - initial_conditions = modinf.initial_conditions.get_from_config(sim_id2write, modinf=modinf) - seeding_data, seeding_amounts = modinf.seeding.get_from_config(sim_id2write, modinf=modinf) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id2write, modinf=modinf + ) + seeding_data, seeding_amounts = modinf.seeding.get_from_config( + sim_id2write, modinf=modinf + ) with Timer("onerun_SEIR.parameters"): # Draw or load parameters @@ -275,14 +306,18 @@ def onerun_SEIR( nsubpops=modinf.nsubpops, ) else: - p_draw = modinf.parameters.parameters_quick_draw(n_days=modinf.n_days, nsubpops=modinf.nsubpops) + p_draw = modinf.parameters.parameters_quick_draw( + n_days=modinf.n_days, nsubpops=modinf.nsubpops + ) # reduce them parameters = modinf.parameters.parameters_reduce(p_draw, npi) log_debug_parameters(p_draw, "Parameters without seir_modifiers") log_debug_parameters(parameters, "Parameters with seir_modifiers") # Parse them - parsed_parameters = modinf.compartments.parse_parameters(parameters, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + parameters, modinf.parameters.pnames, unique_strings + ) log_debug_parameters(parsed_parameters, "Unique Parameters used by transitions") with Timer("onerun_SEIR.compute"): @@ -310,7 +345,13 @@ def run_parallel_SEIR(modinf, config, *, n_jobs=1): if n_jobs == 1: # run single process for debugging/profiling purposes for sim_id in tqdm.tqdm(sim_ids): - onerun_SEIR(sim_id2write=sim_id, modinf=modinf, load_ID=False, sim_id2load=None, config=config) + onerun_SEIR( + sim_id2write=sim_id, + modinf=modinf, + load_ID=False, + sim_id2load=None, + config=config, + ) else: tqdm.contrib.concurrent.process_map( onerun_SEIR, @@ -322,7 +363,9 @@ def run_parallel_SEIR(modinf, config, *, n_jobs=1): max_workers=n_jobs, ) - logging.info(f""">> {modinf.nslots} seir simulations completed in {time.monotonic() - start:.1f} seconds""") + logging.info( + f""">> {modinf.nslots} seir simulations completed in {time.monotonic() - start:.1f} seconds""" + ) def states2Df(modinf, states): @@ -337,12 +380,17 @@ def states2Df(modinf, states): # states_diff = np.diff(states_diff, axis=0) ts_index = pd.MultiIndex.from_product( - [pd.date_range(modinf.ti, modinf.tf, freq="D"), modinf.compartments.compartments["name"]], + [ + pd.date_range(modinf.ti, modinf.tf, freq="D"), + modinf.compartments.compartments["name"], + ], names=["date", "mc_name"], ) # prevalence data, we use multi.index dataframe, sparring us the array manipulation we use to do prev_df = pd.DataFrame( - data=states["prevalence"].to_numpy().reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), + data=states["prevalence"] + .to_numpy() + .reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), index=ts_index, columns=modinf.subpop_struct.subpop_names, ).reset_index() @@ -355,12 +403,17 @@ def states2Df(modinf, states): prev_df.insert(loc=0, column="mc_value_type", value="prevalence") ts_index = pd.MultiIndex.from_product( - [pd.date_range(modinf.ti, modinf.tf, freq="D"), modinf.compartments.compartments["name"]], + [ + pd.date_range(modinf.ti, modinf.tf, freq="D"), + modinf.compartments.compartments["name"], + ], names=["date", "mc_name"], ) incid_df = pd.DataFrame( - data=states["incidence"].to_numpy().reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), + data=states["incidence"] + .to_numpy() + .reshape(modinf.n_days * modinf.compartments.get_ncomp(), modinf.nsubpops), index=ts_index, columns=modinf.subpop_struct.subpop_names, ).reset_index() @@ -384,7 +437,9 @@ def write_spar_snpi(sim_id, modinf, p_draw, npi): if npi is not None: modinf.write_simID(ftype="snpi", sim_id=sim_id, df=npi.getReductionDF()) # Parameters - modinf.write_simID(ftype="spar", sim_id=sim_id, df=modinf.parameters.getParameterDF(p_draw=p_draw)) + modinf.write_simID( + ftype="spar", sim_id=sim_id, df=modinf.parameters.getParameterDF(p_draw=p_draw) + ) def write_seir(sim_id, modinf, states): diff --git a/flepimop/gempyor_pkg/src/gempyor/simulate.py b/flepimop/gempyor_pkg/src/gempyor/simulate.py index 34fdd9d4b..d97b94a93 100644 --- a/flepimop/gempyor_pkg/src/gempyor/simulate.py +++ b/flepimop/gempyor_pkg/src/gempyor/simulate.py @@ -299,23 +299,31 @@ def simulate( seir_modifiers_scenarios = None if config["seir_modifiers"].exists(): if config["seir_modifiers"]["scenarios"].exists(): - seir_modifiers_scenarios = config["seir_modifiers"]["scenarios"].as_str_seq() + seir_modifiers_scenarios = config["seir_modifiers"][ + "scenarios" + ].as_str_seq() # Model Info handles the case of the default scneario if not outcome_modifiers_scenarios: outcome_modifiers_scenarios = None if config["outcomes"].exists() and config["outcome_modifiers"].exists(): if config["outcome_modifiers"]["scenarios"].exists(): - outcome_modifiers_scenarios = config["outcome_modifiers"]["scenarios"].as_str_seq() + outcome_modifiers_scenarios = config["outcome_modifiers"][ + "scenarios" + ].as_str_seq() outcome_modifiers_scenarios = as_list(outcome_modifiers_scenarios) seir_modifiers_scenarios = as_list(seir_modifiers_scenarios) print(outcome_modifiers_scenarios, seir_modifiers_scenarios) - scenarios_combinations = [[s, d] for s in seir_modifiers_scenarios for d in outcome_modifiers_scenarios] + scenarios_combinations = [ + [s, d] for s in seir_modifiers_scenarios for d in outcome_modifiers_scenarios + ] print("Combination of modifiers scenarios to be run: ") print(scenarios_combinations) for seir_modifiers_scenario, outcome_modifiers_scenario in scenarios_combinations: - print(f"seir_modifier: {seir_modifiers_scenario}, outcomes_modifier:{outcome_modifiers_scenario}") + print( + f"seir_modifier: {seir_modifiers_scenario}, outcomes_modifier:{outcome_modifiers_scenario}" + ) if not nslots: nslots = config["nslots"].as_number() @@ -354,7 +362,9 @@ def simulate( if config["seir"].exists(): seir.run_parallel_SEIR(modinf, config=config, n_jobs=jobs) if config["outcomes"].exists(): - outcomes.run_parallel_outcomes(sim_id2write=first_sim_index, modinf=modinf, nslots=nslots, n_jobs=jobs) + outcomes.run_parallel_outcomes( + sim_id2write=first_sim_index, modinf=modinf, nslots=nslots, n_jobs=jobs + ) print( f">>> {seir_modifiers_scenario}_{outcome_modifiers_scenario} completed in {time.monotonic() - start:.1f} seconds" ) diff --git a/flepimop/gempyor_pkg/src/gempyor/statistics.py b/flepimop/gempyor_pkg/src/gempyor/statistics.py index ea2cc72a3..a9b8ce8a6 100644 --- a/flepimop/gempyor_pkg/src/gempyor/statistics.py +++ b/flepimop/gempyor_pkg/src/gempyor/statistics.py @@ -48,7 +48,10 @@ def __init__(self, name, statistic_config: confuse.ConfigView): if resample_config["aggregator"].exists(): self.resample_aggregator_name = resample_config["aggregator"].get() self.resample_skipna = False # TODO - if resample_config["aggregator"].exists() and resample_config["skipna"].exists(): + if ( + resample_config["aggregator"].exists() + and resample_config["skipna"].exists() + ): self.resample_skipna = resample_config["skipna"].get() self.scale = False @@ -71,7 +74,10 @@ def _forecast_regularize(self, model_data, gt_data, **kwargs): last_n = kwargs.get("last_n", 4) mult = kwargs.get("mult", 2) - last_n_llik = self.llik(model_data.isel(date=slice(-last_n, None)), gt_data.isel(date=slice(-last_n, None))) + last_n_llik = self.llik( + model_data.isel(date=slice(-last_n, None)), + gt_data.isel(date=slice(-last_n, None)), + ) return mult * last_n_llik.sum().sum().values @@ -89,7 +95,9 @@ def __repr__(self) -> str: def apply_resample(self, data): if self.resample: - aggregator_method = getattr(data.resample(date=self.resample_freq), self.resample_aggregator_name) + aggregator_method = getattr( + data.resample(date=self.resample_freq), self.resample_aggregator_name + ) return aggregator_method(skipna=self.resample_skipna) else: return data @@ -113,7 +121,9 @@ def llik(self, model_data: xr.DataArray, gt_data: xr.DataArray): "norm_cov": lambda x, loc, scale: scipy.stats.norm.logpdf( x, loc=loc, scale=scale * loc.where(loc > 5, 5) ), # TODO: check, that it's really the loc - "nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf(x, n=self.params.get("n"), p=model_data), + "nbinom": lambda x, n, p: scipy.stats.nbinom.logpmf( + x, n=self.params.get("n"), p=model_data + ), "rmse": lambda x, y: -np.log(np.nansum(np.sqrt((x - y) ** 2))), "absolute_error": lambda x, y: -np.log(np.nansum(np.abs(x - y))), } @@ -147,6 +157,8 @@ def compute_logloss(self, model_data, gt_data): regularization = 0 for reg_func, reg_config in self.regularizations: - regularization += reg_func(model_data=model_data, gt_data=gt_data, **reg_config) # Pass config parameters + regularization += reg_func( + model_data=model_data, gt_data=gt_data, **reg_config + ) # Pass config parameters return self.llik(model_data, gt_data).sum("date"), regularization diff --git a/flepimop/gempyor_pkg/src/gempyor/steps_rk4.py b/flepimop/gempyor_pkg/src/gempyor/steps_rk4.py index 2cb09c99d..74813be09 100644 --- a/flepimop/gempyor_pkg/src/gempyor/steps_rk4.py +++ b/flepimop/gempyor_pkg/src/gempyor/steps_rk4.py @@ -55,7 +55,11 @@ def rk4_integration( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): proportion_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -63,11 +67,19 @@ def rk4_integration( @jit(nopython=True) def rhs(t, x, today): states_current = np.reshape(x, (2, ncompartments, nspatial_nodes))[0] - st_next = states_current.copy() # this is used to make sure stochastic integration never goes below zero - transition_amounts = np.zeros((ntransitions, nspatial_nodes)) # keep track of the transitions + st_next = ( + states_current.copy() + ) # this is used to make sure stochastic integration never goes below zero + transition_amounts = np.zeros( + (ntransitions, nspatial_nodes) + ) # keep track of the transitions if (x < 0).any(): - print("Integration error: rhs got a negative x (pos, time)", np.where(x < 0), t) + print( + "Integration error: rhs got a negative x (pos, time)", + np.where(x < 0), + t, + ) for transition_index in range(ntransitions): total_rate = np.ones((nspatial_nodes)) @@ -85,56 +97,76 @@ def rhs(t, x, today): proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] # chadi: i believe what this mean that the first proportion is always the # source compartment. That's why there is nothing with n_spatial node here. # but (TODO) we should enforce that ? if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp # does this mean we need the first to be "source" ??? yes ! if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * proportion_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * proportion_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_subpop = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment * ( - relevant_number_in_comp[visiting_subpop] ** relevant_exponent[visiting_subpop] + relevant_number_in_comp[visiting_subpop] + ** relevant_exponent[visiting_subpop] ) rate_change_compartment /= population[visiting_subpop] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_subpop] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_subpop] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) # compute the number of individual transitioning from source to destination from the total rate # number_move has shape (nspatial_nodes) @@ -143,7 +175,9 @@ def rhs(t, x, today): elif method == "legacy": compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) if stochastic_p: - number_move = source_number * compound_adjusted_rate ## to initialize typ + number_move = ( + source_number * compound_adjusted_rate + ) ## to initialize typ for spatial_node in range(nspatial_nodes): number_move[spatial_node] = np.random.binomial( # number_move[spatial_node] = random.binomial( @@ -162,7 +196,9 @@ def rhs(t, x, today): @jit(nopython=True) def update_states(states, delta_t, transition_amounts): - states_diff = np.zeros((2, ncompartments, nspatial_nodes)) # first dim: 0 -> states_diff, 1: states_cum + states_diff = np.zeros( + (2, ncompartments, nspatial_nodes) + ) # first dim: 0 -> states_diff, 1: states_cum st_next = states.copy() st_next = np.reshape(st_next, (2, ncompartments, nspatial_nodes)) if method == "rk4": @@ -181,23 +217,34 @@ def update_states(states, delta_t, transition_amounts): ) if ( transition_amounts[transition_index][spatial_node] - >= st_next[0][transitions[transition_source_col][transition_index]][spatial_node] - float_tolerance + >= st_next[0][transitions[transition_source_col][transition_index]][ + spatial_node + ] + - float_tolerance ): transition_amounts[transition_index][spatial_node] = max( - st_next[0][transitions[transition_source_col][transition_index]][spatial_node] + st_next[0][ + transitions[transition_source_col][transition_index] + ][spatial_node] - float_tolerance, 0, ) - st_next[0][transitions[transition_source_col][transition_index]] -= transition_amounts[transition_index] - st_next[0][transitions[transition_destination_col][transition_index]] += transition_amounts[ - transition_index - ] - - states_diff[0, transitions[transition_source_col][transition_index]] -= transition_amounts[transition_index] - states_diff[0, transitions[transition_destination_col][transition_index]] += transition_amounts[ - transition_index - ] - states_diff[1, transitions[transition_destination_col][transition_index], :] += transition_amounts[ + st_next[0][ + transitions[transition_source_col][transition_index] + ] -= transition_amounts[transition_index] + st_next[0][ + transitions[transition_destination_col][transition_index] + ] += transition_amounts[transition_index] + + states_diff[ + 0, transitions[transition_source_col][transition_index] + ] -= transition_amounts[transition_index] + states_diff[ + 0, transitions[transition_destination_col][transition_index] + ] += transition_amounts[transition_index] + states_diff[ + 1, transitions[transition_destination_col][transition_index], : + ] += transition_amounts[ transition_index ] # Cumumlative @@ -214,7 +261,9 @@ def rk4_integrate(t, x, today): yesterday = -1 times = np.arange(0, (ndays - 1) + 1e-7, dt) - for time_index, time in tqdm.tqdm(enumerate(times), disable=silent): # , total=len(times) + for time_index, time in tqdm.tqdm( + enumerate(times), disable=silent + ): # , total=len(times) today = int(np.floor(time)) is_a_new_day = today != yesterday yesterday = today @@ -224,21 +273,31 @@ def rk4_integrate(t, x, today): states[today, :, :] = states_next for seeding_instance_idx in range( seeding_data["day_start_idx"][today], - seeding_data["day_start_idx"][min(today + int(np.ceil(dt)), len(seeding_data["day_start_idx"]) - 1)], + seeding_data["day_start_idx"][ + min( + today + int(np.ceil(dt)), len(seeding_data["day_start_idx"]) - 1 + ) + ], ): this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts x_ = np.zeros((2, ncompartments, nspatial_nodes)) x_[0] = states_next @@ -271,13 +330,19 @@ def rk4_integrate(t, x, today): error = False ## Perform some checks: if np.isnan(states_daily_incid).any() or np.isnan(states).any(): - logging.critical("Integration error: NaN detected in epidemic integration result. Failing...") + logging.critical( + "Integration error: NaN detected in epidemic integration result. Failing..." + ) error = True if not (np.isfinite(states_daily_incid).all() and np.isfinite(states).all()): - logging.critical("Integration error: Inf detected in epidemic integration result. Failing...") + logging.critical( + "Integration error: Inf detected in epidemic integration result. Failing..." + ) error = True if (states_daily_incid < 0).any() or (states < 0).any(): - logging.critical("Integration error: negative values detected in epidemic integration result. Failing...") + logging.critical( + "Integration error: negative values detected in epidemic integration result. Failing..." + ) # todo: this, but smart so it doesn't fail if empty array # print( # f"STATES: NNZ:{states[states < 0].size}/{states.size}, max:{np.max(states[states < 0])}, min:{np.min(states[states < 0])}, mean:{np.mean(states[states < 0])} median:{np.median(states[states < 0])}" @@ -318,6 +383,8 @@ def rk4_integrate(t, x, today): print( "load the name space with: \nwith open('integration_dump.pkl','rb') as fn_dump:\n states, states_daily_incid, ncompartments, nspatial_nodes, ndays, parameters, dt, transitions, proportion_info, transition_sum_compartments, initial_conditions, seeding_data, seeding_amounts, mobility_data, mobility_row_indices, mobility_data_indices, population, stochastic_p, method = pickle.load(fn_dump)" ) - print("/!\\ Invalid integration, will cause problems for downstream users /!\\ ") + print( + "/!\\ Invalid integration, will cause problems for downstream users /!\\ " + ) # raise ValueError("Invalid Integration...") return states, states_daily_incid diff --git a/flepimop/gempyor_pkg/src/gempyor/steps_source.py b/flepimop/gempyor_pkg/src/gempyor/steps_source.py index b8af1d493..ba46badc7 100644 --- a/flepimop/gempyor_pkg/src/gempyor/steps_source.py +++ b/flepimop/gempyor_pkg/src/gempyor/steps_source.py @@ -30,9 +30,12 @@ ## Return "UniTuple(float64[:, :, :], 2) (" ## return states and cumlative states, both [ ndays x ncompartments x nspatial_nodes ] ## Dimensions - "int32," "int32," "int32," ## ncompartments ## nspatial_nodes ## Number of days + "int32," + "int32," + "int32," ## ncompartments ## nspatial_nodes ## Number of days ## Parameters - "float64[:, :, :]," "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt + "float64[:, :, :]," + "float64," ## Parameters [ nparameters x ndays x nspatial_nodes] ## dt ## Transitions "int64[:, :]," ## transitions [ [source, destination, proportion_start, proportion_stop, rate] x ntransitions ] "int64[:, :]," ## proportions_info [ [sum_starts, sum_stops, exponent] x ntransition_proportions ] @@ -84,7 +87,11 @@ def steps_SEIR_nb( percent_day_away = 0.5 for spatial_node in range(nspatial_nodes): percent_who_move[spatial_node] = min( - mobility_data[mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1]].sum() + mobility_data[ + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] + ].sum() / population[spatial_node], 1, ) @@ -111,18 +118,24 @@ def steps_SEIR_nb( this_seeding_amounts = seeding_amounts[seeding_instance_idx] seeding_subpops = seeding_data["seeding_subpops"][seeding_instance_idx] seeding_sources = seeding_data["seeding_sources"][seeding_instance_idx] - seeding_destinations = seeding_data["seeding_destinations"][seeding_instance_idx] + seeding_destinations = seeding_data["seeding_destinations"][ + seeding_instance_idx + ] # this_seeding_amounts = this_seeding_amounts < states_next[seeding_sources] ? this_seeding_amounts : states_next[seeding_instance_idx] states_next[seeding_sources][seeding_subpops] -= this_seeding_amounts - states_next[seeding_sources][seeding_subpops] = states_next[seeding_sources][seeding_subpops] * ( - states_next[seeding_sources][seeding_subpops] > 0 - ) - states_next[seeding_destinations][seeding_subpops] += this_seeding_amounts + states_next[seeding_sources][seeding_subpops] = states_next[ + seeding_sources + ][seeding_subpops] * (states_next[seeding_sources][seeding_subpops] > 0) + states_next[seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_seeded += this_seeding_amounts times_seeded += 1 # ADD TO cumulative, this is debatable, - states_daily_incid[today][seeding_destinations][seeding_subpops] += this_seeding_amounts + states_daily_incid[today][seeding_destinations][ + seeding_subpops + ] += this_seeding_amounts total_infected = 0 for transition_index in range(ntransitions): @@ -138,52 +151,72 @@ def steps_SEIR_nb( proportion_info[proportion_sum_starts_col][proportion_index], proportion_info[proportion_sum_stops_col][proportion_index], ): - relevant_number_in_comp += states_current[transition_sum_compartments[proportion_sum_index]] + relevant_number_in_comp += states_current[ + transition_sum_compartments[proportion_sum_index] + ] # exponents should not be a proportion, since we don't sum them over sum compartments - relevant_exponent = parameters[proportion_info[proportion_exponent_col][proportion_index]][today] + relevant_exponent = parameters[ + proportion_info[proportion_exponent_col][proportion_index] + ][today] if first_proportion: only_one_proportion = ( - transitions[transition_proportion_start_col][transition_index] + 1 + transitions[transition_proportion_start_col][transition_index] + + 1 ) == transitions[transition_proportion_stop_col][transition_index] first_proportion = False source_number = relevant_number_in_comp if source_number.max() > 0: total_rate[source_number > 0] *= ( - source_number[source_number > 0] ** relevant_exponent[source_number > 0] + source_number[source_number > 0] + ** relevant_exponent[source_number > 0] / source_number[source_number > 0] ) if only_one_proportion: - total_rate *= parameters[transitions[transition_rate_col][transition_index]][today] + total_rate *= parameters[ + transitions[transition_rate_col][transition_index] + ][today] else: for spatial_node in range(nspatial_nodes): - proportion_keep_compartment = 1 - percent_day_away * percent_who_move[spatial_node] + proportion_keep_compartment = ( + 1 - percent_day_away * percent_who_move[spatial_node] + ) proportion_change_compartment = ( percent_day_away * mobility_data[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[ + spatial_node + ] : mobility_data_indices[spatial_node + 1] ] / population[spatial_node] ) rate_keep_compartment = ( proportion_keep_compartment - * relevant_number_in_comp[spatial_node] ** relevant_exponent[spatial_node] + * relevant_number_in_comp[spatial_node] + ** relevant_exponent[spatial_node] / population[spatial_node] - * parameters[transitions[transition_rate_col][transition_index]][today][spatial_node] + * parameters[ + transitions[transition_rate_col][transition_index] + ][today][spatial_node] ) visiting_compartment = mobility_row_indices[ - mobility_data_indices[spatial_node] : mobility_data_indices[spatial_node + 1] + mobility_data_indices[spatial_node] : mobility_data_indices[ + spatial_node + 1 + ] ] rate_change_compartment = proportion_change_compartment rate_change_compartment *= ( - relevant_number_in_comp[visiting_compartment] ** relevant_exponent[visiting_compartment] + relevant_number_in_comp[visiting_compartment] + ** relevant_exponent[visiting_compartment] ) rate_change_compartment /= population[visiting_compartment] - rate_change_compartment *= parameters[transitions[transition_rate_col][transition_index]][ - today - ][visiting_compartment] - total_rate[spatial_node] *= rate_keep_compartment + rate_change_compartment.sum() + rate_change_compartment *= parameters[ + transitions[transition_rate_col][transition_index] + ][today][visiting_compartment] + total_rate[spatial_node] *= ( + rate_keep_compartment + rate_change_compartment.sum() + ) compound_adjusted_rate = 1.0 - np.exp(-dt * total_rate) @@ -221,14 +254,22 @@ def steps_SEIR_nb( for spatial_node in range(nspatial_nodes): if ( number_move[spatial_node] - > states_next[transitions[transition_source_col][transition_index]][spatial_node] - ): - number_move[spatial_node] = states_next[transitions[transition_source_col][transition_index]][ + > states_next[transitions[transition_source_col][transition_index]][ spatial_node ] - states_next[transitions[transition_source_col][transition_index]] -= number_move - states_next[transitions[transition_destination_col][transition_index]] += number_move - states_daily_incid[today, transitions[transition_destination_col][transition_index], :] += number_move + ): + number_move[spatial_node] = states_next[ + transitions[transition_source_col][transition_index] + ][spatial_node] + states_next[ + transitions[transition_source_col][transition_index] + ] -= number_move + states_next[ + transitions[transition_destination_col][transition_index] + ] += number_move + states_daily_incid[ + today, transitions[transition_destination_col][transition_index], : + ] += number_move states_current = states_next.copy() diff --git a/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py b/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py index 1e2b9b8de..58192bb9e 100644 --- a/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py +++ b/flepimop/gempyor_pkg/src/gempyor/subpopulation_structure.py @@ -27,7 +27,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): self.setup_name = setup_name self.data = pd.read_csv( - geodata_file, converters={subpop_names_key: lambda x: str(x).strip()}, skipinitialspace=True + geodata_file, + converters={subpop_names_key: lambda x: str(x).strip()}, + skipinitialspace=True, ) # subpops and populations, strip whitespaces self.nsubpops = len(self.data) # K = # of locations @@ -44,7 +46,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): # subpop_names_key is the name of the column in geodata_file with subpops if subpop_names_key not in self.data: - raise ValueError(f"subpop_names_key: {subpop_names_key} does not correspond to a column in geodata.") + raise ValueError( + f"subpop_names_key: {subpop_names_key} does not correspond to a column in geodata." + ) self.subpop_names = self.data[subpop_names_key].tolist() if len(self.subpop_names) != len(set(self.subpop_names)): raise ValueError(f"There are duplicate subpop_names in geodata.") @@ -53,7 +57,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): mobility_file = path_prefix / subpop_config["mobility"].get() mobility_file = pathlib.Path(mobility_file) if mobility_file.suffix == ".txt": - print("Mobility files as matrices are not recommended. Please switch soon to long form csv files.") + print( + "Mobility files as matrices are not recommended. Please switch soon to long form csv files." + ) self.mobility = scipy.sparse.csr_matrix( np.loadtxt(mobility_file), dtype=int ) # K x K matrix of people moving @@ -64,17 +70,28 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): ) elif mobility_file.suffix == ".csv": - mobility_data = pd.read_csv(mobility_file, converters={"ori": str, "dest": str}, skipinitialspace=True) + mobility_data = pd.read_csv( + mobility_file, + converters={"ori": str, "dest": str}, + skipinitialspace=True, + ) nn_dict = {v: k for k, v in enumerate(self.subpop_names)} - mobility_data["ori_idx"] = mobility_data["ori"].apply(nn_dict.__getitem__) - mobility_data["dest_idx"] = mobility_data["dest"].apply(nn_dict.__getitem__) + mobility_data["ori_idx"] = mobility_data["ori"].apply( + nn_dict.__getitem__ + ) + mobility_data["dest_idx"] = mobility_data["dest"].apply( + nn_dict.__getitem__ + ) if any(mobility_data["ori_idx"] == mobility_data["dest_idx"]): raise ValueError( f"Mobility fluxes with same origin and destination in long form matrix. This is not supported" ) self.mobility = scipy.sparse.coo_matrix( - (mobility_data.amount, (mobility_data.ori_idx, mobility_data.dest_idx)), + ( + mobility_data.amount, + (mobility_data.ori_idx, mobility_data.dest_idx), + ), shape=(self.nsubpops, self.nsubpops), dtype=int, ).tocsr() @@ -115,7 +132,9 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): ) else: logging.critical("No mobility matrix specified -- assuming no one moves") - self.mobility = scipy.sparse.csr_matrix(np.zeros((self.nsubpops, self.nsubpops)), dtype=int) + self.mobility = scipy.sparse.csr_matrix( + np.zeros((self.nsubpops, self.nsubpops)), dtype=int + ) if subpop_config["selected"].exists(): selected = subpop_config["selected"].get() @@ -129,4 +148,6 @@ def __init__(self, *, setup_name, subpop_config, path_prefix=pathlib.Path(".")): self.subpop_names = selected self.nsubpops = len(self.data) # TODO: this needs to be tested - self.mobility = self.mobility[selected_subpop_indices][:, selected_subpop_indices] + self.mobility = self.mobility[selected_subpop_indices][ + :, selected_subpop_indices + ] diff --git a/flepimop/gempyor_pkg/src/gempyor/utils.py b/flepimop/gempyor_pkg/src/gempyor/utils.py index 6131b5c52..873f105d0 100644 --- a/flepimop/gempyor_pkg/src/gempyor/utils.py +++ b/flepimop/gempyor_pkg/src/gempyor/utils.py @@ -103,7 +103,9 @@ def command_safe_run(command, command_name="mycommand", fail_on_fail=True): import subprocess import shlex # using shlex to split the command because it's not obvious https://docs.python.org/3/library/subprocess.html#subprocess.Popen - sr = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + sr = subprocess.Popen( + shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) (stdout, stderr) = sr.communicate() if sr.returncode != 0: print(f"{command_name} failed failed with returncode {sr.returncode}") @@ -137,7 +139,9 @@ def wrapper(*args, **kwargs): return decorator -def search_and_import_plugins_class(plugin_file_path: str, path_prefix: str, class_name: str, **kwargs): +def search_and_import_plugins_class( + plugin_file_path: str, path_prefix: str, class_name: str, **kwargs +): # Look for all possible plugins and import them # https://stackoverflow.com/questions/67631/how-can-i-import-a-module-dynamically-given-the-full-path # unfortunatelly very complicated, this is cpython only ?? @@ -159,7 +163,9 @@ def search_and_import_plugins_class(plugin_file_path: str, path_prefix: str, cla from functools import wraps -def profile(output_file=None, sort_by="cumulative", lines_to_print=None, strip_dirs=False): +def profile( + output_file=None, sort_by="cumulative", lines_to_print=None, strip_dirs=False +): """A time profiler decorator. Inspired by and modified the profile decorator of Giampaolo Rodola: http://code.activestate.com/recipes/577817-profile-decorator/ @@ -307,14 +313,20 @@ def as_random_distribution(self): dist = self["distribution"].get() if dist == "fixed": return functools.partial( - np.random.uniform, self["value"].as_evaled_expression(), self["value"].as_evaled_expression(), + np.random.uniform, + self["value"].as_evaled_expression(), + self["value"].as_evaled_expression(), ) elif dist == "uniform": return functools.partial( - np.random.uniform, self["low"].as_evaled_expression(), self["high"].as_evaled_expression(), + np.random.uniform, + self["low"].as_evaled_expression(), + self["high"].as_evaled_expression(), ) elif dist == "poisson": - return functools.partial(np.random.poisson, self["lam"].as_evaled_expression()) + return functools.partial( + np.random.poisson, self["lam"].as_evaled_expression() + ) elif dist == "binomial": p = self["p"].as_evaled_expression() if (p < 0) or (p > 1): @@ -336,13 +348,18 @@ def as_random_distribution(self): ).rvs elif dist == "lognorm": return get_log_normal( - meanlog=self["meanlog"].as_evaled_expression(), sdlog=self["sdlog"].as_evaled_expression(), + meanlog=self["meanlog"].as_evaled_expression(), + sdlog=self["sdlog"].as_evaled_expression(), ).rvs else: raise NotImplementedError(f"unknown distribution [got: {dist}]") else: # we allow a fixed value specified directly: - return functools.partial(np.random.uniform, self.as_evaled_expression(), self.as_evaled_expression(),) + return functools.partial( + np.random.uniform, + self.as_evaled_expression(), + self.as_evaled_expression(), + ) def list_filenames( @@ -431,14 +448,14 @@ def rolling_mean_pad( [20.2, 21.2, 22.2, 23.2], [22.6, 23.6, 24.6, 25.6]]) """ - weights = (1. / window) * np.ones(window) + weights = (1.0 / window) * np.ones(window) output = scipy.ndimage.convolve1d(data, weights, axis=0, mode="nearest") if window % 2 == 0: rows, cols = data.shape i = rows - 1 - output[i, :] = 0. + output[i, :] = 0.0 window -= 1 - weight = 1. / window + weight = 1.0 / window for l in range(-((window - 1) // 2), 1 + (window // 2)): i_star = min(max(i + l, 0), i) for j in range(cols): @@ -472,7 +489,12 @@ def bash(command): def create_resume_out_filename( - flepi_run_index: str, flepi_prefix: str, flepi_slot_index: str, flepi_block_index: str, filetype: str, liketype: str + flepi_run_index: str, + flepi_prefix: str, + flepi_slot_index: str, + flepi_block_index: str, + filetype: str, + liketype: str, ) -> str: prefix = f"{flepi_prefix}/{flepi_run_index}" inference_filepath_suffix = f"{liketype}/intermediate" @@ -493,7 +515,11 @@ def create_resume_out_filename( def create_resume_input_filename( - resume_run_index: str, flepi_prefix: str, flepi_slot_index: str, filetype: str, liketype: str + resume_run_index: str, + flepi_prefix: str, + flepi_slot_index: str, + filetype: str, + liketype: str, ) -> str: prefix = f"{flepi_prefix}/{resume_run_index}" inference_filepath_suffix = f"{liketype}/final" @@ -511,7 +537,9 @@ def create_resume_input_filename( ) -def get_filetype_for_resume(resume_discard_seeding: str, flepi_block_index: str) -> List[str]: +def get_filetype_for_resume( + resume_discard_seeding: str, flepi_block_index: str +) -> List[str]: """ Retrieves a list of parquet file types that are relevant for resuming a process based on specific environment variable settings. This function dynamically determines the list @@ -564,7 +592,8 @@ def create_resume_file_names_map( behavior. """ file_types = get_filetype_for_resume( - resume_discard_seeding=resume_discard_seeding, flepi_block_index=flepi_block_index + resume_discard_seeding=resume_discard_seeding, + flepi_block_index=flepi_block_index, ) resume_file_name_mapping = dict() liketypes = ["global", "chimeric"] @@ -638,10 +667,12 @@ def download_file_from_s3(name_map: Dict[str, str]) -> None: import boto3 from botocore.exceptions import ClientError except ModuleNotFoundError: - raise ModuleNotFoundError(( - "No module named 'boto3', which is required for " - "gempyor.utils.download_file_from_s3. Please install the aws target." - )) + raise ModuleNotFoundError( + ( + "No module named 'boto3', which is required for " + "gempyor.utils.download_file_from_s3. Please install the aws target." + ) + ) s3 = boto3.client("s3") first_output_filename = next(iter(name_map.values())) output_dir = os.path.dirname(first_output_filename) @@ -664,13 +695,13 @@ def move_file_at_local(name_map: Dict[str, str]) -> None: """ Moves files locally according to a given mapping. - This function takes a dictionary where the keys are source file paths and - the values are destination file paths. It ensures that the destination - directories exist and then copies the files from the source paths to the + This function takes a dictionary where the keys are source file paths and + the values are destination file paths. It ensures that the destination + directories exist and then copies the files from the source paths to the destination paths. Parameters: - name_map (Dict[str, str]): A dictionary mapping source file paths to + name_map (Dict[str, str]): A dictionary mapping source file paths to destination file paths. Returns: diff --git a/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py b/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py index 7e3c2bc59..ecf2a84e9 100644 --- a/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py +++ b/flepimop/gempyor_pkg/tests/npi/test_SinglePeriodModifier.py @@ -49,7 +49,10 @@ def test_SinglePeriodModifier_start_date_fail(self): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config_test.yml") - with pytest.raises(ValueError, match=r".*at least one period start or end date is not between.*"): + with pytest.raises( + ValueError, + match=r".*at least one period start or end date is not between.*", + ): s = model_info.ModelInfo( setup_name="test_seir", config=config, @@ -72,7 +75,10 @@ def test_SinglePeriodModifier_end_date_fail(self): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config_test.yml") - with pytest.raises(ValueError, match=r".*at least one period start or end date is not between.*"): + with pytest.raises( + ValueError, + match=r".*at least one period start or end date is not between.*", + ): s = model_info.ModelInfo( setup_name="test_seir", config=config, diff --git a/flepimop/gempyor_pkg/tests/npi/test_npis.py b/flepimop/gempyor_pkg/tests/npi/test_npis.py index bad306b1e..8a3e5bb2f 100644 --- a/flepimop/gempyor_pkg/tests/npi/test_npis.py +++ b/flepimop/gempyor_pkg/tests/npi/test_npis.py @@ -47,12 +47,18 @@ def test_full_npis_read_write(): # inference_simulator.s, load_ID=False, sim_id2load=None, config=config # ) - inference_simulator.modinf.write_simID(ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF()) + inference_simulator.modinf.write_simID( + ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF() + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() hnpi_read["value"] = np.random.random(len(hnpi_read)) * 2 - 1 out_hnpi = pa.Table.from_pandas(hnpi_read, preserve_index=False) - pa.parquet.write_table(out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet")) + pa.parquet.write_table( + out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet") + ) import random random.seed(10) @@ -74,10 +80,16 @@ def test_full_npis_read_write(): npi_outcomes = outcomes.build_outcome_modifiers( inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config ) - inference_simulator.modinf.write_simID(ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF()) + inference_simulator.modinf.write_simID( + ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF() + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() # runs with the new, random NPI @@ -98,10 +110,16 @@ def test_full_npis_read_write(): npi_outcomes = outcomes.build_outcome_modifiers( inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config ) - inference_simulator.modinf.write_simID(ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF()) + inference_simulator.modinf.write_simID( + ftype="hnpi", sim_id=1, df=npi_outcomes.getReductionDF() + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() @@ -116,10 +134,15 @@ def test_spatial_groups(): ) # Test build from config, value of the reduction array - npi = seir.build_npi_SEIR(inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config) + npi = seir.build_npi_SEIR( + inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config + ) # all independent: r1 - assert len(npi.getReduction("r1")["2021-01-01"].unique()) == inference_simulator.modinf.nsubpops + assert ( + len(npi.getReduction("r1")["2021-01-01"].unique()) + == inference_simulator.modinf.nsubpops + ) assert npi.getReduction("r1").isna().sum().sum() == 0 # all the same: r2 @@ -127,26 +150,41 @@ def test_spatial_groups(): assert npi.getReduction("r2").isna().sum().sum() == 0 # two groups: r3 - assert len(npi.getReduction("r3")["2020-04-15"].unique()) == inference_simulator.modinf.nsubpops - 2 + assert ( + len(npi.getReduction("r3")["2020-04-15"].unique()) + == inference_simulator.modinf.nsubpops - 2 + ) assert npi.getReduction("r3").isna().sum().sum() == 0 - assert len(npi.getReduction("r3").loc[["01000", "02000"], "2020-04-15"].unique()) == 1 - assert len(npi.getReduction("r3").loc[["04000", "06000"], "2020-04-15"].unique()) == 1 + assert ( + len(npi.getReduction("r3").loc[["01000", "02000"], "2020-04-15"].unique()) == 1 + ) + assert ( + len(npi.getReduction("r3").loc[["04000", "06000"], "2020-04-15"].unique()) == 1 + ) # one group: r4 assert ( len(npi.getReduction("r4")["2020-04-15"].unique()) == 4 ) # 0 for these not included, 1 unique for the group, and two for the rest assert npi.getReduction("r4").isna().sum().sum() == 0 - assert len(npi.getReduction("r4").loc[["01000", "02000"], "2020-04-15"].unique()) == 1 - assert len(npi.getReduction("r4").loc[["04000", "06000"], "2020-04-15"].unique()) == 2 + assert ( + len(npi.getReduction("r4").loc[["01000", "02000"], "2020-04-15"].unique()) == 1 + ) + assert ( + len(npi.getReduction("r4").loc[["04000", "06000"], "2020-04-15"].unique()) == 2 + ) assert (npi.getReduction("r4").loc[["05000", "08000"], "2020-04-15"] == 0).all() # mtr group: r5 assert npi.getReduction("r5").isna().sum().sum() == 0 assert len(npi.getReduction("r5")["2020-12-15"].unique()) == 2 assert len(npi.getReduction("r5")["2020-10-15"].unique()) == 4 - assert len(npi.getReduction("r5").loc[["01000", "04000"], "2020-10-15"].unique()) == 1 - assert len(npi.getReduction("r5").loc[["02000", "06000"], "2020-10-15"].unique()) == 2 + assert ( + len(npi.getReduction("r5").loc[["01000", "04000"], "2020-10-15"].unique()) == 1 + ) + assert ( + len(npi.getReduction("r5").loc[["02000", "06000"], "2020-10-15"].unique()) == 2 + ) # test the dataframes that are wrote. npi_df = npi.getReductionDF() @@ -160,7 +198,9 @@ def test_spatial_groups(): # all the same: r2 df = npi_df[npi_df["modifier_name"] == "all_together"] assert len(df) == 1 - assert set(df["subpop"].iloc[0].split(",")) == set(inference_simulator.modinf.subpop_struct.subpop_names) + assert set(df["subpop"].iloc[0].split(",")) == set( + inference_simulator.modinf.subpop_struct.subpop_names + ) assert len(df["subpop"].iloc[0].split(",")) == inference_simulator.modinf.nsubpops # two groups: r3 @@ -175,7 +215,10 @@ def test_spatial_groups(): df = npi_df[npi_df["modifier_name"] == "mt_reduce"] assert len(df) == 4 assert df.subpop.to_list() == ["09000,10000", "02000", "06000", "01000,04000"] - assert df[df["subpop"] == "09000,10000"]["start_date"].iloc[0] == "2020-12-01,2021-12-01" + assert ( + df[df["subpop"] == "09000,10000"]["start_date"].iloc[0] + == "2020-12-01,2021-12-01" + ) assert ( df[df["subpop"] == "01000,04000"]["start_date"].iloc[0] == df[df["subpop"] == "06000"]["start_date"].iloc[0] @@ -194,15 +237,21 @@ def test_spatial_groups(): ) # Test build from config, value of the reduction array - npi = seir.build_npi_SEIR(inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config) + npi = seir.build_npi_SEIR( + inference_simulator.modinf, load_ID=False, sim_id2load=None, config=config + ) npi_df = npi.getReductionDF() inference_simulator.modinf.write_simID(ftype="snpi", sim_id=1, df=npi_df) - snpi_read = pq.read_table(f"{config_filepath_prefix}model_output/snpi/000000001.105.snpi.parquet").to_pandas() + snpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/snpi/000000001.105.snpi.parquet" + ).to_pandas() snpi_read["value"] = np.random.random(len(snpi_read)) * 2 - 1 out_snpi = pa.Table.from_pandas(snpi_read, preserve_index=False) - pa.parquet.write_table(out_snpi, file_paths.create_file_name(106, "", 1, "snpi", "parquet")) + pa.parquet.write_table( + out_snpi, file_paths.create_file_name(106, "", 1, "snpi", "parquet") + ) inference_simulator = gempyor.GempyorInference( config_filepath=f"{config_filepath_prefix}config_test_spatial_group_npi.yml", @@ -213,22 +262,42 @@ def test_spatial_groups(): out_run_id=107, ) - npi_seir = seir.build_npi_SEIR(inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config) - inference_simulator.modinf.write_simID(ftype="snpi", sim_id=1, df=npi_seir.getReductionDF()) + npi_seir = seir.build_npi_SEIR( + inference_simulator.modinf, load_ID=True, sim_id2load=1, config=config + ) + inference_simulator.modinf.write_simID( + ftype="snpi", sim_id=1, df=npi_seir.getReductionDF() + ) - snpi_read = pq.read_table(f"{config_filepath_prefix}model_output/snpi/000000001.106.snpi.parquet").to_pandas() - snpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/snpi/000000001.107.snpi.parquet").to_pandas() + snpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/snpi/000000001.106.snpi.parquet" + ).to_pandas() + snpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/snpi/000000001.107.snpi.parquet" + ).to_pandas() # now the order can change, so we need to sort by subpop and start_date - snpi_wrote = snpi_wrote.sort_values(by=["subpop", "start_date"]).reset_index(drop=True) - snpi_read = snpi_read.sort_values(by=["subpop", "start_date"]).reset_index(drop=True) + snpi_wrote = snpi_wrote.sort_values(by=["subpop", "start_date"]).reset_index( + drop=True + ) + snpi_read = snpi_read.sort_values(by=["subpop", "start_date"]).reset_index( + drop=True + ) assert (snpi_read == snpi_wrote).all().all() npi_read = seir.build_npi_SEIR( - inference_simulator.modinf, load_ID=False, sim_id2load=1, config=config, bypass_DF=snpi_read + inference_simulator.modinf, + load_ID=False, + sim_id2load=1, + config=config, + bypass_DF=snpi_read, ) npi_wrote = seir.build_npi_SEIR( - inference_simulator.modinf, load_ID=False, sim_id2load=1, config=config, bypass_DF=snpi_wrote + inference_simulator.modinf, + load_ID=False, + sim_id2load=1, + config=config, + bypass_DF=snpi_wrote, ) assert (npi_read.getReductionDF() == npi_wrote.getReductionDF()).all().all() diff --git a/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py b/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py index 56df652cf..aa186010c 100644 --- a/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py +++ b/flepimop/gempyor_pkg/tests/outcomes/make_seir_test_file.py @@ -36,7 +36,9 @@ prefix = "" sim_id = 1 -a = pd.read_parquet(file_paths.create_file_name(run_id, prefix, sim_id, "seir", "parquet")) +a = pd.read_parquet( + file_paths.create_file_name(run_id, prefix, sim_id, "seir", "parquet") +) print(a) # created by running SEIR test_seir.py (comment line 530 to remove file tree) first b = pd.read_parquet("../../SEIR/test/model_output/seir/000000101.test.seir.parquet") @@ -54,7 +56,9 @@ diffI = np.arange(5) * 2 date_data = datetime.date(2020, 4, 15) for i in range(5): - b.loc[(b["mc_value_type"] == "incidence") & (b["date"] == str(date_data)), subpop[i]] = diffI[i] + b.loc[ + (b["mc_value_type"] == "incidence") & (b["date"] == str(date_data)), subpop[i] + ] = diffI[i] pa_df = pa.Table.from_pandas(b, preserve_index=False) pa.parquet.write_table(pa_df, "new_test_no_vacc.parquet") diff --git a/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py b/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py index 35f490920..ed0942663 100644 --- a/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py +++ b/flepimop/gempyor_pkg/tests/outcomes/test_outcomes.py @@ -40,87 +40,150 @@ def test_outcome(): stoch_traj_flag=False, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False + ) - hosp = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.1.hosp.parquet").to_pandas() + hosp = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.1.hosp.parquet" + ).to_pandas() hosp.set_index("date", drop=True, inplace=True) for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: assert hosp[hosp["subpop"] == place]["incidI"][dt] == diffI[i] - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == diffI[i] * 0.01 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 * 0.4 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == diffI[i] * 0.01 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == diffI[i] * 0.1 * 0.4 + ) for j in range(7): - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + j)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + j) + ] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidI"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7) + ] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place]["incidH"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidICU"][dt] == 0 - hpar = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet").to_pandas() + hpar = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet" + ).to_pandas() for i, place in enumerate(subpop): assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.1 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 7 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "duration")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "duration") + ]["value"].iloc[0] ) == 7 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.01 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.4 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 0 ) @@ -136,15 +199,23 @@ def test_outcome_modifiers_scenario_with_load(): stoch_traj_flag=False, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False + ) - hpar_config = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet").to_pandas() - hpar_rel = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet").to_pandas() + hpar_config = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.1.hpar.parquet" + ).to_pandas() + hpar_rel = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet" + ).to_pandas() for out in ["incidH", "incidD", "incidICU"]: for i, place in enumerate(subpop): a = hpar_rel[(hpar_rel["outcome"] == out) & (hpar_rel["subpop"] == place)] - b = hpar_config[(hpar_rel["outcome"] == out) & (hpar_config["subpop"] == place)] + b = hpar_config[ + (hpar_rel["outcome"] == out) & (hpar_config["subpop"] == place) + ] assert len(a) == len(b) for j in range(len(a)): if b.iloc[j]["quantity"] in ["delay", "duration"]: @@ -171,16 +242,30 @@ def test_outcomes_read_write_hpar(): stoch_traj_flag=False, out_run_id=3, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.3.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.2.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.3.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.2.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.3.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.2.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.3.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.2.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.3.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.2.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.3.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -201,7 +286,9 @@ def test_multishift_notstochdelays(): [36, 29], ] ) - shifts = np.array([[1, 0], [2, 1], [1, 0], [2, 2], [1, 2], [0, 1], [1, 1], [1, 2], [1, 2], [1, 0]]) + shifts = np.array( + [[1, 0], [2, 1], [1, 0], [2, 2], [1, 2], [0, 1], [1, 1], [1, 2], [1, 2], [1, 0]] + ) expected = np.array( [ [0, 39], @@ -216,7 +303,9 @@ def test_multishift_notstochdelays(): [12, 32], ] ) - assert (outcomes.multishift(array, shifts, stoch_delay_flag=False) == expected).all() + assert ( + outcomes.multishift(array, shifts, stoch_delay_flag=False) == expected + ).all() def test_outcomes_npi(): @@ -230,89 +319,152 @@ def test_outcomes_npi(): stoch_traj_flag=False, out_run_id=105, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf + ) - hosp = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() + hosp = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() hosp.set_index("date", drop=True, inplace=True) # same as config.yaml (doubled, then NPI halve it) for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: assert hosp[hosp["subpop"] == place]["incidI"][dt] == diffI[i] - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == diffI[i] * 0.01 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 * 0.4 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == diffI[i] * 0.01 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == diffI[i] * 0.1 * 0.4 + ) for j in range(7): - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + j)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + j) + ] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidI"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7) + ] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place]["incidH"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidICU"][dt] == 0 - hpar = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() + hpar = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() # Doubled everything from previous config.yaml for i, place in enumerate(subpop): assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.1 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "duration")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "duration") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.01 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 2 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.4 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 0 * 2 ) @@ -330,17 +482,31 @@ def test_outcomes_read_write_hnpi(): out_run_id=106, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -356,17 +522,27 @@ def test_outcomes_read_write_hnpi2(): out_run_id=106, ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() hnpi_read["value"] = np.random.random(len(hnpi_read)) * 2 - 1 out_hnpi = pa.Table.from_pandas(hnpi_read, preserve_index=False) - pa.parquet.write_table(out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet")) + pa.parquet.write_table( + out_hnpi, file_paths.create_file_name(105, "", 1, "hnpi", "parquet") + ) import random random.seed(10) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() # runs with the new, random NPI @@ -378,16 +554,30 @@ def test_outcomes_read_write_hnpi2(): stoch_traj_flag=False, out_run_id=107, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -402,89 +592,152 @@ def test_outcomes_npi_custom_pname(): stoch_traj_flag=False, out_run_id=105, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False, sim_id2load=1 + ) - hosp = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() + hosp = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() hosp.set_index("date", drop=True, inplace=True) # same as config.yaml (doubled, then NPI halve it) for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: assert hosp[hosp["subpop"] == place]["incidI"][dt] == diffI[i] - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == diffI[i] * 0.01 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == diffI[i] * 0.1 * 0.4 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == diffI[i] * 0.01 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == diffI[i] * 0.1 * 0.4 + ) for j in range(7): - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + j)] == diffI[i] * 0.1 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + j) + ] + == diffI[i] * 0.1 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidH"][dt + datetime.timedelta(7)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidI"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place]["incidICU"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["hosp_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt + datetime.timedelta(2)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidICU"][ + dt + datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["hosp_curr"][ + dt + datetime.timedelta(7) + ] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place]["incidH"][dt] == 0 - assert hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place]["incidI"][dt - datetime.timedelta(7)] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place]["incidD"][dt - datetime.timedelta(4)] + == 0 + ) assert hosp[hosp["subpop"] == place]["incidICU"][dt] == 0 - hpar = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() + hpar = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() # Doubled everything from previous config.yaml for i, place in enumerate(subpop): assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.1 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidH") & (hpar["quantity"] == "duration")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidH") + & (hpar["quantity"] == "duration") + ]["value"].iloc[0] ) == 7 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.01 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidD") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidD") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 2 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "probability")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "probability") + ]["value"].iloc[0] ) == 0.4 * 2 ) assert ( float( - hpar[(hpar["subpop"] == place) & (hpar["outcome"] == "incidICU") & (hpar["quantity"] == "delay")][ - "value" - ].iloc[0] + hpar[ + (hpar["subpop"] == place) + & (hpar["outcome"] == "incidICU") + & (hpar["quantity"] == "delay") + ]["value"].iloc[0] ) == 0 * 2 ) @@ -502,16 +755,30 @@ def test_outcomes_read_write_hnpi_custom_pname(): out_run_id=106, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.105.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.105.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -520,10 +787,14 @@ def test_outcomes_read_write_hnpi2_custom_pname(): prefix = "" - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() hnpi_read["value"] = np.random.random(len(hnpi_read)) * 2 - 1 out_hnpi = pa.Table.from_pandas(hnpi_read, preserve_index=False) - pa.parquet.write_table(out_hnpi, file_paths.create_file_name(105, prefix, 1, "hnpi", "parquet")) + pa.parquet.write_table( + out_hnpi, file_paths.create_file_name(105, prefix, 1, "hnpi", "parquet") + ) import random random.seed(10) @@ -537,10 +808,16 @@ def test_outcomes_read_write_hnpi2_custom_pname(): out_run_id=106, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.105.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() # runs with the new, random NPI @@ -553,16 +830,30 @@ def test_outcomes_read_write_hnpi2_custom_pname(): out_run_id=107, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.106.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.107.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.106.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.107.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.106.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.107.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() @@ -580,7 +871,9 @@ def test_outcomes_pcomp(): ) p_compmult = [1, 3] - seir = pq.read_table(f"{config_filepath_prefix}model_output/seir/000000001.105.seir.parquet").to_pandas() + seir = pq.read_table( + f"{config_filepath_prefix}model_output/seir/000000001.105.seir.parquet" + ).to_pandas() seir2 = seir.copy() seir2["mc_vaccination_stage"] = "first_dose" @@ -591,10 +884,16 @@ def test_outcomes_pcomp(): seir2[pl] = seir2[pl] * p_compmult[1] new_seir = pd.concat([seir, seir2]) out_df = pa.Table.from_pandas(new_seir, preserve_index=False) - pa.parquet.write_table(out_df, file_paths.create_file_name(110, prefix, 1, "seir", "parquet")) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False) + pa.parquet.write_table( + out_df, file_paths.create_file_name(110, prefix, 1, "seir", "parquet") + ) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=False + ) - hosp_f = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet").to_pandas() + hosp_f = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet" + ).to_pandas() hosp_f.set_index("date", drop=True, inplace=True) # same as config.yaml (doubled, then NPI halve it) for k, p_comp in enumerate(["0dose", "1dose"]): @@ -602,42 +901,90 @@ def test_outcomes_pcomp(): for i, place in enumerate(subpop): for dt in hosp.index: if dt.date() == date_data: - assert hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt] == diffI[i] * p_compmult[k] assert ( - hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][dt + datetime.timedelta(7)] + hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt] + == diffI[i] * p_compmult[k] + ) + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][ + dt + datetime.timedelta(7) + ] - diffI[i] * 0.1 * p_compmult[k] < 1e-8 ) assert ( - hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][dt + datetime.timedelta(2)] + hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][ + dt + datetime.timedelta(2) + ] - diffI[i] * 0.01 * p_compmult[k] < 1e-8 ) assert ( - hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][dt + datetime.timedelta(7)] + hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][ + dt + datetime.timedelta(7) + ] - diffI[i] * 0.1 * 0.4 * p_compmult[k] < 1e-8 ) for j in range(7): assert ( - hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][dt + datetime.timedelta(7 + j)] + hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][ + dt + datetime.timedelta(7 + j) + ] - diffI[i] * 0.1 * p_compmult[k] < 1e-8 ) - assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][dt + datetime.timedelta(7 + 8)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][ + dt + datetime.timedelta(7 + 8) + ] + == 0 + ) elif dt.date() < date_data: - assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][ + dt + datetime.timedelta(7) + ] + == 0 + ) assert hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt] == 0 - assert hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][dt + datetime.timedelta(2)] == 0 - assert hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][dt + datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][dt + datetime.timedelta(7)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][ + dt + datetime.timedelta(2) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][ + dt + datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place][f"incidH_{p_comp}_curr"][ + dt + datetime.timedelta(7) + ] + == 0 + ) elif dt.date() > (date_data + datetime.timedelta(7)): assert hosp[hosp["subpop"] == place][f"incidH_{p_comp}"][dt] == 0 - assert hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][dt - datetime.timedelta(7)] == 0 - assert hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][dt - datetime.timedelta(4)] == 0 + assert ( + hosp[hosp["subpop"] == place][f"incidI_{p_comp}"][ + dt - datetime.timedelta(7) + ] + == 0 + ) + assert ( + hosp[hosp["subpop"] == place][f"incidD_{p_comp}"][ + dt - datetime.timedelta(4) + ] + == 0 + ) assert hosp[hosp["subpop"] == place][f"incidICU_{p_comp}"][dt] == 0 - hpar_f = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet").to_pandas() + hpar_f = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet" + ).to_pandas() # Doubled everything from previous config.yaml # for k, p_comp in enumerate(["unvaccinated", "first_dose"]): for k, p_comp in enumerate(["0dose", "1dose"]): @@ -727,16 +1074,30 @@ def test_outcomes_pcomp_read_write(): out_run_id=112, ) - outcomes.onerun_delayframe_outcomes(sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1) + outcomes.onerun_delayframe_outcomes( + sim_id2write=1, modinf=inference_simulator.modinf, load_ID=True, sim_id2load=1 + ) - hpar_read = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet").to_pandas() - hpar_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hpar/000000001.112.hpar.parquet").to_pandas() + hpar_read = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.111.hpar.parquet" + ).to_pandas() + hpar_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hpar/000000001.112.hpar.parquet" + ).to_pandas() assert (hpar_read == hpar_wrote).all().all() - hnpi_read = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.111.hnpi.parquet").to_pandas() - hnpi_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hnpi/000000001.112.hnpi.parquet").to_pandas() + hnpi_read = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.111.hnpi.parquet" + ).to_pandas() + hnpi_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hnpi/000000001.112.hnpi.parquet" + ).to_pandas() assert (hnpi_read == hnpi_wrote).all().all() - hosp_read = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet").to_pandas() - hosp_wrote = pq.read_table(f"{config_filepath_prefix}model_output/hosp/000000001.112.hosp.parquet").to_pandas() + hosp_read = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.111.hosp.parquet" + ).to_pandas() + hosp_wrote = pq.read_table( + f"{config_filepath_prefix}model_output/hosp/000000001.112.hosp.parquet" + ).to_pandas() assert (hosp_read == hosp_wrote).all().all() diff --git a/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py b/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py index 1e0915b82..53c34e039 100644 --- a/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py +++ b/flepimop/gempyor_pkg/tests/seir/dev_new_test0.py @@ -42,7 +42,9 @@ def test_parameters_from_timeserie_file(): ) # p = inference_simulator.s.parameters - p_draw = p.parameters_quick_draw(n_days=inference_simulator.s.n_days, nnodes=inference_simulator.s.nnodes) + p_draw = p.parameters_quick_draw( + n_days=inference_simulator.s.n_days, nnodes=inference_simulator.s.nnodes + ) p_df = p.getParameterDF(p_draw)["parameter"] diff --git a/flepimop/gempyor_pkg/tests/seir/test_compartments.py b/flepimop/gempyor_pkg/tests/seir/test_compartments.py index 1d4319e3b..0b40b5423 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_compartments.py +++ b/flepimop/gempyor_pkg/tests/seir/test_compartments.py @@ -10,7 +10,14 @@ import pyarrow.parquet as pq import filecmp -from gempyor import compartments, seir, NPI, file_paths, model_info, subpopulation_structure +from gempyor import ( + compartments, + seir, + NPI, + file_paths, + model_info, + subpopulation_structure, +) from gempyor.utils import config @@ -24,7 +31,9 @@ def test_check_transitions_parquet_creation(): config.set_file(f"{DATA_DIR}/config_compartmental_model_format.yml") original_compartments_file = f"{DATA_DIR}/parsed_compartment_compartments.parquet" original_transitions_file = f"{DATA_DIR}/parsed_compartment_transitions.parquet" - lhs = compartments.Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + lhs = compartments.Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) rhs = compartments.Compartments( seir_config=config["seir"], compartments_file=original_compartments_file, @@ -43,10 +52,16 @@ def test_check_transitions_parquet_writing_and_loading(): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config_compartmental_model_format.yml") - lhs = compartments.Compartments(seir_config=config["seir"], compartments_config=config["compartments"]) + lhs = compartments.Compartments( + seir_config=config["seir"], compartments_config=config["compartments"] + ) temp_compartments_file = f"{DATA_DIR}/parsed_compartment_compartments.test.parquet" temp_transitions_file = f"{DATA_DIR}/parsed_compartment_transitions.test.parquet" - lhs.toFile(compartments_file=temp_compartments_file, transitions_file=temp_transitions_file, write_parquet=True) + lhs.toFile( + compartments_file=temp_compartments_file, + transitions_file=temp_transitions_file, + write_parquet=True, + ) rhs = compartments.Compartments( seir_config=config["seir"], compartments_file=temp_compartments_file, @@ -86,4 +101,3 @@ def test_ModelInfo_has_compartments_component(): ) assert type(s.compartments) == compartments.Compartments assert type(s.compartments) == compartments.Compartments - diff --git a/flepimop/gempyor_pkg/tests/seir/test_ic.py b/flepimop/gempyor_pkg/tests/seir/test_ic.py index b4cd240ee..9e61aafcf 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_ic.py +++ b/flepimop/gempyor_pkg/tests/seir/test_ic.py @@ -21,7 +21,9 @@ def test_IC_success(self): outcome_modifiers_scenario=None, write_csv=False, ) - sic = initial_conditions.InitialConditionsFactory(config=s.initial_conditions_config) + sic = initial_conditions.InitialConditionsFactory( + config=s.initial_conditions_config + ) assert sic.initial_conditions_config == s.initial_conditions_config def test_IC_allow_missing_node_compartments_success(self): @@ -40,11 +42,15 @@ def test_IC_allow_missing_node_compartments_success(self): s.initial_conditions_config["allow_missing_nodes"] = True s.initial_conditions_config["allow_missing_compartments"] = True - sic = initial_conditions.InitialConditionsFactory(config=s.initial_conditions_config) + sic = initial_conditions.InitialConditionsFactory( + config=s.initial_conditions_config + ) sic.get_from_config(sim_id=100, modinf=s) def test_IC_IC_notImplemented_fail(self): - with pytest.raises(NotImplementedError, match=r".*unknown.*initial.*conditions.*"): + with pytest.raises( + NotImplementedError, match=r".*unknown.*initial.*conditions.*" + ): config.clear() config.read(user=False) config.set_file(f"{DATA_DIR}/config.yml") @@ -58,6 +64,8 @@ def test_IC_IC_notImplemented_fail(self): write_csv=False, ) s.initial_conditions_config["method"] = "unknown" - sic = initial_conditions.InitialConditionsFactory(config=s.initial_conditions_config) + sic = initial_conditions.InitialConditionsFactory( + config=s.initial_conditions_config + ) sic.get_from_config(sim_id=100, modinf=s) diff --git a/flepimop/gempyor_pkg/tests/seir/test_parameters.py b/flepimop/gempyor_pkg/tests/seir/test_parameters.py index 9e03bf87d..045e6643b 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_parameters.py +++ b/flepimop/gempyor_pkg/tests/seir/test_parameters.py @@ -10,7 +10,14 @@ import pyarrow.parquet as pq import filecmp -from gempyor import model_info, seir, NPI, file_paths, parameters, subpopulation_structure +from gempyor import ( + model_info, + seir, + NPI, + file_paths, + parameters, + subpopulation_structure, +) from gempyor.utils import config, write_df, read_df @@ -65,7 +72,9 @@ def test_parameters_from_config_plus_read_write(): tf=s.tf, subpop_names=s.subpop_struct.subpop_names, ) - p_load = rhs.parameters_load(param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops) + p_load = rhs.parameters_load( + param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops + ) assert (p_draw == p_load).all() @@ -102,9 +111,13 @@ def test_parameters_quick_draw_old(): assert params.pnames == ["alpha", "sigma", "gamma", "R0s"] assert params.npar == 4 assert params.stacked_modifier_method["sum"] == [] - assert params.stacked_modifier_method["product"] == [pn.lower() for pn in params.pnames] + assert params.stacked_modifier_method["product"] == [ + pn.lower() for pn in params.pnames + ] - p_array = params.parameters_quick_draw(n_days=modinf.n_days, nsubpops=modinf.nsubpops) + p_array = params.parameters_quick_draw( + n_days=modinf.n_days, nsubpops=modinf.nsubpops + ) print(p_array.shape) alpha = p_array[params.pnames2pindex["alpha"]] @@ -122,7 +135,12 @@ def test_parameters_quick_draw_old(): assert ((2 <= R0s) & (R0s <= 3)).all() assert sigma.shape == (modinf.n_days, modinf.nsubpops) - assert (sigma == config["seir"]["parameters"]["sigma"]["value"]["value"].as_evaled_expression()).all() + assert ( + sigma + == config["seir"]["parameters"]["sigma"]["value"][ + "value" + ].as_evaled_expression() + ).all() assert gamma.shape == (modinf.n_days, modinf.nsubpops) assert len(np.unique(gamma)) == 1 @@ -174,6 +192,8 @@ def test_parameters_from_timeseries_file(): tf=s.tf, subpop_names=s.subpop_struct.subpop_names, ) - p_load = rhs.parameters_load(param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops) + p_load = rhs.parameters_load( + param_df=read_df("test_pwrite.parquet"), n_days=n_days, nsubpops=nsubpops + ) assert (p_draw == p_load).all() diff --git a/flepimop/gempyor_pkg/tests/seir/test_seir.py b/flepimop/gempyor_pkg/tests/seir/test_seir.py index 99a4cc236..94a40d7d9 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_seir.py +++ b/flepimop/gempyor_pkg/tests/seir/test_seir.py @@ -73,8 +73,12 @@ def test_constant_population_legacy_integration(): ) integration_method = "legacy" - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -82,7 +86,9 @@ def test_constant_population_legacy_integration(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -94,7 +100,9 @@ def test_constant_population_legacy_integration(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) states = seir.steps_SEIR( modinf, @@ -142,16 +150,24 @@ def test_constant_population_rk4jit_integration_fail(): ) modinf.seir_config["integration"]["method"] = "rk4.jit" - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, modinf=modinf, modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, - pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method[ + "sum" + ], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -163,7 +179,9 @@ def test_constant_population_rk4jit_integration_fail(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) states = seir.steps_SEIR( modinf, @@ -212,8 +230,12 @@ def test_constant_population_rk4jit_integration(): # s.integration_method = "rk4.jit" assert modinf.seir_config["integration"]["method"].get() == "rk4" - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -221,7 +243,9 @@ def test_constant_population_rk4jit_integration(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -233,7 +257,9 @@ def test_constant_population_rk4jit_integration(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) states = seir.steps_SEIR( modinf, parsed_parameters, @@ -280,8 +306,12 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): out_prefix=prefix, ) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -289,7 +319,9 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -301,7 +333,9 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( @@ -316,11 +350,17 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "10001"] + df[ + (df["mc_value_type"] == "prevalence") + & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "10001"] > 1 ) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") + & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] > 1 ) @@ -336,11 +376,26 @@ def test_steps_SEIR_nb_simple_spread_with_txt_matrices(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") + & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] > 1 ) - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["20002"] > 0 - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["10001"] > 0 + assert ( + df[ + (df["mc_value_type"] == "incidence") + & (df["mc_infection_stage"] == "I1") + ].max()["20002"] + > 0 + ) + assert ( + df[ + (df["mc_value_type"] == "incidence") + & (df["mc_infection_stage"] == "I1") + ].max()["10001"] + > 0 + ) def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): @@ -366,8 +421,12 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): out_prefix=prefix, ) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -375,7 +434,9 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -387,7 +448,9 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( @@ -402,8 +465,20 @@ def test_steps_SEIR_nb_simple_spread_with_csv_matrices(): ) df = seir.states2Df(modinf, states) - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["20002"] > 0 - assert df[(df["mc_value_type"] == "incidence") & (df["mc_infection_stage"] == "I1")].max()["10001"] > 0 + assert ( + df[ + (df["mc_value_type"] == "incidence") + & (df["mc_infection_stage"] == "I1") + ].max()["20002"] + > 0 + ) + assert ( + df[ + (df["mc_value_type"] == "incidence") + & (df["mc_infection_stage"] == "I1") + ].max()["10001"] + > 0 + ) def test_steps_SEIR_no_spread(): @@ -426,8 +501,12 @@ def test_steps_SEIR_no_spread(): out_prefix=prefix, ) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) modinf.mobility.data = modinf.mobility.data * 0 @@ -437,7 +516,9 @@ def test_steps_SEIR_no_spread(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -449,7 +530,9 @@ def test_steps_SEIR_no_spread(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(10): states = seir.steps_SEIR( @@ -464,7 +547,10 @@ def test_steps_SEIR_no_spread(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") + & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] == 0.0 ) @@ -480,7 +566,10 @@ def test_steps_SEIR_no_spread(): ) df = seir.states2Df(modinf, states) assert ( - df[(df["mc_value_type"] == "prevalence") & (df["mc_infection_stage"] == "R")].loc[str(modinf.tf), "20002"] + df[ + (df["mc_value_type"] == "prevalence") + & (df["mc_infection_stage"] == "R") + ].loc[str(modinf.tf), "20002"] == 0.0 ) @@ -515,7 +604,9 @@ def test_continuation_resume(): seir.onerun_SEIR(sim_id2write=int(sim_id2write), modinf=modinf, config=config) states_old = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, 100, "seir", "parquet"), + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, 100, "seir", "parquet" + ), ).to_pandas() states_old = states_old[states_old["date"] == "2020-03-15"].reset_index(drop=True) @@ -547,7 +638,9 @@ def test_continuation_resume(): seir.onerun_SEIR(sim_id2write=sim_id2write, modinf=modinf, config=config) states_new = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write, "seir", "parquet"), + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write, "seir", "parquet" + ), ).to_pandas() states_new = states_new[states_new["date"] == "2020-03-15"].reset_index(drop=True) assert ( @@ -560,10 +653,16 @@ def test_continuation_resume(): ) seir.onerun_SEIR( - sim_id2write=sim_id2write + 1, modinf=modinf, sim_id2load=sim_id2write, load_ID=True, config=config + sim_id2write=sim_id2write + 1, + modinf=modinf, + sim_id2load=sim_id2write, + load_ID=True, + config=config, ) states_new = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "seir", "parquet"), + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "seir", "parquet" + ), ).to_pandas() states_new = states_new[states_new["date"] == "2020-03-15"].reset_index(drop=True) for path in ["model_output/seir", "model_output/snpi", "model_output/spar"]: @@ -587,7 +686,9 @@ def test_inference_resume(): spatial_config = config["subpop_setup"] if "data_path" in config: - raise ValueError("The config has a data_path section. This is no longer supported.") + raise ValueError( + "The config has a data_path section. This is no longer supported." + ) # spatial_base_path = pathlib.Path(config["data_path"].get()) modinf = model_info.ModelInfo( config=config, @@ -604,7 +705,9 @@ def test_inference_resume(): seir.onerun_SEIR(sim_id2write=int(sim_id2write), modinf=modinf, config=config) npis_old = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write, "snpi", "parquet") + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write, "snpi", "parquet" + ) ).to_pandas() config.clear() @@ -632,14 +735,24 @@ def test_inference_resume(): ) seir.onerun_SEIR( - sim_id2write=sim_id2write + 1, modinf=modinf, sim_id2load=sim_id2write, load_ID=True, config=config + sim_id2write=sim_id2write + 1, + modinf=modinf, + sim_id2load=sim_id2write, + load_ID=True, + config=config, ) npis_new = pq.read_table( - file_paths.create_file_name(modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "snpi", "parquet") + file_paths.create_file_name( + modinf.in_run_id, modinf.in_prefix, sim_id2write + 1, "snpi", "parquet" + ) ).to_pandas() assert npis_old["modifier_name"].isin(["None", "Wuhan", "KansasCity"]).all() - assert npis_new["modifier_name"].isin(["None", "Wuhan", "KansasCity", "BrandNew"]).all() + assert ( + npis_new["modifier_name"] + .isin(["None", "Wuhan", "KansasCity", "BrandNew"]) + .all() + ) # assert((['None', 'Wuhan', 'KansasCity']).isin(npis_old["modifier_name"]).all()) # assert((['None', 'Wuhan', 'KansasCity', 'BrandNew']).isin(npis_new["modifier_name"]).all()) assert (npis_old["start_date"] == "2020-04-01").all() @@ -674,8 +787,12 @@ def test_parallel_compartments_with_vacc(): out_prefix=prefix, ) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -683,7 +800,9 @@ def test_parallel_compartments_with_vacc(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -695,7 +814,9 @@ def test_parallel_compartments_with_vacc(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( @@ -761,8 +882,12 @@ def test_parallel_compartments_no_vacc(): out_prefix=prefix, ) - seeding_data, seeding_amounts = modinf.seeding.get_from_file(sim_id=100, modinf=modinf) - initial_conditions = modinf.initial_conditions.get_from_config(sim_id=100, modinf=modinf) + seeding_data, seeding_amounts = modinf.seeding.get_from_file( + sim_id=100, modinf=modinf + ) + initial_conditions = modinf.initial_conditions.get_from_config( + sim_id=100, modinf=modinf + ) npi = NPI.NPIBase.execute( npi_config=modinf.npi_config_seir, @@ -770,7 +895,9 @@ def test_parallel_compartments_no_vacc(): modifiers_library=modinf.seir_modifiers_library, subpops=modinf.subpop_struct.subpop_names, pnames_overlap_operation_sum=modinf.parameters.stacked_modifier_method["sum"], - pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method["reduction_product"], + pnames_overlap_operation_reductionprod=modinf.parameters.stacked_modifier_method[ + "reduction_product" + ], ) params = modinf.parameters.parameters_quick_draw(modinf.n_days, modinf.nsubpops) @@ -782,7 +909,9 @@ def test_parallel_compartments_no_vacc(): proportion_array, proportion_info, ) = modinf.compartments.get_transition_array() - parsed_parameters = modinf.compartments.parse_parameters(params, modinf.parameters.pnames, unique_strings) + parsed_parameters = modinf.compartments.parse_parameters( + params, modinf.parameters.pnames, unique_strings + ) for i in range(5): states = seir.steps_SEIR( diff --git a/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py b/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py index 34df630c3..9a11531a3 100644 --- a/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py +++ b/flepimop/gempyor_pkg/tests/seir/test_subpopulationstructure.py @@ -25,7 +25,9 @@ def test_subpopulation_structure_mobility(): mobility: {DATA_DIR}/mobility.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path @@ -59,7 +61,9 @@ def test_subpopulation_structure_mobility_txt(): mobility: {DATA_DIR}/mobility.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path @@ -89,7 +93,9 @@ def test_subpopulation_structure_subpop_population_zero_fail(): mobility: {DATA_DIR}/mobility.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path @@ -110,7 +116,9 @@ def test_subpopulation_structure_dulpicate_subpop_names_fail(): mobility: {DATA_DIR}/mobility.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path @@ -130,12 +138,16 @@ def test_subpopulation_structure_mobility_shape_fail(): mobility: {DATA_DIR}/mobility_2x3.txt """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path - with pytest.raises(ValueError, match=r"mobility data must have dimensions of length of geodata.*"): + with pytest.raises( + ValueError, match=r"mobility data must have dimensions of length of geodata.*" + ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] ) @@ -150,12 +162,16 @@ def test_subpopulation_structure_mobility_fluxes_same_ori_and_dest_fail(): mobility: {DATA_DIR}/mobility_same_ori_dest.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path - with pytest.raises(ValueError, match=r"Mobility fluxes with same origin and destination.*"): + with pytest.raises( + ValueError, match=r"Mobility fluxes with same origin and destination.*" + ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] ) @@ -170,12 +186,16 @@ def test_subpopulation_structure_mobility_npz_shape_fail(): mobility: {DATA_DIR}/mobility_2x3.npz """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path - with pytest.raises(ValueError, match=r"mobility data must have dimensions of length of geodata.*"): + with pytest.raises( + ValueError, match=r"mobility data must have dimensions of length of geodata.*" + ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] ) @@ -190,7 +210,9 @@ def test_subpopulation_structure_mobility_no_extension_fail(): mobility: {DATA_DIR}/mobility """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path @@ -210,13 +232,16 @@ def test_subpopulation_structure_mobility_exceed_source_node_pop_fail(): mobility: {DATA_DIR}/mobility1001.csv """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path with pytest.raises( - ValueError, match=r"The following entries in the mobility data exceed the source subpop populations.*" + ValueError, + match=r"The following entries in the mobility data exceed the source subpop populations.*", ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] @@ -232,13 +257,16 @@ def test_subpopulation_structure_mobility_rows_exceed_source_node_pop_fail(): mobility: {DATA_DIR}/mobility_row_exceeed.txt """ - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path with pytest.raises( - ValueError, match=r"The following entries in the mobility data exceed the source subpop populations.*" + ValueError, + match=r"The following entries in the mobility data exceed the source subpop populations.*", ): subpop_struct = subpopulation_structure.SubpopulationStructure( setup_name=TEST_SETUP_NAME, subpop_config=config["subpop_setup"] @@ -252,7 +280,9 @@ def test_subpopulation_structure_mobility_no_mobility_matrix_specified(): """ config.clear() config.read(user=False) - with tempfile.NamedTemporaryFile(delete=False) as temp_file: # Creates a temporary file + with tempfile.NamedTemporaryFile( + delete=False + ) as temp_file: # Creates a temporary file temp_file.write(subpop_config_str.encode("utf-8")) # Write the content temp_file.close() # Ensure the file is closed config.set_file(temp_file.name) # Load from the temporary file path diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py index 367a7f550..b12989ec4 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py +++ b/flepimop/gempyor_pkg/tests/utils/test_get_log_normal.py @@ -7,7 +7,7 @@ class TestGetLogNormal: """Unit tests for the `gempyor.utils.get_log_normal` function.""" - + @pytest.mark.parametrize( "meanlog,sdlog", [ @@ -22,15 +22,15 @@ class TestGetLogNormal: ], ) def test_construct_distribution( - self, - meanlog: float | int, - sdlog: float | int, + self, + meanlog: float | int, + sdlog: float | int, ) -> None: """Test the construction of a log normal distribution. - This test checks whether the `get_log_normal` function correctly constructs - a log normal distribution with the specified parameters. It verifies that - the returned object is an instance of `rv_frozen`, and that its support and + This test checks whether the `get_log_normal` function correctly constructs + a log normal distribution with the specified parameters. It verifies that + the returned object is an instance of `rv_frozen`, and that its support and parameters (log mean and log standard deviation) are correctly set. Args: diff --git a/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py index 23e4fad58..c3ccbca79 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py +++ b/flepimop/gempyor_pkg/tests/utils/test_get_truncated_normal.py @@ -7,7 +7,7 @@ class TestGetTruncatedNormal: """Unit tests for the `gempyor.utils.get_truncated_normal` function.""" - + @pytest.mark.parametrize( "mean,sd,a,b", [ @@ -21,10 +21,10 @@ class TestGetTruncatedNormal: ], ) def test_construct_distribution( - self, - mean: float | int, - sd: float | int, - a: float | int, + self, + mean: float | int, + sd: float | int, + a: float | int, b: float | int, ) -> None: """Test the construction of a truncated normal distribution. diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils.py b/flepimop/gempyor_pkg/tests/utils/test_utils.py index 768451b2a..7013e0ea3 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils.py @@ -10,7 +10,11 @@ @pytest.mark.parametrize( - ("fname", "extension"), [("mobility", "csv"), ("usa-geoid-params-output", "parquet"),], + ("fname", "extension"), + [ + ("mobility", "csv"), + ("usa-geoid-params-output", "parquet"), + ], ) def test_read_df_and_write_success(fname, extension): os.chdir(tmp_path) @@ -23,13 +27,18 @@ def test_read_df_and_write_success(fname, extension): utils.write_df(tmp_path + "/data/" + fname, df2, extension=extension) assert os.path.isfile(tmp_path + "/data/" + fname + "." + extension) elif extension == "parquet": - df2 = pa.parquet.read_table(f"{DATA_DIR}/" + fname + "." + extension).to_pandas() + df2 = pa.parquet.read_table( + f"{DATA_DIR}/" + fname + "." + extension + ).to_pandas() assert df2.equals(df1) utils.write_df(tmp_path + "/data/" + fname, df2, extension=extension) assert os.path.isfile(tmp_path + "/data/" + fname + "." + extension) -@pytest.mark.parametrize(("fname", "extension"), [("mobility", "csv"), ("usa-geoid-params-output", "parquet")]) +@pytest.mark.parametrize( + ("fname", "extension"), + [("mobility", "csv"), ("usa-geoid-params-output", "parquet")], +) def test_read_df_and_write_fail(fname, extension): with pytest.raises(NotImplementedError, match=r".*Invalid.*extension.*Must.*"): os.chdir(tmp_path) @@ -41,7 +50,9 @@ def test_read_df_and_write_fail(fname, extension): assert df2.equals(df1) utils.write_df(tmp_path + "/data/" + fname, df2, extension="") elif extension == "parquet": - df2 = pa.parquet.read_table(f"{DATA_DIR}/" + fname + "." + extension).to_pandas() + df2 = pa.parquet.read_table( + f"{DATA_DIR}/" + fname + "." + extension + ).to_pandas() assert df2.equals(df1) utils.write_df(tmp_path + "/data/" + fname, df2, extension="") @@ -91,9 +102,7 @@ def test_create_resume_out_filename(): filetype="spar", liketype="global", ) - expected_filename = ( - "model_output/output/123/spar/global/intermediate/000000002.000000001.000000001.123.spar.parquet" - ) + expected_filename = "model_output/output/123/spar/global/intermediate/000000002.000000001.000000001.123.spar.parquet" assert result == expected_filename result2 = utils.create_resume_out_filename( @@ -111,32 +120,59 @@ def test_create_resume_out_filename(): def test_create_resume_input_filename(): result = utils.create_resume_input_filename( - flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="spar", liketype="global" + flepi_slot_index="2", + resume_run_index="321", + flepi_prefix="output", + filetype="spar", + liketype="global", + ) + expect_filename = ( + "model_output/output/321/spar/global/final/000000002.321.spar.parquet" ) - expect_filename = "model_output/output/321/spar/global/final/000000002.321.spar.parquet" assert result == expect_filename result2 = utils.create_resume_input_filename( - flepi_slot_index="2", resume_run_index="321", flepi_prefix="output", filetype="seed", liketype="chimeric" + flepi_slot_index="2", + resume_run_index="321", + flepi_prefix="output", + filetype="seed", + liketype="chimeric", + ) + expect_filename2 = ( + "model_output/output/321/seed/chimeric/final/000000002.321.seed.csv" ) - expect_filename2 = "model_output/output/321/seed/chimeric/final/000000002.321.seed.csv" assert result2 == expect_filename2 def test_get_filetype_resume_discard_seeding_true_flepi_block_index_1(): expected_types = ["spar", "snpi", "hpar", "hnpi", "init"] - assert utils.get_filetype_for_resume(resume_discard_seeding="true", flepi_block_index="1") == expected_types + assert ( + utils.get_filetype_for_resume( + resume_discard_seeding="true", flepi_block_index="1" + ) + == expected_types + ) def test_get_filetype_resume_discard_seeding_false_flepi_block_index_1(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "init"] - assert utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="1") == expected_types + assert ( + utils.get_filetype_for_resume( + resume_discard_seeding="false", flepi_block_index="1" + ) + == expected_types + ) def test_get_filetype_flepi_block_index_2(): expected_types = ["seed", "spar", "snpi", "hpar", "hnpi", "host", "llik", "init"] - assert utils.get_filetype_for_resume(resume_discard_seeding="false", flepi_block_index="2") == expected_types + assert ( + utils.get_filetype_for_resume( + resume_discard_seeding="false", flepi_block_index="2" + ) + == expected_types + ) def test_create_resume_file_names_map(): diff --git a/flepimop/gempyor_pkg/tests/utils/test_utils2.py b/flepimop/gempyor_pkg/tests/utils/test_utils2.py index 4b0ae59ba..0822604ed 100644 --- a/flepimop/gempyor_pkg/tests/utils/test_utils2.py +++ b/flepimop/gempyor_pkg/tests/utils/test_utils2.py @@ -26,7 +26,9 @@ class SampleClass: def __init__(self): self.value = 11 - @utils.profile(output_file="get_value.prof", sort_by="time", lines_to_print=10, strip_dirs=True) + @utils.profile( + output_file="get_value.prof", sort_by="time", lines_to_print=10, strip_dirs=True + ) def get_value(self): return self.value @@ -198,7 +200,9 @@ def test_as_random_distribution_binomial_w_fraction_error(config): def test_as_random_distribution_truncnorm(config): - config.add({"value": {"distribution": "truncnorm", "mean": 0, "sd": 1, "a": -1, "b": 1}}) + config.add( + {"value": {"distribution": "truncnorm", "mean": 0, "sd": 1, "a": -1, "b": 1}} + ) dist = config["value"].as_random_distribution() rvs = dist(size=1000) assert len(rvs) == 1000 diff --git a/postprocessing/postprocess_auto.py b/postprocessing/postprocess_auto.py index aaf4a0bff..13f0849d6 100644 --- a/postprocessing/postprocess_auto.py +++ b/postprocessing/postprocess_auto.py @@ -26,7 +26,13 @@ def __init__(self, run_id, config_filepath=None, folder_path=None): self.folder_path = folder_path -def get_all_filenames(file_type, all_runs, finals_only=False, intermediates_only=False, ignore_chimeric=True) -> dict: +def get_all_filenames( + file_type, + all_runs, + finals_only=False, + intermediates_only=False, + ignore_chimeric=True, +) -> dict: """ return dictionanary for each run name """ @@ -159,7 +165,14 @@ def slack_multiple_files_v2(slack_token, message, file_list, channel): help="Maximum number of files to load for in depth plot and individual sim plot", ) def generate_pdf( - config_filepath, run_id, job_name, fs_results_path, slack_token, slack_channel, max_files, max_files_deep + config_filepath, + run_id, + job_name, + fs_results_path, + slack_token, + slack_channel, + max_files, + max_files_deep, ): print("Generating plots") print(f">> config {config_filepath} for run_id {run_id}") @@ -217,10 +230,14 @@ def generate_pdf( for filename in file_list: slot = int(filename.split("/")[-1].split(".")[0]) block = int(filename.split("/")[-1].split(".")[1]) - sim_str = filename.split("/")[-1].split(".")[2] # not necessarily a sim number now + sim_str = filename.split("/")[-1].split(".")[ + 2 + ] # not necessarily a sim number now if sim_str.isdigit(): sim = int(sim_str) - if block == 1 and (sim == 1 or sim % 5 == 0): ## first block, only one + if block == 1 and ( + sim == 1 or sim % 5 == 0 + ): ## first block, only one df_raw = pq.read_table(filename).to_pandas() df_raw["slot"] = slot df_raw["sim"] = sim @@ -238,7 +255,9 @@ def generate_pdf( # In[23]: - fig, axes = plt.subplots(len(node_names) + 1, 4, figsize=(4 * 4, len(node_names) * 3), sharex=True) + fig, axes = plt.subplots( + len(node_names) + 1, 4, figsize=(4 * 4, len(node_names) * 3), sharex=True + ) colors = ["b", "r", "y", "c"] icl = 0 @@ -255,32 +274,68 @@ def generate_pdf( lls = lls.cumsum() feature = "accepts, cumulative" axes[idp, ift].fill_between( - lls.index, lls.quantile(0.025, axis=1), lls.quantile(0.975, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.025, axis=1), + lls.quantile(0.975, axis=1), + alpha=0.1, + color=colors[icl], ) axes[idp, ift].fill_between( - lls.index, lls.quantile(0.25, axis=1), lls.quantile(0.75, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.25, axis=1), + lls.quantile(0.75, axis=1), + alpha=0.1, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, + lls.median(axis=1), + marker="o", + label=run_id, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3 ) - axes[idp, ift].plot(lls.index, lls.median(axis=1), marker="o", label=run_id, color=colors[icl]) - axes[idp, ift].plot(lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3) axes[idp, ift].set_title(f"National, {feature}") axes[idp, ift].grid() for idp, nn in enumerate(node_names): idp = idp + 1 - all_nn = full_df[full_df["subpop"] == nn][["sim", "slot", "ll", "accept", "accept_avg", "accept_prob"]] - for ift, feature in enumerate(["ll", "accept", "accept_avg", "accept_prob"]): + all_nn = full_df[full_df["subpop"] == nn][ + ["sim", "slot", "ll", "accept", "accept_avg", "accept_prob"] + ] + for ift, feature in enumerate( + ["ll", "accept", "accept_avg", "accept_prob"] + ): lls = all_nn.pivot(index="sim", columns="slot", values=feature) if feature == "accept": lls = lls.cumsum() feature = "accepts, cumulative" axes[idp, ift].fill_between( - lls.index, lls.quantile(0.025, axis=1), lls.quantile(0.975, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.025, axis=1), + lls.quantile(0.975, axis=1), + alpha=0.1, + color=colors[icl], ) axes[idp, ift].fill_between( - lls.index, lls.quantile(0.25, axis=1), lls.quantile(0.75, axis=1), alpha=0.1, color=colors[icl] + lls.index, + lls.quantile(0.25, axis=1), + lls.quantile(0.75, axis=1), + alpha=0.1, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, + lls.median(axis=1), + marker="o", + label=run_id, + color=colors[icl], + ) + axes[idp, ift].plot( + lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3 ) - axes[idp, ift].plot(lls.index, lls.median(axis=1), marker="o", label=run_id, color=colors[icl]) - axes[idp, ift].plot(lls.index, lls.iloc[:, 0:max_files_deep], color="k", lw=0.3) axes[idp, ift].set_title(f"{nn}, {feature}") axes[idp, ift].grid() if idp == len(node_names) - 1: @@ -292,8 +347,11 @@ def generate_pdf( pass import gempyor.utils - llik_filenames = gempyor.utils.list_filenames(folder="model_output/", filters=["final", "llik" , ".parquet"]) - #get_all_filenames("llik", fs_results_path, finals_only=True, intermediates_only=False) + + llik_filenames = gempyor.utils.list_filenames( + folder="model_output/", filters=["final", "llik", ".parquet"] + ) + # get_all_filenames("llik", fs_results_path, finals_only=True, intermediates_only=False) # In[7]: resultST = [] for filename in llik_filenames: diff --git a/utilities/clean_s3.py b/utilities/clean_s3.py index 2998c65c2..08982b2b2 100644 --- a/utilities/clean_s3.py +++ b/utilities/clean_s3.py @@ -9,7 +9,9 @@ s3 = boto3.client("s3") paginator = s3.get_paginator("list_objects_v2") -pages = paginator.paginate(Bucket=bucket, Prefix="", Delimiter="/") # needs paginator cause more than 1000 files +pages = paginator.paginate( + Bucket=bucket, Prefix="", Delimiter="/" +) # needs paginator cause more than 1000 files to_prun = [] # folders: diff --git a/utilities/copy_for_continuation.py b/utilities/copy_for_continuation.py index 33d9da40f..05b7a803a 100644 --- a/utilities/copy_for_continuation.py +++ b/utilities/copy_for_continuation.py @@ -77,10 +77,14 @@ def detect_old_run_id(fp): fn = files[0] old_run_id = detect_old_run_id(fn) new_name = ( - fn.replace("seir", "cont").replace(f"{input_folder}/model_output", "model_output").replace(old_run_id, run_id) + fn.replace("seir", "cont") + .replace(f"{input_folder}/model_output", "model_output") + .replace(old_run_id, run_id) ) - print(f"detected old_run_id: {old_run_id} which will be replaced by user provided run_id: {run_id}") + print( + f"detected old_run_id: {old_run_id} which will be replaced by user provided run_id: {run_id}" + ) empty_str = "°" * len(input_folder) print(f"file: \n OLD NAME: {fn}\n NEW NAME: {empty_str}{new_name}") for fn in tqdm.tqdm(files): diff --git a/utilities/prune_by_llik.py b/utilities/prune_by_llik.py index 5b1f3224b..c783591b1 100644 --- a/utilities/prune_by_llik.py +++ b/utilities/prune_by_llik.py @@ -11,7 +11,11 @@ def get_all_filenames( - file_type, fs_results_path="to_prune/", finals_only=False, intermediates_only=True, ignore_chimeric=True + file_type, + fs_results_path="to_prune/", + finals_only=False, + intermediates_only=True, + ignore_chimeric=True, ) -> dict: """ return dictionary for each run name @@ -113,14 +117,18 @@ def get_all_filenames( if fill_missing: # Extract the numbers from the filenames numbers = [int(os.path.basename(filename).split(".")[0]) for filename in all_files] - missing_numbers = [num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers] + missing_numbers = [ + num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers + ] if missing_numbers: missing_filenames = [] for num in missing_numbers: filename = os.path.basename(all_files[0]) filename_prefix = re.search(r"^.*?(\d+)", filename).group() filename_suffix = re.search(r"(\..*?)$", filename).group() - missing_filename = os.path.join(os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}") + missing_filename = os.path.join( + os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}" + ) missing_filenames.append(missing_filename) print("The missing filenames with full paths are:") for missing_filename in missing_filenames: @@ -143,7 +151,7 @@ def copy_path(src, dst): file_types = [ "llik", - #"seed", + # "seed", "init", "snpi", "hnpi", @@ -160,7 +168,9 @@ def copy_path(src, dst): if fn in files_to_keep: for file_type in file_types: src = fn.replace("llik", file_type) - dst = fn.replace(fs_results_path, output_folder).replace("llik", file_type) + dst = fn.replace(fs_results_path, output_folder).replace( + "llik", file_type + ) if file_type == "seed": src = src.replace(".parquet", ".csv") dst = dst.replace(".parquet", ".csv") @@ -169,7 +179,9 @@ def copy_path(src, dst): file_to_keep = np.random.choice(files_to_keep) for file_type in file_types: src = file_to_keep.replace("llik", file_type) - dst = fn.replace(fs_results_path, output_folder).replace("llik", file_type) + dst = fn.replace(fs_results_path, output_folder).replace( + "llik", file_type + ) if file_type == "seed": src = src.replace(".parquet", ".csv") dst = dst.replace(".parquet", ".csv") diff --git a/utilities/prune_by_llik_and_proj.py b/utilities/prune_by_llik_and_proj.py index 53e623224..8bc50b163 100644 --- a/utilities/prune_by_llik_and_proj.py +++ b/utilities/prune_by_llik_and_proj.py @@ -11,7 +11,11 @@ def get_all_filenames( - file_type, fs_results_path="to_prune/", finals_only=False, intermediates_only=True, ignore_chimeric=True + file_type, + fs_results_path="to_prune/", + finals_only=False, + intermediates_only=True, + ignore_chimeric=True, ) -> dict: """ return dictionary for each run name @@ -23,7 +27,7 @@ def get_all_filenames( l = [] for f in Path(str(fs_results_path + "model_output")).rglob(f"*.{ext}"): f = str(f) - + if file_type in f: print(f) if ( @@ -61,7 +65,9 @@ def get_all_filenames( fs_results_path = "to_prune/" best_n = 200 -llik_filenames = get_all_filenames("llik", fs_results_path, finals_only=True, intermediates_only=False) +llik_filenames = get_all_filenames( + "llik", fs_results_path, finals_only=True, intermediates_only=False +) # In[7]: resultST = [] for filename in llik_filenames: @@ -100,7 +106,6 @@ def get_all_filenames( print(f" - {slot:4}, llik: {sorted_llik.loc[slot]['ll']:0.3f}") - #### RERUN FROM HERE TO CHANGE THE REGULARIZATION files_to_keep = list(full_df.loc[best_slots]["filename"].unique()) # important to sort by llik @@ -109,8 +114,9 @@ def get_all_filenames( files_to_keep = [] for fn in all_files: if fn in files_to_keep3: - outcome_fn = fn.replace("llik", "hosp") + outcome_fn = fn.replace("llik", "hosp") import gempyor.utils + outcomes_df = gempyor.utils.read_df(outcome_fn) outcomes_df = outcomes_df.set_index("date") reg = 1.5 @@ -118,19 +124,27 @@ def get_all_filenames( this_bad = 0 bad_subpops = [] for sp in outcomes_df["subpop"].unique(): - max_fit = outcomes_df[outcomes_df["subpop"]==sp]["incidC"][:"2024-04-08"].max() - max_summer = outcomes_df[outcomes_df["subpop"]==sp]["incidC"]["2024-04-08":"2024-09-30"].max() - if max_summer > max_fit*reg: + max_fit = outcomes_df[outcomes_df["subpop"] == sp]["incidC"][ + :"2024-04-08" + ].max() + max_summer = outcomes_df[outcomes_df["subpop"] == sp]["incidC"][ + "2024-04-08":"2024-09-30" + ].max() + if max_summer > max_fit * reg: this_bad += 1 - max_reg = max(max_reg, max_summer/max_fit) + max_reg = max(max_reg, max_summer / max_fit) bad_subpops.append(sp) - #print(f"changing {sp} because max_summer max_summer={max_summer:.1f} > reg*max_fit={max_fit:.1f}, diff {max_fit/max_summer*100:.1f}%") - #print(f">>> MULT BY {max_summer/max_fit*mult:2f}") - #outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]] = outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]]*max_summer/max_fit*mult - if this_bad>4 or max_reg>4: - print(f"{outcome_fn.split('/')[-1].split('.')[0]} >>> BAAD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}") + # print(f"changing {sp} because max_summer max_summer={max_summer:.1f} > reg*max_fit={max_fit:.1f}, diff {max_fit/max_summer*100:.1f}%") + # print(f">>> MULT BY {max_summer/max_fit*mult:2f}") + # outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]] = outcomes_df.loc[outcomes_df["subpop"]==sp, ["incidH", "incidD"]]*max_summer/max_fit*mult + if this_bad > 4 or max_reg > 4: + print( + f"{outcome_fn.split('/')[-1].split('.')[0]} >>> BAAD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}" + ) else: - print(f"{outcome_fn.split('/')[-1].split('.')[0]} >>> GOOD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}") + print( + f"{outcome_fn.split('/')[-1].split('.')[0]} >>> GOOD: {this_bad} subpops AND max_ratio={max_reg:.1f}, sp with max_summer > max_fit*{reg} {bad_subpops}" + ) files_to_keep.append(fn) print(len(files_to_keep)) ### END OF CODE @@ -146,14 +160,18 @@ def get_all_filenames( if fill_missing: # Extract the numbers from the filenames numbers = [int(os.path.basename(filename).split(".")[0]) for filename in all_files] - missing_numbers = [num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers] + missing_numbers = [ + num for num in range(fill_from_min, fill_from_max + 1) if num not in numbers + ] if missing_numbers: missing_filenames = [] for num in missing_numbers: filename = os.path.basename(all_files[0]) filename_prefix = re.search(r"^.*?(\d+)", filename).group() filename_suffix = re.search(r"(\..*?)$", filename).group() - missing_filename = os.path.join(os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}") + missing_filename = os.path.join( + os.path.dirname(all_files[0]), f"{num:09d}{filename_suffix}" + ) missing_filenames.append(missing_filename) print("The missing filenames with full paths are:") for missing_filename in missing_filenames: @@ -191,7 +209,9 @@ def copy_path(src, dst): if fn in files_to_keep: for file_type in file_types: src = fn.replace("llik", file_type) - dst = fn.replace(fs_results_path, output_folder).replace("llik", file_type) + dst = fn.replace(fs_results_path, output_folder).replace( + "llik", file_type + ) if file_type == "seed": src = src.replace(".parquet", ".csv") dst = dst.replace(".parquet", ".csv") @@ -200,7 +220,9 @@ def copy_path(src, dst): file_to_keep = np.random.choice(files_to_keep) for file_type in file_types: src = file_to_keep.replace("llik", file_type) - dst = fn.replace(fs_results_path, output_folder).replace("llik", file_type) + dst = fn.replace(fs_results_path, output_folder).replace( + "llik", file_type + ) if file_type == "seed": src = src.replace(".parquet", ".csv") dst = dst.replace(".parquet", ".csv")