diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3ec5f14..5898525 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,8 +2,70 @@
All notable changes to this project will be documented in this file.
+
+## [1.0.0.a5] - 2024-02-06
-## [Unreleased]
+In this alpha release we fixed more bugs and issues that emerged during more rigorous testing.
+
+Most notably, we backed away from storing the transition matrix in a model's instance. Because it created opaque and confusion calls to functions trying to delete them when parameters were updated.
+
+Instead, the function computing the transition matrix is now globally cached using a hash function from the graph representation. This has the drawback of slightly more computation time when calculating the hash. But the advantage is that e.g. in a bilateral symmetric model, the transition matrix of the two sides is only ever computed once when (synched) parameters are updated.
+
+### Bug Fixes
+
+- (**graph**) Assume `nodes` is dictionary, not a list
+- (**uni**) Update `draw_patients()` method to output LyProX style data
+- (**bi**) Update bilateral data generation method to also generate LyProX style data
+- (**bi**) Syntax error in `init_synchronization`
+- (**uni**) Remove need for transition matrix deletion via a global cache
+- (**uni**) Use cached matrices & simplify stuff
+- (**uni**) Observation matrix only property, not cached anymore
+
+### Documentation
+
+- Fix typos & formatting errors in docstrings
+
+### Features
+
+- (**graph**) Implement graph hash for global cache of transition matrix
+- (**helper**) Add an `arg0` cache decorator that caches based on the first argument only
+- (**matrix**) Use cache for observation & diagnose matrices
+
+### Miscellaneous Tasks
+
+- Update dependencies & classifiers
+
+### Refactor
+
+- Variables inside `generate_transition()`
+
+### Testing
+
+- Make doctests discoverable by unittest
+- Update tests to changed API
+- (**uni**) Assert format & distribution of drawn patients
+- (**uni**) Allow larger delta for synthetic data distribution
+- (**bi**) Check bilateral data generation method
+- Check the bilateral model with symmetric tumor spread
+- Make sure delete & recompute synced edges' tensor work
+- Adapt tests to changed `Edge` API
+- (**bi**) Evaluate transition matrix recomputation
+- Update tests to match new transition matrix code
+- Update trinary unilateral tests
+
+### Change
+
+- ⚠ **BREAKING** Compute transition tensor globally
+- ⚠ **BREAKING** Make transition matrix a method instead of a property
+- ⚠ **BREAKING** Make observation matrix a method instead of a property
+
+### Ci
+
+- Add coverage test dependency back into project
+
+### Remove
+
+- Unused files and directories
diff --git a/docs/source/_data/bilateral.csv b/docs/source/_data/bilateral.csv
deleted file mode 100644
index e87deac..0000000
--- a/docs/source/_data/bilateral.csv
+++ /dev/null
@@ -1,203 +0,0 @@
-MRI,MRI,MRI,MRI,MRI,MRI,MRI,MRI,PET,PET,PET,PET,PET,PET,PET,PET,info
-contra,contra,contra,contra,ipsi,ipsi,ipsi,ipsi,contra,contra,contra,contra,ipsi,ipsi,ipsi,ipsi,tumor
-I,II,III,IV,I,II,III,IV,I,II,III,IV,I,II,III,IV,t_stage
-True,True,False,True,False,False,False,True,False,False,False,False,False,True,True,False,late
-False,False,False,False,False,True,False,True,False,False,False,False,False,True,False,False,early
-False,False,True,True,True,False,True,True,False,False,False,False,False,True,False,False,early
-True,False,True,False,True,False,False,False,False,False,True,False,False,False,False,False,early
-True,False,False,False,False,True,True,True,True,False,False,False,True,True,False,True,late
-False,False,True,False,True,False,False,True,False,False,False,False,False,True,True,True,late
-True,False,False,False,False,False,True,False,False,False,False,False,False,True,False,False,early
-False,True,False,False,True,False,True,False,True,False,True,False,False,False,False,False,late
-False,False,True,True,False,False,True,True,False,True,True,True,False,False,False,False,late
-True,True,True,True,False,False,True,False,False,True,True,False,False,True,True,True,early
-True,True,True,True,False,True,True,True,False,False,False,True,False,True,True,False,late
-False,True,True,False,True,False,False,False,False,True,True,False,False,True,False,True,early
-True,True,True,False,False,True,False,True,False,False,True,False,False,False,True,False,early
-True,False,True,False,False,True,True,False,False,True,True,False,False,True,True,True,late
-True,True,False,False,True,True,False,True,False,True,False,True,False,True,True,False,early
-True,False,False,False,False,False,True,True,False,True,True,True,False,False,False,False,early
-True,True,True,False,True,True,True,True,False,False,False,False,False,True,False,True,late
-True,True,True,False,True,True,True,True,False,False,False,False,True,True,True,False,late
-False,True,False,True,True,False,False,False,False,False,False,False,False,False,False,False,early
-True,False,False,False,False,True,False,False,False,False,True,False,False,False,False,True,early
-True,True,True,True,True,False,True,True,False,True,True,True,False,True,True,True,late
-True,True,True,False,False,True,False,True,False,False,False,False,False,False,False,False,early
-False,True,True,False,True,True,True,True,False,False,True,False,False,True,True,False,early
-False,False,True,False,False,True,False,True,False,False,False,False,False,False,True,False,late
-True,False,False,True,True,False,False,False,False,False,False,False,True,False,False,False,early
-True,False,False,True,True,True,False,True,False,False,True,False,True,False,True,False,late
-False,True,False,False,True,True,True,True,False,True,False,False,False,False,False,True,early
-False,True,False,True,True,False,False,False,True,False,False,False,True,True,False,True,early
-False,False,True,True,True,True,False,False,False,True,True,False,False,True,False,True,early
-False,True,True,False,True,False,True,True,True,False,False,False,False,False,False,False,early
-True,True,True,True,True,False,True,True,False,False,True,False,True,False,True,True,late
-False,True,False,False,True,False,False,False,False,True,False,False,True,False,True,False,early
-True,True,False,False,True,False,False,True,False,False,False,False,True,False,False,False,early
-False,True,True,True,False,True,True,False,True,True,False,True,True,True,True,True,late
-False,True,False,True,True,False,True,True,False,True,True,True,False,True,False,True,late
-False,False,False,False,True,True,True,False,False,False,False,False,False,True,False,True,late
-True,False,False,False,True,True,False,False,True,False,False,False,False,True,False,False,early
-True,True,True,False,True,False,False,False,False,False,False,False,False,False,True,False,late
-False,False,False,True,False,False,True,False,False,False,False,False,False,False,False,True,late
-True,False,False,True,False,False,False,True,False,True,False,False,False,False,False,True,early
-False,False,False,True,True,True,True,True,False,False,False,False,False,True,False,False,early
-True,True,False,False,True,True,True,False,False,False,False,False,False,False,False,False,early
-True,True,False,True,True,False,True,True,False,True,False,True,True,False,True,True,early
-False,False,True,True,True,True,True,False,False,False,True,False,False,False,False,False,early
-True,False,False,True,False,False,False,True,False,False,False,False,False,False,False,True,early
-True,False,False,False,True,True,True,True,False,False,False,False,False,True,False,True,early
-True,False,False,True,False,True,True,False,True,False,True,False,False,False,True,False,early
-True,True,False,False,True,True,True,True,False,False,False,False,False,True,True,True,late
-False,True,False,False,False,True,True,False,True,False,False,False,False,False,True,True,late
-False,False,False,True,True,True,True,True,False,False,False,False,True,True,False,True,early
-False,False,False,True,False,False,True,False,False,True,False,False,False,False,True,False,early
-True,False,False,True,False,True,True,False,True,True,True,False,False,False,True,False,late
-False,False,True,False,False,True,True,False,False,False,True,False,True,True,True,True,late
-False,False,True,True,False,True,False,False,False,False,False,False,False,True,False,False,early
-True,False,True,False,False,False,False,True,True,True,True,False,True,True,True,True,late
-True,False,True,False,True,True,False,True,False,False,False,False,True,False,True,True,early
-True,True,True,False,True,True,True,False,False,True,True,False,False,False,False,False,early
-True,True,False,True,True,True,True,True,False,False,False,True,True,False,True,True,late
-False,False,True,False,False,True,True,True,False,True,False,False,False,False,False,True,early
-True,True,False,False,False,False,False,True,False,False,False,False,False,False,False,True,late
-False,False,False,False,True,False,False,True,False,False,True,True,True,True,True,True,late
-True,False,True,False,False,False,True,False,False,False,False,False,True,True,False,False,late
-True,True,False,True,False,True,False,True,False,False,False,False,True,True,True,True,late
-False,True,False,False,False,False,True,False,False,False,False,False,False,False,False,False,early
-True,True,False,False,False,True,False,True,True,False,False,False,False,True,False,False,early
-False,False,True,False,False,True,True,True,False,True,False,True,False,False,True,True,early
-False,False,False,True,True,True,True,False,False,False,False,False,False,False,True,False,late
-True,False,False,False,True,False,True,True,False,False,False,False,True,False,True,True,early
-True,False,False,False,True,True,False,False,False,False,False,False,False,True,False,False,early
-False,False,True,False,False,True,True,True,True,False,False,False,True,True,False,True,early
-True,False,False,False,True,True,False,False,False,False,False,False,False,False,False,False,early
-True,True,True,False,False,False,True,True,False,False,True,False,False,True,True,False,early
-False,False,True,True,True,False,True,True,False,True,False,False,True,False,True,False,early
-True,False,True,True,True,False,False,False,False,False,False,False,False,True,True,True,late
-False,True,True,True,True,True,False,True,False,True,True,True,True,False,True,False,late
-True,True,False,False,False,True,True,True,False,False,True,False,False,False,False,True,early
-False,False,False,False,False,False,False,True,False,False,False,False,False,False,True,False,early
-True,False,True,True,False,False,True,False,False,True,True,False,False,False,False,False,late
-True,False,True,True,False,True,True,False,False,False,False,False,False,False,False,False,early
-False,True,False,True,True,False,True,False,False,False,False,False,False,False,True,True,early
-False,True,False,True,False,True,False,False,False,True,False,False,False,True,True,False,late
-True,True,False,False,True,False,True,True,False,False,False,True,False,True,True,True,early
-True,False,True,True,False,False,False,False,False,False,False,False,False,False,False,False,early
-False,True,False,False,True,False,False,False,False,False,False,False,True,False,False,True,early
-False,False,False,True,False,True,False,True,False,False,False,False,False,False,False,False,late
-False,False,False,False,True,False,True,True,False,False,False,False,False,False,True,False,early
-False,False,True,False,True,True,False,True,False,True,False,False,False,True,False,True,early
-True,False,True,True,True,False,False,False,True,False,True,False,False,True,False,False,late
-False,True,False,False,False,True,True,True,False,False,False,False,False,True,True,True,late
-False,True,True,False,True,True,False,True,False,False,False,False,False,True,False,False,early
-False,False,False,True,False,True,True,True,False,True,False,True,False,True,True,False,early
-False,True,True,False,True,True,False,True,False,True,True,False,True,True,True,True,early
-True,False,True,True,False,True,False,True,False,False,False,False,False,False,False,False,early
-False,True,True,True,False,False,False,True,False,True,False,True,False,True,True,True,late
-False,False,True,False,False,True,False,False,False,False,False,False,False,True,True,False,early
-False,False,True,True,False,False,True,True,True,False,False,False,True,False,True,True,late
-False,False,False,False,True,True,False,True,False,False,False,False,False,True,False,False,late
-True,True,True,True,True,True,True,True,False,True,True,True,True,True,False,True,late
-False,False,True,False,True,True,True,True,True,False,True,False,True,True,True,True,late
-True,False,False,True,False,True,True,True,True,False,False,False,True,True,False,True,late
-False,True,False,False,True,True,True,True,False,False,False,False,True,True,False,True,early
-True,False,False,False,True,True,True,True,False,False,False,True,False,False,True,False,late
-False,False,True,False,False,False,True,False,False,False,False,False,False,True,True,False,late
-False,False,False,True,True,True,False,False,False,False,False,False,False,True,False,False,early
-True,False,False,False,False,True,True,False,False,False,False,False,False,True,True,False,early
-False,False,False,True,False,False,True,True,False,False,False,False,False,True,False,False,early
-True,True,True,False,True,False,True,True,False,False,False,True,False,True,True,True,late
-False,False,False,False,True,True,True,True,True,False,True,False,True,True,True,True,late
-False,False,False,False,True,True,False,False,False,False,False,True,False,False,False,False,early
-False,True,True,True,True,True,True,True,True,True,False,True,True,False,True,True,late
-True,False,True,True,True,False,False,True,False,False,False,False,False,True,True,True,late
-True,True,True,False,False,True,False,True,False,False,False,False,False,False,True,True,early
-False,False,False,False,True,True,True,False,False,False,False,False,False,False,False,False,early
-True,True,True,True,True,True,True,True,False,False,True,False,True,True,False,True,late
-False,True,True,False,False,True,True,False,True,True,False,True,False,True,False,False,early
-False,True,True,True,True,True,False,True,False,False,True,True,True,True,False,False,early
-True,True,False,True,True,False,False,True,False,False,False,False,False,False,True,True,early
-True,False,False,True,False,True,True,True,False,False,False,False,True,False,False,True,late
-False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,early
-True,True,False,False,False,True,True,False,False,True,False,True,False,True,False,False,late
-False,False,True,False,True,True,False,False,False,False,False,False,True,False,True,False,early
-True,True,False,False,False,False,True,True,False,False,False,False,False,False,False,False,early
-False,True,True,False,False,False,True,True,True,False,False,True,True,False,False,True,late
-True,False,True,True,False,False,True,True,True,False,True,False,False,True,True,True,early
-False,True,True,False,False,True,False,True,False,False,False,False,False,True,False,True,early
-False,True,False,True,False,True,True,False,False,True,False,True,False,True,True,False,early
-False,True,False,True,False,True,True,True,False,True,True,True,False,True,False,True,late
-False,False,False,False,True,True,False,False,False,False,False,False,True,True,False,False,early
-True,False,True,False,True,True,False,True,True,True,True,False,True,True,False,False,early
-False,False,False,False,True,True,True,False,False,False,False,False,True,False,False,False,early
-False,True,True,False,False,True,False,True,True,True,True,False,True,True,True,True,late
-True,False,True,True,True,False,False,True,False,False,False,False,True,False,False,True,late
-False,True,False,False,False,True,False,True,True,True,False,False,False,False,True,True,late
-True,True,True,False,False,True,False,False,False,False,False,False,False,False,False,True,early
-False,True,True,False,True,True,True,False,False,False,False,False,False,True,True,False,early
-False,True,True,False,True,True,True,True,True,True,True,True,True,True,True,False,late
-False,False,False,True,False,True,True,True,False,True,False,True,True,True,True,False,late
-False,True,True,False,True,False,True,False,True,False,True,False,True,False,False,False,early
-True,False,False,False,False,True,True,True,False,False,False,False,False,True,True,False,early
-False,False,False,True,False,True,True,True,False,False,False,False,False,True,False,True,early
-True,False,True,False,False,True,False,True,False,False,False,False,False,True,True,True,early
-True,False,False,False,False,False,False,True,False,False,False,False,False,True,True,False,early
-False,False,False,False,True,False,False,True,True,False,False,False,False,False,False,True,late
-False,False,True,False,False,True,True,False,False,False,False,False,True,True,False,False,early
-False,True,False,False,False,True,False,True,True,False,False,False,False,True,True,True,late
-True,True,False,True,True,True,True,False,False,True,True,True,False,False,False,False,late
-True,False,False,True,True,False,True,False,False,False,True,False,False,True,True,False,early
-False,True,True,True,True,True,True,False,False,False,False,False,False,False,True,False,early
-False,True,False,True,False,False,True,True,False,False,False,False,False,False,False,True,early
-True,False,True,True,True,False,True,False,False,False,False,False,False,False,False,False,early
-True,True,True,False,True,False,False,False,False,False,True,False,False,False,False,True,early
-False,False,True,True,True,False,False,True,False,False,False,False,True,True,False,False,early
-False,False,False,False,True,True,True,True,False,False,False,True,True,True,False,True,early
-False,False,True,False,False,True,True,False,False,False,False,False,False,True,True,False,early
-False,False,True,False,False,False,False,False,False,False,True,True,False,True,True,True,late
-True,True,False,False,True,True,True,False,False,False,False,False,False,True,True,False,early
-False,True,True,True,False,True,True,True,False,True,False,False,False,True,True,True,early
-False,False,False,True,True,False,True,True,True,False,False,True,True,False,False,False,early
-False,False,False,True,True,True,False,False,False,False,False,False,False,True,False,True,late
-True,False,False,False,True,False,True,True,False,False,False,False,False,True,False,False,early
-True,False,True,True,True,True,True,True,True,False,False,False,False,True,False,True,early
-False,True,False,False,True,False,True,True,False,True,False,False,False,False,False,True,early
-True,False,False,False,True,True,True,False,True,False,False,False,False,True,True,False,late
-False,False,True,False,False,True,True,True,False,True,True,True,False,True,True,True,late
-True,False,True,True,False,True,True,True,False,False,False,False,True,True,False,False,early
-True,True,True,True,True,False,False,False,False,True,False,False,False,False,True,False,early
-True,False,True,True,True,False,True,False,False,False,False,False,False,True,False,False,early
-False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,early
-True,False,False,False,False,False,False,False,False,False,False,True,False,True,False,False,early
-True,False,True,True,False,True,True,False,False,True,False,False,True,True,True,False,late
-True,False,True,False,True,True,True,False,False,False,False,False,True,False,False,False,early
-False,False,False,True,False,True,True,True,False,False,False,False,False,True,True,False,early
-False,False,False,True,False,False,False,False,False,False,True,False,False,False,True,False,early
-True,True,True,False,False,True,False,False,False,False,False,False,True,False,False,True,late
-True,False,False,True,False,True,True,True,True,True,False,True,False,True,True,True,late
-True,False,False,False,True,False,False,True,False,True,False,False,False,True,False,False,early
-False,True,True,True,False,True,True,True,False,True,False,False,False,True,True,True,late
-False,False,False,False,False,False,True,False,False,False,True,True,False,False,True,False,early
-True,False,True,True,False,False,False,False,False,False,False,False,False,False,False,False,early
-True,False,False,False,False,True,False,True,False,True,False,False,False,False,False,False,early
-True,True,False,True,False,False,True,True,False,True,True,False,True,True,True,True,late
-True,False,False,True,True,True,True,True,False,True,False,False,True,True,False,True,late
-True,False,False,False,True,False,True,True,False,False,False,False,False,False,True,True,late
-False,False,True,True,True,True,True,False,False,True,False,True,True,False,True,False,late
-False,False,True,True,False,True,True,True,False,False,True,False,True,True,True,False,late
-False,True,False,False,False,True,True,True,False,True,False,False,False,True,True,True,late
-False,True,False,False,False,True,False,True,True,False,False,False,False,False,False,False,early
-False,False,True,False,True,True,True,True,False,True,False,False,False,True,True,True,early
-True,False,False,True,False,False,True,False,True,False,False,True,False,False,True,False,late
-True,True,True,True,True,True,True,True,True,True,False,False,True,False,True,False,early
-False,False,False,False,False,True,False,True,False,False,False,False,False,False,False,True,early
-False,False,False,False,True,True,False,True,False,False,False,False,False,True,False,True,late
-False,False,True,False,True,True,True,True,False,False,True,False,False,True,True,True,late
-False,True,False,True,False,True,False,True,False,True,True,False,False,True,True,True,late
-False,True,True,False,True,True,True,False,True,False,True,False,False,False,True,False,early
-False,True,True,False,False,False,True,True,True,True,False,False,True,False,True,True,early
-True,False,True,False,False,True,False,False,False,True,False,False,True,True,False,False,early
-True,False,True,True,True,True,True,False,False,True,True,False,False,False,False,False,early
-False,True,False,True,True,True,False,False,False,False,False,False,True,True,True,True,late
-False,False,True,False,True,False,False,False,False,False,False,False,False,False,True,False,early
diff --git a/docs/source/_data/demo.hdf5 b/docs/source/_data/demo.hdf5
deleted file mode 100644
index da0f6d3..0000000
Binary files a/docs/source/_data/demo.hdf5 and /dev/null differ
diff --git a/docs/source/_data/example.csv b/docs/source/_data/example.csv
deleted file mode 100644
index 2536db1..0000000
--- a/docs/source/_data/example.csv
+++ /dev/null
@@ -1,149 +0,0 @@
-info,pathology,pathology,pathology,pathology
-t_stage,I,II,III,IV
-early,1,0,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,1,0,0
-early,0,0,1,0
-early,0,0,1,0
-early,0,0,1,0
-early,0,0,1,0
-early,0,0,0,1
-early,1,1,0,0
-early,1,0,0,1
-early,1,0,0,1
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,1,0
-early,0,1,0,1
-early,0,1,0,1
-early,0,1,0,1
-early,0,0,1,1
-early,1,1,1,0
-early,1,1,1,0
-early,1,1,1,0
-early,1,1,1,0
-early,1,1,1,0
-early,0,1,1,1
-early,0,1,1,1
-early,0,1,1,1
-early,0,1,1,1
-early,0,1,1,1
-early,0,1,1,1
-early,0,1,1,1
-early,0,1,1,1
-early,0,1,1,1
-early,0,1,1,1
-early,0,1,1,1
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
-early,0,0,0,0
diff --git a/docs/source/_data/samples.hdf5 b/docs/source/_data/samples.hdf5
deleted file mode 100644
index 7145237..0000000
Binary files a/docs/source/_data/samples.hdf5 and /dev/null differ
diff --git a/lymph/diagnose_times.py b/lymph/diagnose_times.py
index 892c9fa..7a4b572 100644
--- a/lymph/diagnose_times.py
+++ b/lymph/diagnose_times.py
@@ -231,9 +231,21 @@ def set_params(self, **kwargs) -> None:
warnings.warn("Distribution is not updateable, skipping...")
- def draw(self) -> np.ndarray:
- """Draw sample of diagnose times from the PMF."""
- return np.random.choice(a=self.support, p=self.distribution)
+ def draw_diag_times(
+ self,
+ num: int | None = None,
+ rng: np.random.Generator | None = None,
+ seed: int = 42,
+ ) -> np.ndarray:
+ """Draw ``num`` samples of diagnose times from the stored PMF.
+
+ A random number generator can be provided as ``rng``. If ``None``, a new one
+ is initialized with the given ``seed`` (or ``42``, by default).
+ """
+ if rng is None:
+ rng = np.random.default_rng(seed)
+
+ return rng.choice(a=self.support, p=self.distribution, size=num)
class DistributionsUserDict(AbstractLookupDict):
diff --git a/lymph/graph.py b/lymph/graph.py
index 3182323..d1a6cb7 100644
--- a/lymph/graph.py
+++ b/lymph/graph.py
@@ -18,7 +18,7 @@
import numpy as np
-from lymph.helper import check_unique_names, trigger
+from lymph.helper import check_unique_names, comp_transition_tensor, trigger
class AbstractNode:
@@ -31,9 +31,9 @@ def __init__(
) -> None:
"""Make a new node.
- Upon initialization, the `name` and `state` of the node must be provided. The
- `state` must be one of the `allowed_states`. The constructor makes sure that
- the `allowed_states` are a list of ints, even when, e.g., a tuple of floats
+ Upon initialization, the ``name`` and ``state`` of the node must be provided.
+ The ``state`` must be one of the ``allowed_states``. The constructor makes sure that
+ the ``allowed_states`` are a list of ints, even when, e.g., a tuple of floats
is provided.
"""
self.name = name
@@ -103,9 +103,9 @@ def comp_obs_prob(
obs_table: np.ndarray,
log: bool = False,
) -> float:
- """Compute the probability of the diagnosis `obs`, given the current state.
+ """Compute the probability of the diagnosis ``obs``, given the current state.
- The `obs_table` is a 2D array with the rows corresponding to the states and
+ The ``obs_table`` is a 2D array with the rows corresponding to the states and
the columns corresponding to the observations. It encodes for each state and
diagnosis the corresponding probability.
"""
@@ -192,7 +192,7 @@ def comp_bayes_net_prob(self, log: bool = False) -> float:
def comp_trans_prob(self, new_state: int) -> float:
- """Compute the hidden Markov model's transition probability to a `new_state`."""
+ """Compute the hidden Markov model's transition probability to a ``new_state``."""
if new_state == self.state:
stay_prob = 1.
for edge in self.inc:
@@ -225,7 +225,7 @@ def __init__(
spread to the next LNL. The ``micro_mod`` parameter is a modifier for the spread
probability in case of only a microscopic node involvement.
"""
- self.trigger_callbacks = [self.delete_transition_tensor]
+ self.trigger_callbacks = []
if callbacks is not None:
self.trigger_callbacks += callbacks
@@ -409,78 +409,21 @@ def set_params(
self.set_micro_mod(micro)
- def comp_transition_tensor(self) -> np.ndarray:
- """Compute the transition factors of the edge.
-
- The returned array is of shape (p,c,c), where p is the number of states of the
- parent node and c is the number of states of the child node.
-
- Essentially, the tensors computed here contain most of the parametrization of
- the model. They are used to compute the transition matrix.
- """
- num_parent = len(self.parent.allowed_states)
- num_child = len(self.child.allowed_states)
- tensor = np.stack([np.eye(num_child)] * num_parent)
-
- # this should allow edges from trinary nodes to binary nodes
- pad = [0.] * (num_child - 2)
-
- if self.is_tumor_spread:
- # NOTE: Here we define how tumors spread to LNLs
- tensor[0, 0, :] = np.array([1. - self.spread_prob, self.spread_prob, *pad])
- return tensor
-
- if self.is_growth:
- # In the growth case, we can assume that two things:
- # 1. parent and child state are the same
- # 2. the child node is trinary
- tensor[1, 1, :] = np.array([0., (1 - self.spread_prob), self.spread_prob])
- return tensor
-
- if self.parent.is_trinary:
- # NOTE: here we define how the micro_mod affects the spread probability
- micro_spread = self.spread_prob * self.micro_mod
- tensor[1,0,:] = np.array([1. - micro_spread, micro_spread, *pad])
-
- macro_spread = self.spread_prob
- tensor[2,0,:] = np.array([1. - macro_spread, macro_spread, *pad])
-
- return tensor
-
- tensor[1,0,:] = np.array([1. - self.spread_prob, self.spread_prob, *pad])
- return tensor
-
-
- def get_transition_tensor(self) -> np.ndarray:
- """Return the transition tensor of the edge."""
- if not hasattr(self, "_transition_tensor"):
- self._transition_tensor = self.comp_transition_tensor()
-
- return self._transition_tensor
-
-
- def delete_transition_tensor(self) -> None:
- """Delete the transition tensor of the edge."""
- if hasattr(self, "_transition_tensor"):
- del self._transition_tensor
-
-
- transition_tensor = property(
- fget=get_transition_tensor,
- fdel=delete_transition_tensor,
- doc="""
- This tensor of the shape (s,e,e) contains the transition probabilities for the
- :py:class:`~LymphNodeLevel` at this instance's end to transition from any
- starting state to any new state, given any possible state of the
- :py:class:`~AbstractNode` instance at the start of this edge.
-
- The correct term can be accessed like this:
-
- .. code-block:: python
+ @property
+ def transition_tensor(self) -> np.ndarray:
+ """Return the transition tensor of the edge.
- edge.transition_tensor[parent_state, child_state, new_child_state]
+ See Also:
+ :py:function:`lymph.helper.comp_transition_tensor`
"""
- )
+ return comp_transition_tensor(
+ num_parent=len(self.parent.allowed_states),
+ num_child=len(self.child.allowed_states),
+ is_tumor_spread=self.is_tumor_spread,
+ is_growth=self.is_growth,
+ spread_prob=self.spread_prob,
+ micro_mod=self.micro_mod,
+ )
class Representation:
@@ -640,12 +583,64 @@ def growth_edges(self) -> dict[str, Edge]:
return {n: e for n, e in self.edges.items() if e.is_growth}
+ def parameter_hash(self) -> int:
+ """Compute a hash of the graph.
+
+ Note:
+ This is used to check if the graph has changed and the transition matrix
+ needs to be recomputed. It should not be used as a replacement for the
+ ``__hash__`` method, for two reasons:
+
+ 1. It may change over the lifetime of the object, whereas ``__hash__``
+ should be constant.
+ 2. It only takes into account the ``transition_tensor`` of the edges,
+ nothing else.
+
+ Example:
+
+ >>> graph_dict = {
+ ... ('tumor', 'T'): ['II', 'III'],
+ ... ('lnl', 'II'): ['III'],
+ ... ('lnl', 'III'): [],
+ ... }
+ >>> one_graph = Representation(graph_dict)
+ >>> another_graph = Representation(graph_dict)
+ >>> rng = np.random.default_rng(42)
+ >>> for one_edge, another_edge in zip(
+ ... one_graph.edges.values(), another_graph.edges.values()
+ ... ):
+ ... params_dict = one_edge.get_params(as_dict=True)
+ ... params_to_set = {k: rng.uniform() for k in params_dict}
+ ... one_edge.set_params(**params_to_set)
+ ... another_edge.set_params(**params_to_set)
+ >>> one_graph.parameter_hash() == another_graph.parameter_hash()
+ True
+ """
+ tensor_bytes = b""
+ for edge in self.edges.values():
+ tensor_bytes += edge.transition_tensor.tobytes()
+
+ return hash(tensor_bytes)
+
+
def to_dict(self) -> dict[tuple[str, str], set[str]]:
- """Returns graph representing this instance's nodes and egdes as dictionary."""
+ """Returns graph representing this instance's nodes and egdes as dictionary.
+
+ Example:
+
+ >>> graph_dict = {
+ ... ('tumor', 'T'): ['II', 'III'],
+ ... ('lnl', 'II'): ['III'],
+ ... ('lnl', 'III'): [],
+ ... }
+ >>> graph = Representation(graph_dict)
+ >>> graph.to_dict() == graph_dict
+ True
+ """
res = {}
- for node in self.nodes:
+ for node in self.nodes.values():
node_type = "tumor" if isinstance(node, Tumor) else "lnl"
- res[(node_type, node.name)] = {o.child.name for o in node.out}
+ res[(node_type, node.name)] = [o.child.name for o in node.out]
return res
@@ -655,14 +650,14 @@ def get_mermaid(self) -> str:
Example:
>>> graph_dict = {
- ... ("tumor", "T"): ["II", "III"],
- ... ("lnl", "II"): ["III"],
- ... ("lnl", "III"): [],
+ ... ('tumor', 'T'): ['II', 'III'],
+ ... ('lnl', 'II'): ['III'],
+ ... ('lnl', 'III'): [],
... }
>>> graph = Representation(graph_dict)
- >>> graph.edge_params["spread_T_to_II"].set_param(0.1)
- >>> graph.edge_params["spread_T_to_III"].set_param(0.2)
- >>> graph.edge_params["spread_II_to_III"].set_param(0.3)
+ >>> graph.edges["T_to_II"].spread_prob = 0.1
+ >>> graph.edges["T_to_III"].spread_prob = 0.2
+ >>> graph.edges["II_to_III"].spread_prob = 0.3
>>> print(graph.get_mermaid()) # doctest: +NORMALIZE_WHITESPACE
flowchart TD
T-->|10%| II
@@ -672,8 +667,8 @@ def get_mermaid(self) -> str:
"""
mermaid_graph = "flowchart TD\n"
- for idx, node in enumerate(self.nodes):
- for edge in self.nodes[idx].out:
+ for node in self.nodes.values():
+ for edge in node.out:
mermaid_graph += f"\t{node.name}-->|{edge.spread_prob:.0%}| {edge.child.name}\n"
return mermaid_graph
@@ -741,13 +736,13 @@ def state_list(self):
state 1 and all others are in state 0, etc. Essentially, it looks like binary
counting:
- >>> model = Unilateral(graph={
+ >>> graph = Representation(graph_dict={
... ("tumor", "T"): ["I", "II" , "III"],
... ("lnl", "I"): [],
... ("lnl", "II"): ["I", "III"],
... ("lnl", "III"): [],
... })
- >>> model.state_list
+ >>> graph.state_list
array([[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
diff --git a/lymph/helper.py b/lymph/helper.py
index 1d73872..d883b30 100644
--- a/lymph/helper.py
+++ b/lymph/helper.py
@@ -7,6 +7,7 @@
from typing import Any, Callable
import numpy as np
+from cachetools import LRUCache
from pandas._libs.missing import NAType
PatternType = dict[str, bool | NAType | None]
@@ -196,6 +197,58 @@ def change_base(
return pad + result[::-1]
+@lru_cache
+def comp_transition_tensor(
+ num_parent: int,
+ num_child: int,
+ is_tumor_spread: bool,
+ is_growth: bool,
+ spread_prob: float,
+ micro_mod: float,
+) -> np.ndarray:
+ """Compute the transition factors of the edge.
+
+ The returned array is of shape (p,c,c), where p is the number of states of the
+ parent node and c is the number of states of the child node.
+
+ Essentially, the tensors computed here contain most of the parametrization of
+ the model. They are used to compute the transition matrix.
+
+ This function globally computes and caches the transition tensors, such that we
+ do not need to worry about deleting and recomputing them when the parameters of the
+ edge change.
+ """
+ tensor = np.stack([np.eye(num_child)] * num_parent)
+
+ # this should allow edges from trinary nodes to binary nodes
+ pad = [0.] * (num_child - 2)
+
+ if is_tumor_spread:
+ # NOTE: Here we define how tumors spread to LNLs
+ tensor[0, 0, :] = np.array([1. - spread_prob, spread_prob, *pad])
+ return tensor
+
+ if is_growth:
+ # In the growth case, we can assume that two things:
+ # 1. parent and child state are the same
+ # 2. the child node is trinary
+ tensor[1, 1, :] = np.array([0., (1 - spread_prob), spread_prob])
+ return tensor
+
+ if num_parent == 3:
+ # NOTE: here we define how the micro_mod affects the spread probability
+ micro_spread = spread_prob * micro_mod
+ tensor[1,0,:] = np.array([1. - micro_spread, micro_spread, *pad])
+
+ macro_spread = spread_prob
+ tensor[2,0,:] = np.array([1. - macro_spread, macro_spread, *pad])
+
+ return tensor
+
+ tensor[1,0,:] = np.array([1. - spread_prob, spread_prob, *pad])
+ return tensor
+
+
def check_modality(modality: str, spsn: list):
"""Private method that checks whether all inserted values
are valid for a confusion matrix.
@@ -441,3 +494,28 @@ def __set__(self, instance: object, value: Any) -> None:
def __delete__(self, instance: object) -> None:
dict_like = self.__get__(instance)
dict_like.clear()
+
+
+def arg0_cache(maxsize: int = 128, cache_class = LRUCache) -> callable:
+ """Cache a function only based on its first argument.
+
+ One may choose which ``cache_class`` to use. This will be created with the
+ argument ``maxsize``.
+
+ Note:
+ The first argument is not passed on to the decorated function. It is basically
+ used as a key for the cache and it trusts the user to be sure that this is
+ sufficient.
+ """
+ def decorator(func: callable) -> callable:
+ cache = cache_class(maxsize=maxsize)
+
+ @wraps(func)
+ def wrapper(arg0, *args, **kwargs):
+ if arg0 not in cache:
+ cache[arg0] = func(*args, **kwargs)
+ return cache[arg0]
+
+ return wrapper
+
+ return decorator
diff --git a/lymph/matrix.py b/lymph/matrix.py
index 5f02e35..b09a4d7 100644
--- a/lymph/matrix.py
+++ b/lymph/matrix.py
@@ -5,13 +5,16 @@
from __future__ import annotations
import warnings
+from typing import Any
import numpy as np
import pandas as pd
+from cachetools import LRUCache
from lymph import models
from lymph.helper import (
AbstractLookupDict,
+ arg0_cache,
get_state_idx_matrix,
row_wise_kron,
tile_and_repeat,
@@ -20,11 +23,12 @@
def generate_transition(instance: models.Unilateral) -> np.ndarray:
"""Compute the transition matrix of the lymph model."""
- num_lnls = len(instance.graph.lnls)
+ lnls_list = list(instance.graph.lnls.values())
+ num_lnls = len(lnls_list)
num_states = 3 if instance.graph.is_trinary else 2
transition_matrix = np.ones(shape=(num_states**num_lnls, num_states**num_lnls))
- for i, lnl in enumerate(instance.graph.lnls.values()):
+ for i, lnl in enumerate(lnls_list):
current_state_idx = get_state_idx_matrix(
lnl_idx=i,
num_lnls=num_lnls,
@@ -43,7 +47,7 @@ def generate_transition(instance: models.Unilateral) -> np.ndarray:
0, current_state_idx, new_state_idx
]
else:
- parent_node_i = list(instance.graph.lnls.values()).index(edge.parent)
+ parent_node_i = lnls_list.index(edge.parent)
parent_state_idx = get_state_idx_matrix(
lnl_idx=parent_node_i,
num_lnls=num_lnls,
@@ -70,6 +74,15 @@ def generate_transition(instance: models.Unilateral) -> np.ndarray:
return transition_matrix
+cached_generate_transition = arg0_cache(maxsize=128, cache_class=LRUCache)(generate_transition)
+"""Cached version of :py:func:`generate_transition`.
+
+This expects the first argument to be a hashable object that is used instrad of the
+``instance`` argument of :py:func:`generate_transition`. It is intended to be used with
+the :py:meth:`~lymph.graph.Representation.parameter_hash` method of the graph.
+"""
+
+
def generate_observation(instance: models.Unilateral) -> np.ndarray:
"""Generate the observation matrix of the lymph model."""
num_lnls = len(instance.graph.lnls)
@@ -87,6 +100,16 @@ def generate_observation(instance: models.Unilateral) -> np.ndarray:
return observation_matrix
+cached_generate_observation = arg0_cache(maxsize=128, cache_class=LRUCache)(generate_observation)
+"""Cached version of :py:func:`generate_observation`.
+
+This expects the first argument to be a hashable object that is used instrad of the
+``instance`` argument of :py:func:`generate_observation`. It is intended to be used
+with the hash of all confusion matrices of the model's modalities, which is returned
+by the method :py:meth:`~lymph.modalities.ModalitiesUserDict.confusion_matrices_hash`.
+"""
+
+
def compute_encoding(
lnls: list[str],
pattern: pd.Series | dict[str, bool | int | str],
@@ -200,7 +223,7 @@ def generate_data_encoding(model: models.Unilateral, t_stage: str) -> np.ndarray
patients_with_t_stage = model.patient_data[has_t_stage]
result = np.ones(
- shape=(model.observation_matrix.shape[1], len(patients_with_t_stage)),
+ shape=(model.observation_matrix().shape[1], len(patients_with_t_stage)),
dtype=bool,
)
@@ -250,6 +273,27 @@ def __missing__(self, t_stage: str):
return self[t_stage]
+def generate_diagnose(model: models.Unilateral, t_stage: str) -> np.ndarray:
+ """Generate the diagnose matrix for a specific T-stage.
+
+ The diagnose matrix is the product of the observation matrix and the data matrix
+ for the given ``t_stage``.
+ """
+ return model.observation_matrix() @ model.data_matrices[t_stage]
+
+
+cached_generate_diagnose = arg0_cache(maxsize=128, cache_class=LRUCache)(generate_diagnose)
+"""Cached version of :py:func:`generate_diagnose`.
+
+The decorated function expects an additional first argument that should be unique for
+the combination of modalities and patient data. It is intended to be used with the
+joint hash of the modalities
+(:py:meth:`~lymph.modalities.ModalitiesUserDict.confusion_matrices_hash`) and the
+patient data hash that is always precomputed when a new dataset is loaded into the
+model (:py:meth:`~lymph.models.Unilateral.patient_data_hash`).
+"""
+
+
class DiagnoseUserDict(AbstractLookupDict):
"""``UserDict`` that dynamically generates the diagnose matrices for each T-stage.
@@ -267,9 +311,12 @@ class DiagnoseUserDict(AbstractLookupDict):
def __setitem__(self, __key, __value) -> None:
warnings.warn("Setting the diagnose matrices is not supported.")
- def __missing__(self, t_stage: str) -> np.ndarray:
- """If the matrix for a ``t_stage`` is missing, try to generate it lazily."""
- self.data[t_stage] = (
- self.model.observation_matrix @ self.model.data_matrices[t_stage]
- )
+ def __getitem__(self, key: Any) -> Any:
+ modalities_hash = self.model.modalities.confusion_matrices_hash()
+ patient_data_hash = self.model.patient_data_hash
+ joint_hash = hash((modalities_hash, patient_data_hash, key))
+ return cached_generate_diagnose(joint_hash, self.model, key)
+
+ def __missing__(self, t_stage: str):
+ """Create the diagnose matrix for a specific T-stage if necessary."""
return self[t_stage]
diff --git a/lymph/modalities.py b/lymph/modalities.py
index bd4b5e5..615cd13 100644
--- a/lymph/modalities.py
+++ b/lymph/modalities.py
@@ -110,32 +110,33 @@ def compute_confusion_matrix(self) -> np.ndarray:
ModalityDef = Union[Modality, np.ndarray, Tuple[float, float], List[float]]
class ModalitiesUserDict(AbstractLookupDict):
- """Dictionary storing instances of a diagnostic `Modality` for a lymph model.
+ """Dictionary storing instances of :py:class:`Modality` for a lymph model.
This class allows the user to specify the diagnostic modalities of a lymph model
- in a convenient way. The user may pass an instance of `Modality` - or one of its
- subclasses - directly. Especially for trinary models, it is recommended to use the
- subclasses `Clinical` and `Pathological` to avoid ambiguities.
+ in a convenient way. The user may pass an instance of :py:class:`Modality` - or one
+ of its subclasses - directly. Especially for trinary models, it is recommended to
+ use the subclasses :py:class:`Clinical` and :py:class:`Pathological` to avoid
+ ambiguities.
Alternatively, a simple tuple or list of floats may be passed, from which the first
two entries are interpreted as the specificity and sensitivity, respectively. For
- trinary models, we assume the modality to be `Clinical`.
+ trinary models, we assume the modality to be :py:class:`Clinical`.
For completely custom confusion matrices, the user may pass a numpy array directly.
- In the binary case, a valid `Modality` instance is constructed from the array. For
- trinary models, the array must have three rows, and is not possible anymore to
- infer the type of the modality or unambiguouse values for sensitivity and
+ In the binary case, a valid :py:class:`Modality` instance is constructed from the
+ array. For trinary models, the array must have three rows, and is not possible
+ anymore to infer the type of the modality or unambiguouse values for sensitivity and
specificity. This may lead to unexpected results when the confusion matrix is
recomputed accidentally at some point.
Examples:
- >>> binary_modalities = ModalityDict(is_trinary=False)
+ >>> binary_modalities = ModalitiesUserDict(is_trinary=False)
>>> binary_modalities["test"] = Modality(0.9, 0.8)
>>> binary_modalities["test"].confusion_matrix
array([[0.9, 0.1],
[0.2, 0.8]])
- >>> modalities = ModalityDict(is_trinary=True)
+ >>> modalities = ModalitiesUserDict(is_trinary=True)
>>> modalities["CT"] = Clinical(specificity=0.9, sensitivity=0.8)
>>> modalities["CT"].confusion_matrix
array([[0.9, 0.1],
@@ -211,3 +212,23 @@ def __setitem__(self, name: str, value: ModalityDef, / ) -> None:
@trigger
def __delitem__(self, key: str) -> None:
return super().__delitem__(key)
+
+
+ def confusion_matrices_hash(self) -> int:
+ """Compute a kind of hash from all confusion matrices.
+
+ Note:
+ This is used to check if some modalities have changed and the observation
+ matrix needs to be recomputed. It should not be used as a replacement for
+ the ``__hash__`` method, for two reasons:
+
+ 1. It may change over the lifetime of the object, whereas ``__hash__``
+ should be constant.
+ 2. It only takes into account the ``confusion_matric`` of the modality,
+ nothing else.
+ """
+ confusion_mat_bytes = b""
+ for modality in self.values():
+ confusion_mat_bytes += modality.confusion_matrix.tobytes()
+
+ return hash(confusion_mat_bytes)
diff --git a/lymph/models/bilateral.py b/lymph/models/bilateral.py
index bb49cf6..0a1fa7a 100644
--- a/lymph/models/bilateral.py
+++ b/lymph/models/bilateral.py
@@ -146,15 +146,18 @@ def __init__(
The ``is_symmetric`` dictionary defines which characteristics of the bilateral
model should be symmetric. Valid keys are:
- - ``"modalities"``: Whether the diagnostic modalities of the two neck sides
- are symmetric (default: ``True``).
- - ``"tumor_spread"``: Whether the spread probabilities from the tumor(s) to the
- LNLs are symmetric (default: ``False``). If this is set to ``True`` but
- the graphs are asymmetric, a warning is issued.
- - ``"lnl_spread"``: Whether the spread probabilities between the LNLs are
- symmetric (default: ``True`` if the graphs are symmetric, otherwise
- ``False``). If this is set to ``True`` but the graphs are asymmetric, a
- warning is issued.
+
+ - ``"modalities"``:
+ Whether the diagnostic modalities of the two neck sides are symmetric
+ (default: ``True``).
+ - ``"tumor_spread"``:
+ Whether the spread probabilities from the tumor(s) to the LNLs are
+ symmetric (default: ``False``). If this is set to ``True`` but the graphs
+ are asymmetric, a warning is issued.
+ - ``"lnl_spread"``:
+ Whether the spread probabilities between the LNLs are symmetric
+ (default: ``True`` if the graphs are symmetric, otherwise ``False``). If
+ this is set to ``True`` but the graphs are asymmetric, a warning is issued.
The ``unilateral_kwargs`` are passed to both instances of the unilateral model,
while the ``ipsilateral_kwargs`` and ``contralateral_kwargs`` are passed to the
@@ -234,14 +237,14 @@ def init_synchronization(self) -> None:
ipsi_tumor_edges = list(self.ipsi.graph.tumor_edges.values())
ipsi_lnl_edges = list(self.ipsi.graph.lnl_edges.values())
ipsi_edges = (
- ipsi_tumor_edges if self.is_symmetric["tumor_spread"] else []
- + ipsi_lnl_edges if self.is_symmetric["lnl_spread"] else []
+ (ipsi_tumor_edges if self.is_symmetric["tumor_spread"] else [])
+ + (ipsi_lnl_edges if self.is_symmetric["lnl_spread"] else [])
)
contra_tumor_edges = list(self.contra.graph.tumor_edges.values())
contra_lnl_edges = list(self.contra.graph.lnl_edges.values())
contra_edges = (
- contra_tumor_edges if self.is_symmetric["tumor_spread"] else []
- + contra_lnl_edges if self.is_symmetric["lnl_spread"] else []
+ (contra_tumor_edges if self.is_symmetric["tumor_spread"] else [])
+ + (contra_lnl_edges if self.is_symmetric["lnl_spread"] else [])
)
init_edge_sync(
@@ -384,7 +387,7 @@ def modalities(self) -> modalities.ModalitiesUserDict:
See Also:
:py:attr:`lymph.models.Unilateral.modalities`
The corresponding unilateral attribute.
- :py:class:`~lymph.descriptors.ModalitiesUserDict`
+ :py:class:`~lymph.modalities.ModalitiesUserDict`
The implementation of the descriptor class.
"""
if not self.is_symmetric["modalities"]:
@@ -471,9 +474,9 @@ def comp_joint_obs_dist(
"""
joint_state_dist = self.comp_joint_state_dist(t_stage=t_stage, mode=mode)
return (
- self.ipsi.observation_matrix.T
+ self.ipsi.observation_matrix().T
@ joint_state_dist
- @ self.contra.observation_matrix
+ @ self.contra.observation_matrix()
)
@@ -626,7 +629,7 @@ def comp_posterior_joint_state_dist(
diagnose_encoding = getattr(self, side).comp_diagnose_encoding(
given_diagnoses.get(side, {})
)
- observation_matrix = getattr(self, side).observation_matrix
+ observation_matrix = getattr(self, side).observation_matrix()
# vector with P(Z=z|X) for each state X. A data matrix for one "patient"
diagnose_given_state[side] = diagnose_encoding @ observation_matrix.T
@@ -696,36 +699,57 @@ def risk(
)
- def generate_dataset(
+ def draw_patients(
self,
- num_patients: int,
- stage_dist: dict[str, float],
+ num: int,
+ stage_dist: Iterable[float],
+ rng: np.random.Generator | None = None,
+ seed: int = 42,
+ **_kwargs,
) -> pd.DataFrame:
- """Generate/sample a pandas :class:`DataFrame` from the defined network.
+ """Draw ``num`` random patients from the parametrized model.
- Args:
- num_patients: Number of patients to generate.
- stage_dist: Probability to find a patient in a certain T-stage.
+ See Also:
+ :py:meth:`lymph.diagnose_times.Distribution.draw_diag_times`
+ Method to draw diagnose times from a distribution.
+ :py:meth:`lymph.models.Unilateral.draw_diagnoses`
+ Method to draw individual diagnoses from a unilateral model.
+ :py:meth:`lymph.models.Unilateral.draw_patients`
+ The unilateral method to draw a synthetic dataset.
"""
- # TODO: check if this still works
- drawn_t_stages, drawn_diag_times = self.diag_time_dists.draw(
- dist=stage_dist, size=num_patients
+ if rng is None:
+ rng = np.random.default_rng(seed)
+
+ if sum(stage_dist) != 1.:
+ warnings.warn("Sum of stage distribution is not 1. Renormalizing.")
+ stage_dist = np.array(stage_dist) / sum(stage_dist)
+
+ drawn_t_stages = rng.choice(
+ a=list(self.diag_time_dists.keys()),
+ p=stage_dist,
+ size=num,
)
+ drawn_diag_times = [
+ self.diag_time_dists[t_stage].draw_diag_times(rng=rng)
+ for t_stage in drawn_t_stages
+ ]
- drawn_obs_ipsi = self.ipsi._draw_patient_diagnoses(drawn_diag_times)
- drawn_obs_contra = self.contra._draw_patient_diagnoses(drawn_diag_times)
+ drawn_obs_ipsi = self.ipsi.draw_diagnoses(drawn_diag_times, rng=rng)
+ drawn_obs_contra = self.contra.draw_diagnoses(drawn_diag_times, rng=rng)
drawn_obs = np.concatenate([drawn_obs_ipsi, drawn_obs_contra], axis=1)
- # construct MultiIndex for dataset from stored modalities
+ # construct MultiIndex with "ipsi" and "contra" at top level to allow
+ # concatenation of the two separate drawn diagnoses
sides = ["ipsi", "contra"]
- modalities = list(self.modalities.keys())
- lnl_names = [lnl.name for lnl in self.ipsi.graph._lnls]
- multi_cols = pd.MultiIndex.from_product([sides, modalities, lnl_names])
+ modality_names = list(self.modalities.keys())
+ lnl_names = [lnl for lnl in self.ipsi.graph.lnls.keys()]
+ multi_cols = pd.MultiIndex.from_product([sides, modality_names, lnl_names])
- # create DataFrame
+ # reorder the column levels and thus also the individual columns to match the
+ # LyProX format without mixing up the data
dataset = pd.DataFrame(drawn_obs, columns=multi_cols)
dataset = dataset.reorder_levels(order=[1, 0, 2], axis="columns")
dataset = dataset.sort_index(axis="columns", level=0)
- dataset[('info', 'tumor', 't_stage')] = drawn_t_stages
+ dataset[('tumor', '1', 't_stage')] = drawn_t_stages
return dataset
diff --git a/lymph/models/unilateral.py b/lymph/models/unilateral.py
index e5ea940..93797b4 100644
--- a/lymph/models/unilateral.py
+++ b/lymph/models/unilateral.py
@@ -95,7 +95,6 @@ def __init__(
graph_dict=graph_dict,
tumor_state=tumor_state,
allowed_states=allowed_states,
- on_edge_change=[self.delete_transition_matrix],
)
if 0 >= max_time:
@@ -383,16 +382,6 @@ def comp_diagnose_prob(
return prob
- def _gen_obs_list(self):
- """Generates the list of possible observations."""
- possible_obs_list = []
- for modality in self.modalities.values():
- possible_obs = np.arange(modality.confusion_matrix.shape[1])
- for _ in self.graph.lnls:
- possible_obs_list.append(possible_obs.copy())
-
- self._obs_list = np.array(list(product(*possible_obs_list)))
-
@property
def obs_list(self):
"""Return the list of all possible observations.
@@ -424,20 +413,15 @@ def obs_list(self):
modality CT, the second two columns correspond to the same LNLs under the
pathology modality.
"""
- try:
- return self._obs_list
- except AttributeError:
- self._gen_obs_list()
- return self._obs_list
+ possible_obs_list = []
+ for modality in self.modalities.values():
+ possible_obs = np.arange(modality.confusion_matrix.shape[1])
+ for _ in self.graph.lnls:
+ possible_obs_list.append(possible_obs.copy())
- @obs_list.deleter
- def obs_list(self):
- """Delete the observation list. Necessary to pass as callback."""
- if hasattr(self, "_obs_list"):
- del self._obs_list
+ return np.array(list(product(*possible_obs_list)))
- @cached_property
def transition_matrix(self) -> np.ndarray:
"""Matrix encoding the probabilities to transition from one state to another.
@@ -448,9 +432,6 @@ def transition_matrix(self) -> np.ndarray:
transition from the :math:`i`-th state to the :math:`j`-th state. The states
are ordered as in the :py:attr:`lymph.graph.state_list`.
- This matrix is deleted every time the parameters along the edges of the graph
- are changed. It is lazily computed when it is next accessed.
-
See Also:
:py:func:`~lymph.descriptors.matrix.generate_transition`
The function actually computing the transition matrix.
@@ -462,21 +443,15 @@ def transition_matrix(self) -> np.ndarray:
... ("lnl", "II"): ["III"],
... ("lnl", "III"): [],
... })
- >>> model.assign_params(0.7, 0.3, 0.2)
- >>> model.transition_matrix
+ >>> model.assign_params(0.7, 0.3, 0.2) # doctest: +ELLIPSIS
+ (..., {})
+ >>> model.transition_matrix()
array([[0.21, 0.09, 0.49, 0.21],
- [0. , 0.3 , 0. , 0.7 ],
- [0. , 0. , 0.56, 0.44],
- [0. , 0. , 0. , 1. ]])
+ [0. , 0.3 , 0. , 0.7 ],
+ [0. , 0. , 0.56, 0.44],
+ [0. , 0. , 0. , 1. ]])
"""
- return matrix.generate_transition(self)
-
- def delete_transition_matrix(self):
- """Delete the transition matrix. Necessary to pass as callback."""
- try:
- del self.transition_matrix
- except AttributeError:
- pass
+ return matrix.cached_generate_transition(self.graph.parameter_hash(), self)
@smart_updating_dict_cached_property
@@ -497,13 +472,9 @@ def modalities(self) -> modalities.ModalitiesUserDict:
:py:class:`~lymph.descriptors.modalities.ModalitiesUserDict`
:py:class:`~lymph.descriptors.modalities.Modality`
"""
- return modalities.ModalitiesUserDict(
- is_trinary=self.is_trinary,
- trigger_callbacks=[self.delete_obs_list_and_matrix],
- )
+ return modalities.ModalitiesUserDict(is_trinary=self.is_trinary)
- @cached_property
def observation_matrix(self) -> np.ndarray:
"""The matrix encoding the probabilities to observe a certain diagnosis.
@@ -517,16 +488,9 @@ def observation_matrix(self) -> np.ndarray:
:py:func:`~lymph.descriptors.matrix.generate_observation`
The function actually computing the observation matrix.
"""
- return matrix.generate_observation(self)
-
- def delete_obs_list_and_matrix(self):
- """Delete the observation matrix. Necessary to pass as callback."""
- try:
- del self.observation_matrix
- except AttributeError:
- pass
-
- del self.obs_list
+ return matrix.cached_generate_observation(
+ self.modalities.confusion_matrices_hash(), self
+ )
@smart_updating_dict_cached_property
@@ -621,7 +585,7 @@ def load_patient_data(
if side not in patient_data[modality_name]:
raise ValueError(f"{side}lateral involvement data not found.")
- for name in self.graph.lnls:
+ for name in self.graph.lnls.keys():
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
modality_side_data = patient_data[modality_name, side]
@@ -639,12 +603,27 @@ def load_patient_data(
if t_stage not in patient_data["_model", "#", "t_stage"].values:
warnings.warn(f"No data for T-stage {t_stage} found.")
+ self._patient_data = patient_data
# Changes to the patient data require a recomputation of the data and
- # diagnose matrices. Clearing them will trigger this when they are next
- # accessed.
+ # diagnose matrices. For the data matrix, it is enough to clear the respective
+ # ``UserDict``. For the diagnose matrices, we need to delete the hash value of
+ # the patient data, so that the next time it is requested, a cache miss occurs
+ # and they are recomputed.
self.data_matrices.clear()
- self.diagnose_matrices.clear()
- self._patient_data = patient_data
+ try:
+ del self.patient_data_hash
+ except AttributeError:
+ pass
+
+
+ @cached_property
+ def patient_data_hash(self) -> int:
+ """Hash of the patient data.
+
+ This is used to check if the patient data has changed since the last time
+ the data and diagnose matrices were computed. If so, they are recomputed.
+ """
+ return hash(self.patient_data.to_numpy().tobytes())
@property
@@ -675,7 +654,7 @@ def evolve_dist(self, state_dist: np.ndarray, num_steps: int) -> np.ndarray:
is the number of steps ``num_steps``.
"""
for _ in range(num_steps):
- state_dist = state_dist @ self.transition_matrix
+ state_dist = state_dist @ self.transition_matrix()
return state_dist
@@ -742,7 +721,7 @@ def comp_obs_dist(self, t_stage: str = "early", mode: str = "HMM") -> np.ndarray
the :py:attr:`~data_matrices` and use these to compute the likelihood.
"""
state_dist = self.comp_state_dist(t_stage=t_stage, mode=mode)
- return state_dist @ self.observation_matrix
+ return state_dist @ self.observation_matrix()
def _bn_likelihood(self, log: bool = True) -> float:
@@ -887,7 +866,7 @@ def comp_posterior_state_dist(
diagnose_encoding = self.comp_diagnose_encoding(given_diagnoses)
# vector containing P(Z=z|X). Essentially a data matrix for one patient
- diagnose_given_state = diagnose_encoding @ self.observation_matrix.T
+ diagnose_given_state = diagnose_encoding @ self.observation_matrix().T
# vector P(X=x) of probabilities of arriving in state x (marginalized over time)
state_dist = self.comp_state_dist(t_stage, mode=mode)
@@ -945,58 +924,77 @@ def risk(
return marginalize_over_states @ posterior_state_dist
- def _draw_patient_diagnoses(
+ def draw_diagnoses(
self,
diag_times: list[int],
+ rng: np.random.Generator | None = None,
+ seed: int = 42,
) -> np.ndarray:
- """Draw random possible observations for a list of T-stages and
- diagnose times.
+ """Given some ``diag_times``, draw diagnoses for each LNL."""
+ if rng is None:
+ rng = np.random.default_rng(seed)
- Args:
- diag_times: List of diagnose times for each patient who's diagnose
- is supposed to be drawn.
- """
- # use the drawn diagnose times to compute probabilities over states and
- # diagnoses
- per_time_state_probs = self.comp_dist_evolution()
- per_patient_state_probs = per_time_state_probs[diag_times]
- per_patient_obs_probs = per_patient_state_probs @ self.observation_matrix
-
- # then, draw a diagnose from the possible ones
- obs_idx = np.arange(len(self.obs_list))
+ state_probs_given_time = self.comp_dist_evolution()[diag_times]
+ obs_probs_given_time = state_probs_given_time @ self.observation_matrix()
+
+ obs_indices = np.arange(len(self.obs_list))
drawn_obs_idx = [
- np.random.choice(obs_idx, p=obs_prob)
- for obs_prob in per_patient_obs_probs
+ np.random.choice(obs_indices, p=obs_prob)
+ for obs_prob in obs_probs_given_time
]
+
return self.obs_list[drawn_obs_idx].astype(bool)
- def generate_dataset(
+ def draw_patients(
self,
- num_patients: int,
- stage_dist: dict[str, float],
+ num: int,
+ stage_dist: Iterable[float],
+ rng: np.random.Generator | None = None,
+ seed: int = 42,
**_kwargs,
) -> pd.DataFrame:
- """Generate/sample a pandas :class:`DataFrame` from the defined network
- using the samples and diagnostic modalities that have been set.
+ """Draw ``num`` random patients from the model.
- Args:
- num_patients: Number of patients to generate.
- stage_dist: Probability to find a patient in a certain T-stage.
+ For this, a ``stage_dist``, i.e., a distribution over the T-stages, needs to
+ be defined. This must be an iterable of probabilities with as many elements as
+ there are defined T-stages in the model's :py:attr:`diag_time_dists` attribute.
+
+ A random number generator can be provided as ``rng``. If ``None``, a new one
+ is initialized with the given ``seed`` (or ``42``, by default).
+
+ See Also:
+ :py:meth:`lymph.diagnose_times.Distribution.draw_diag_times`
+ Method to draw diagnose times from a distribution.
+ :py:meth:`lymph.models.Unilateral.draw_diagnoses`
+ Method to draw individual diagnoses.
+ :py:meth:`lymph.models.Bilateral.draw_patients`
+ The corresponding bilateral method.
"""
- drawn_t_stages, drawn_diag_times = self.diag_time_dists.draw(
- prob_of_t_stage=stage_dist, size=num_patients
+ if rng is None:
+ rng = np.random.default_rng(seed)
+
+ if sum(stage_dist) != 1.:
+ warnings.warn("Sum of stage distribution is not 1. Renormalizing.")
+ stage_dist = np.array(stage_dist) / sum(stage_dist)
+
+ drawn_t_stages = rng.choice(
+ a=list(self.diag_time_dists.keys()),
+ p=stage_dist,
+ size=num,
)
+ drawn_diag_times = [
+ self.diag_time_dists[t_stage].draw_diag_times(rng=rng)
+ for t_stage in drawn_t_stages
+ ]
- drawn_obs = self._draw_patient_diagnoses(drawn_diag_times)
+ drawn_obs = self.draw_diagnoses(drawn_diag_times, rng=rng)
- # construct MultiIndex for dataset from stored modalities
modality_names = list(self.modalities.keys())
- lnl_names = self.graph.lnls.keys()
- multi_cols = pd.MultiIndex.from_product([modality_names, lnl_names])
+ lnl_names = list(self.graph.lnls.keys())
+ multi_cols = pd.MultiIndex.from_product([modality_names, ["ipsi"], lnl_names])
- # create DataFrame
dataset = pd.DataFrame(drawn_obs, columns=multi_cols)
- dataset[('info', 't_stage')] = drawn_t_stages
+ dataset[("tumor", "1", "t_stage")] = drawn_t_stages
return dataset
diff --git a/notebook/data/2009_sanguineti.csv b/notebook/data/2009_sanguineti.csv
deleted file mode 100644
index 6bf59d8..0000000
--- a/notebook/data/2009_sanguineti.csv
+++ /dev/null
@@ -1,147 +0,0 @@
-1,0,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,1,0,0
-0,0,1,0
-0,0,1,0
-0,0,1,0
-0,0,1,0
-0,0,0,1
-1,1,0,0
-1,0,0,1
-1,0,0,1
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,1,0
-0,1,0,1
-0,1,0,1
-0,1,0,1
-0,0,1,1
-1,1,1,0
-1,1,1,0
-1,1,1,0
-1,1,1,0
-1,1,1,0
-0,1,1,1
-0,1,1,1
-0,1,1,1
-0,1,1,1
-0,1,1,1
-0,1,1,1
-0,1,1,1
-0,1,1,1
-0,1,1,1
-0,1,1,1
-0,1,1,1
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
-0,0,0,0
diff --git a/notebook/data/cross-validation_set1.csv b/notebook/data/cross-validation_set1.csv
deleted file mode 100644
index 8ae2a2f..0000000
--- a/notebook/data/cross-validation_set1.csv
+++ /dev/null
@@ -1,51 +0,0 @@
-,info,path,path,path,path
-,t_stage,I,II,III,IV
-95,early,0,1,1,1
-125,early,0,0,0,0
-59,early,1,1,0,0
-76,early,0,1,1,0
-57,early,0,0,1,0
-118,early,0,0,0,0
-14,early,0,1,0,0
-47,early,0,1,0,0
-31,early,0,1,0,0
-94,early,0,1,1,1
-6,early,0,1,0,0
-23,early,0,1,0,0
-41,early,0,1,0,0
-102,early,0,1,1,1
-81,early,0,1,1,0
-2,early,0,1,0,0
-89,early,1,1,1,0
-21,early,0,1,0,0
-82,early,0,1,1,0
-65,early,0,1,1,0
-97,early,0,1,1,1
-74,early,0,1,1,0
-10,early,0,1,0,0
-137,early,0,0,0,0
-35,early,0,1,0,0
-20,early,0,1,0,0
-40,early,0,1,0,0
-12,early,0,1,0,0
-73,early,0,1,1,0
-88,early,1,1,1,0
-91,early,1,1,1,0
-108,early,0,0,0,0
-8,early,0,1,0,0
-109,early,0,0,0,0
-133,early,0,0,0,0
-115,early,0,0,0,0
-62,early,0,1,1,0
-67,early,0,1,1,0
-139,early,0,0,0,0
-15,early,0,1,0,0
-143,early,0,0,0,0
-18,early,0,1,0,0
-70,early,0,1,1,0
-134,early,0,0,0,0
-3,early,0,1,0,0
-45,early,0,1,0,0
-24,early,0,1,0,0
-53,early,0,1,0,0
-26,early,0,1,0,0
diff --git a/notebook/data/cross-validation_set2.csv b/notebook/data/cross-validation_set2.csv
deleted file mode 100644
index 85347d3..0000000
--- a/notebook/data/cross-validation_set2.csv
+++ /dev/null
@@ -1,51 +0,0 @@
-,info,path,path,path,path
-,t_stage,I,II,III,IV
-122,early,0,0,0,0
-135,early,0,0,0,0
-83,early,0,1,0,1
-116,early,0,0,0,0
-63,early,0,1,1,0
-0,early,1,0,0,0
-123,early,0,0,0,0
-106,early,0,0,0,0
-99,early,0,1,1,1
-39,early,0,1,0,0
-85,early,0,1,0,1
-140,early,0,0,0,0
-64,early,0,1,1,0
-136,early,0,0,0,0
-66,early,0,1,1,0
-90,early,1,1,1,0
-60,early,1,0,0,1
-34,early,0,1,0,0
-113,early,0,0,0,0
-86,early,0,0,1,1
-51,early,0,1,0,0
-145,early,0,0,0,0
-11,early,0,1,0,0
-101,early,0,1,1,1
-50,early,0,1,0,0
-93,early,0,1,1,1
-30,early,0,1,0,0
-1,early,0,1,0,0
-32,early,0,1,0,0
-49,early,0,1,0,0
-61,early,1,0,0,1
-17,early,0,1,0,0
-84,early,0,1,0,1
-13,early,0,1,0,0
-138,early,0,0,0,0
-127,early,0,0,0,0
-48,early,0,1,0,0
-44,early,0,1,0,0
-98,early,0,1,1,1
-27,early,0,1,0,0
-54,early,0,0,1,0
-141,early,0,0,0,0
-37,early,0,1,0,0
-25,early,0,1,0,0
-87,early,1,1,1,0
-5,early,0,1,0,0
-80,early,0,1,1,0
-130,early,0,0,0,0
-56,early,0,0,1,0
diff --git a/notebook/data/cross-validation_set3.csv b/notebook/data/cross-validation_set3.csv
deleted file mode 100644
index 4e6fa65..0000000
--- a/notebook/data/cross-validation_set3.csv
+++ /dev/null
@@ -1,51 +0,0 @@
-,info,path,path,path,path
-,t_stage,I,II,III,IV
-4,early,0,1,0,0
-7,early,0,1,0,0
-9,early,0,1,0,0
-16,early,0,1,0,0
-19,early,0,1,0,0
-22,early,0,1,0,0
-28,early,0,1,0,0
-29,early,0,1,0,0
-33,early,0,1,0,0
-36,early,0,1,0,0
-38,early,0,1,0,0
-42,early,0,1,0,0
-43,early,0,1,0,0
-46,early,0,1,0,0
-52,early,0,1,0,0
-55,early,0,0,1,0
-58,early,0,0,0,1
-68,early,0,1,1,0
-69,early,0,1,1,0
-71,early,0,1,1,0
-72,early,0,1,1,0
-75,early,0,1,1,0
-77,early,0,1,1,0
-78,early,0,1,1,0
-79,early,0,1,1,0
-92,early,0,1,1,1
-96,early,0,1,1,1
-100,early,0,1,1,1
-103,early,0,0,0,0
-104,early,0,0,0,0
-105,early,0,0,0,0
-107,early,0,0,0,0
-110,early,0,0,0,0
-111,early,0,0,0,0
-112,early,0,0,0,0
-114,early,0,0,0,0
-117,early,0,0,0,0
-119,early,0,0,0,0
-120,early,0,0,0,0
-121,early,0,0,0,0
-124,early,0,0,0,0
-126,early,0,0,0,0
-128,early,0,0,0,0
-129,early,0,0,0,0
-131,early,0,0,0,0
-132,early,0,0,0,0
-142,early,0,0,0,0
-144,early,0,0,0,0
-146,early,0,0,0,0
diff --git a/notebook/figures/HMM_BN_risk_comparison.png b/notebook/figures/HMM_BN_risk_comparison.png
deleted file mode 100644
index df09a84..0000000
Binary files a/notebook/figures/HMM_BN_risk_comparison.png and /dev/null differ
diff --git a/notebook/figures/HMM_BN_risk_comparison.svg b/notebook/figures/HMM_BN_risk_comparison.svg
deleted file mode 100644
index bff16f3..0000000
--- a/notebook/figures/HMM_BN_risk_comparison.svg
+++ /dev/null
@@ -1,10432 +0,0 @@
-
-
-
diff --git a/notebook/figures/HMM_evo_matrix.png b/notebook/figures/HMM_evo_matrix.png
deleted file mode 100644
index 7188555..0000000
Binary files a/notebook/figures/HMM_evo_matrix.png and /dev/null differ
diff --git a/notebook/figures/HMM_evo_matrix.svg b/notebook/figures/HMM_evo_matrix.svg
deleted file mode 100644
index f28061d..0000000
--- a/notebook/figures/HMM_evo_matrix.svg
+++ /dev/null
@@ -1,3783 +0,0 @@
-
-
-
diff --git a/notebook/figures/HMM_evolution.png b/notebook/figures/HMM_evolution.png
deleted file mode 100644
index e4c3bb9..0000000
Binary files a/notebook/figures/HMM_evolution.png and /dev/null differ
diff --git a/notebook/figures/HMM_evolution.svg b/notebook/figures/HMM_evolution.svg
deleted file mode 100644
index faec6bb..0000000
--- a/notebook/figures/HMM_evolution.svg
+++ /dev/null
@@ -1,3975 +0,0 @@
-
-
-
diff --git a/notebook/figures/HMM_risk_increaseP.png b/notebook/figures/HMM_risk_increaseP.png
deleted file mode 100644
index 6d68223..0000000
Binary files a/notebook/figures/HMM_risk_increaseP.png and /dev/null differ
diff --git a/notebook/figures/HMM_risk_increaseP.svg b/notebook/figures/HMM_risk_increaseP.svg
deleted file mode 100644
index 7ba78f1..0000000
--- a/notebook/figures/HMM_risk_increaseP.svg
+++ /dev/null
@@ -1,2505 +0,0 @@
-
-
-
diff --git a/notebook/figures/corner_BN.png b/notebook/figures/corner_BN.png
deleted file mode 100644
index 21315d9..0000000
Binary files a/notebook/figures/corner_BN.png and /dev/null differ
diff --git a/notebook/figures/corner_BN.svg b/notebook/figures/corner_BN.svg
deleted file mode 100644
index 44c07b5..0000000
--- a/notebook/figures/corner_BN.svg
+++ /dev/null
@@ -1,65852 +0,0 @@
-
-
-
diff --git a/notebook/figures/corner_HMM.png b/notebook/figures/corner_HMM.png
deleted file mode 100644
index b5db758..0000000
Binary files a/notebook/figures/corner_HMM.png and /dev/null differ
diff --git a/notebook/figures/corner_HMM.svg b/notebook/figures/corner_HMM.svg
deleted file mode 100644
index 332864e..0000000
--- a/notebook/figures/corner_HMM.svg
+++ /dev/null
@@ -1,65993 +0,0 @@
-
-
-
diff --git a/notebook/figures/corner_simultaneous.png b/notebook/figures/corner_simultaneous.png
deleted file mode 100644
index 8a1e79f..0000000
Binary files a/notebook/figures/corner_simultaneous.png and /dev/null differ
diff --git a/notebook/figures/corner_simultaneous.svg b/notebook/figures/corner_simultaneous.svg
deleted file mode 100644
index 0a57db7..0000000
--- a/notebook/figures/corner_simultaneous.svg
+++ /dev/null
@@ -1,87433 +0,0 @@
-
-
-
diff --git a/notebook/figures/multi_length_risk.png b/notebook/figures/multi_length_risk.png
deleted file mode 100644
index 74ae167..0000000
Binary files a/notebook/figures/multi_length_risk.png and /dev/null differ
diff --git a/notebook/figures/multi_length_risk.svg b/notebook/figures/multi_length_risk.svg
deleted file mode 100644
index 0cf89b6..0000000
--- a/notebook/figures/multi_length_risk.svg
+++ /dev/null
@@ -1,13973 +0,0 @@
-
-
-
diff --git a/notebook/figures/rate_decay_theory_vs_sampled.png b/notebook/figures/rate_decay_theory_vs_sampled.png
deleted file mode 100644
index d354dbb..0000000
Binary files a/notebook/figures/rate_decay_theory_vs_sampled.png and /dev/null differ
diff --git a/notebook/figures/rate_decay_theory_vs_sampled.svg b/notebook/figures/rate_decay_theory_vs_sampled.svg
deleted file mode 100644
index e32d74b..0000000
--- a/notebook/figures/rate_decay_theory_vs_sampled.svg
+++ /dev/null
@@ -1,3074 +0,0 @@
-
-
-
-
diff --git a/notebook/figures/simple_rate_decay.png b/notebook/figures/simple_rate_decay.png
deleted file mode 100644
index a0531ff..0000000
Binary files a/notebook/figures/simple_rate_decay.png and /dev/null differ
diff --git a/notebook/figures/simple_rate_decay.svg b/notebook/figures/simple_rate_decay.svg
deleted file mode 100644
index 2e9ca86..0000000
--- a/notebook/figures/simple_rate_decay.svg
+++ /dev/null
@@ -1,1644 +0,0 @@
-
-
-
diff --git a/notebook/figures/simultaneous_learnedP.png b/notebook/figures/simultaneous_learnedP.png
deleted file mode 100644
index 5011ada..0000000
Binary files a/notebook/figures/simultaneous_learnedP.png and /dev/null differ
diff --git a/notebook/figures/simultaneous_learnedP.svg b/notebook/figures/simultaneous_learnedP.svg
deleted file mode 100644
index 3835572..0000000
--- a/notebook/figures/simultaneous_learnedP.svg
+++ /dev/null
@@ -1,3187 +0,0 @@
-
-
-
diff --git a/notebook/figures/simultaneous_risk.png b/notebook/figures/simultaneous_risk.png
deleted file mode 100644
index bc17a29..0000000
Binary files a/notebook/figures/simultaneous_risk.png and /dev/null differ
diff --git a/notebook/figures/simultaneous_risk.svg b/notebook/figures/simultaneous_risk.svg
deleted file mode 100644
index 655435c..0000000
--- a/notebook/figures/simultaneous_risk.svg
+++ /dev/null
@@ -1,4117 +0,0 @@
-
-
-
diff --git a/notebook/figures/transition_matrix.png b/notebook/figures/transition_matrix.png
deleted file mode 100644
index 9b36bdc..0000000
Binary files a/notebook/figures/transition_matrix.png and /dev/null differ
diff --git a/notebook/figures/transition_matrix.svg b/notebook/figures/transition_matrix.svg
deleted file mode 100644
index d0cc8ad..0000000
--- a/notebook/figures/transition_matrix.svg
+++ /dev/null
@@ -1,2052 +0,0 @@
-
-
-
diff --git a/notebook/lymph.mplstyle b/notebook/lymph.mplstyle
deleted file mode 100644
index afc470a..0000000
--- a/notebook/lymph.mplstyle
+++ /dev/null
@@ -1,81 +0,0 @@
-#### COLORS
-# usz_blue = '005ea8'
-# usz_green = '00afa5'
-# usz_red = 'ae0060'
-# usz_orange = 'f17900'
-# usz_gray = 'c5d5db'
-
-#### LATEX
-mathtext.cal : cursive
-mathtext.rm : serif
-mathtext.tt : monospace
-mathtext.it : serif:italic
-mathtext.bf : serif:bold
-mathtext.sf : serif
-mathtext.fontset : custom
-
-#### FONT
-font.size : 9.0
-font.serif : Cambria
-font.sans-serif : Calibri
-font.monospace : Ubuntu Mono
-font.cursive : Lucida Calligraphy
-
-#### LINES
-lines.linewidth : 1
-lines.markersize : 3
-
-#### AXES
-axes.grid : True
-axes.titlesize : large
-axes.titleweight : bold
-axes.titlepad : 8.0
-axes.labelsize : medium
-axes.labelweight : regular
-axes.labelpad : 2.0
-axes.linewidth : 0.5
-axes.formatter.limits : -3, 4
-axes.prop_cycle : cycler(linestyle=['-', '--', '-.']) * cycler(color=['005ea8', 'f17900', 'ae0060', '00afa5'])
-
-### GRIDS
-grid.color : k # grid color
-grid.linestyle : - # dotted
-grid.linewidth : 0.5 # in points
-grid.alpha : 0.1 # transparency, between 0.0 and 1.0
-
-#### TICKS
-xtick.direction : in
-xtick.major.width : 0.5
-xtick.major.size : 3.0
-xtick.minor.visible : False
-xtick.color : black
-xtick.labelsize : x-small
-xtick.major.pad : 1.5 ## distance to major tick label in points
-xtick.minor.pad : 1.4 ## distance to the minor tick label in points
-
-ytick.direction : in
-ytick.major.width : 0.5
-ytick.major.size : 3.0
-ytick.minor.visible : False
-ytick.color : black
-ytick.labelsize : x-small
-ytick.major.pad : 1.5 ## distance to major tick label in points
-ytick.minor.pad : 1.4 ## distance to the minor tick label in points
-
-### FIGURE
-figure.dpi : 150
-figure.titlesize : large
-figure.titleweight : bold
-figure.subplot.wspace : 0.2
-figure.subplot.hspace : 0.2
-figure.autolayout : False
-
-#### SAVE
-savefig.format : svg
-
-#### LEGEND
-legend.markerscale : 0.75
-legend.title_fontsize : small
-legend.fontsize : x-small
-legend.borderpad : 0.2
-legend.labelspacing : 0.1
diff --git a/notebook/notebook_requirements.txt b/notebook/notebook_requirements.txt
deleted file mode 100644
index fd857bc..0000000
--- a/notebook/notebook_requirements.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-jupyter
-tqdm
-emcee
-matplotlib
-git+https://github.com/rmnldwg/corner.py#egg=corner
diff --git a/notebook/results_and_plots.ipynb b/notebook/results_and_plots.ipynb
deleted file mode 100644
index b167b34..0000000
--- a/notebook/results_and_plots.ipynb
+++ /dev/null
@@ -1,2154 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# A hidden Markov model for lymphatic tumour progression in head and neck cancer\n",
- "\n",
- "Roman Ludwig¹*, Bertrand Pouymayou¹, Panagiotis Balermpas¹ and Jan Unkelbach¹\n",
- "\n",
- "¹ Departement of Radiation Oncology, University Hospital Zurich, Switzerland \\\n",
- "\\* [roman.ludwig@usz.ch](mailto:roman.ludwig@usz.ch)\n",
- "\n",
- "***\n",
- "\n",
- "## Abstract [🔗](https://www.nature.com/articles/s41598-021-91544-1)\n",
- "\n",
- "Currently , elective clinical target volume (CTV-N) definition for head & neck squamous cell carcinoma (HNSCC) is mostly based on the prevalence of nodal involvement for a given tumor location. In this work, we propose a probabilistic model for lymphatic metastatic spread that can quantify the risk of microscopic involvement in lymph node levels (LNL) given the location of macroscopic metastases and T-stage. This may allow for further personalized CTV-N definition based on an individual patient’s state of disease. \\\n",
- "We model the patient's state of metastatic lymphatic progression as a collection of hidden binary random variables that indicate the involvement of LNLs. In addition, each LNL is associated with observed binary random variables that indicate whether macroscopic metastases are detected. A hidden Markov model (HMM) is used to compute the probabilities of transitions between states over time. The underlying graph of the HMM represents the anatomy of the lymphatic drainage system. Learning of the transition probabilities is done via Markov chain Monte Carlo sampling and is based on a dataset of HNSCC patients in whom involvement of individual LNLs was report-ed. \\\n",
- "The model is demonstrated for ipsilateral metastatic spread in oropharyngeal HNSCC patients. We demonstrate the model's capability to quantify the risk of microscopic involvement in levels III and IV, depending on whether macroscopic metastases are observed in the upstream levels II and III, and depending on T-stage. \\\n",
- "In conclusion, the statistical model of lymphatic progression may inform future, more personal-ized, guidelines on which LNL to include in the elective CTV. However, larger multi-institutional datasets for model parameter learning are required for that. \n",
- "\n",
- "***\n",
- "\n",
- "## Introduction\n",
- "\n",
- "This notebook contains all the code we ran to produce results and plots for our paper. It is intended to be read alongside the paper if one wants to understand better how exactly we implemented and used the methodology introduced in the paper. However, this notebook is NOT a stand-alone work but only a supplement.\n",
- "\n",
- "### Imports\n",
- "\n",
- "First, we import some libraries that are necessary for our implementation. [`lymph`](https://lymph-model.readthedocs.io/en/latest/) is the package we wrote, while [`corner`](https://corner.readthedocs.io/en/latest/) and [`emcee`](https://emcee.readthedocs.io/en/stable/) (see also the corresponding [arXiv paper](https://arxiv.org/abs/1202.3665) on this package) are both packages by Dan Foreman-Mackey & contributors. [`matplotlib`](https://matplotlib.org/stable/index.html) is an extensive and powerful plotting library. All other packages are standard and included with any default python installation."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Additional requirements to install for running this notebook\n",
- "!pip install -r notebook_requirements.txt"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# basic stuff\n",
- "import numpy as np\n",
- "import scipy as sp\n",
- "from scipy import stats\n",
- "import pandas as pd\n",
- "from multiprocessing import Pool\n",
- "import datetime as dt\n",
- "\n",
- "# plotting\n",
- "import matplotlib.pyplot as plt\n",
- "import matplotlib.gridspec as gs\n",
- "from matplotlib import font_manager\n",
- "from matplotlib.colors import LinearSegmentedColormap\n",
- "from matplotlib.colors import ListedColormap\n",
- "from cycler import cycler\n",
- "import corner\n",
- "\n",
- "# sampling\n",
- "import emcee\n",
- "\n",
- "# our package\n",
- "import lymph"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Settings\n",
- "\n",
- "The variables below are meant to be constants. `MAX_T` is the length of the used binomial time-prior. `DRAW_SAMPLES` is a `bool` that defines whether new samples will be drawn for the plots and computations or alreadz drawn samples should be loaded from file. `SEED` is the seed for the random number generator in `numpy` and aims at reproducability. `SAVE_FIGURES` defines whether or not generated figures should be saved to disk."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "MAX_T = 10\n",
- "DRAW_SAMPLES = True\n",
- "SEED = 42\n",
- "SAVE_FIGURES = True"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Colors\n",
- "\n",
- "We chose the colors of the University Hospital Zurich's corporate design for the default colors of our plots."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# USZ colors\n",
- "usz_blue = '#005ea8'\n",
- "usz_green = '#00afa5'\n",
- "usz_red = '#ae0060'\n",
- "usz_orange = '#f17900'\n",
- "usz_gray = '#c5d5db'\n",
- "\n",
- "# colormaps\n",
- "white_to_blue = LinearSegmentedColormap.from_list(\"white_to_blue\", \n",
- " [\"#ffffff\", usz_blue], \n",
- " N=256)\n",
- "white_to_green = LinearSegmentedColormap.from_list(\"white_to_green\", \n",
- " [\"#ffffff\", usz_green], \n",
- " N=256)\n",
- "green_to_red = LinearSegmentedColormap.from_list(\"green_to_red\", \n",
- " [usz_green, usz_red], \n",
- " N=256)\n",
- "\n",
- "h = usz_gray.lstrip('#')\n",
- "gray_rgba = tuple(int(h[i:i+2], 16) / 255. for i in (0, 2, 4)) + (1.0,)\n",
- "tmp = LinearSegmentedColormap.from_list(\"tmp\", [usz_green, usz_red], N=128)\n",
- "tmp = tmp(np.linspace(0., 1., 128))\n",
- "tmp = np.vstack([np.array([gray_rgba]*128), tmp])\n",
- "halfGray_halfGreenToRed = ListedColormap(tmp)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Plot Settings\n",
- "\n",
- "Here we define a function to consistently set the size of our plots and we load the default settings regarding font size etc from an `.mplstyle` file."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "plt.style.use(\"./lymph.mplstyle\")\n",
- "\n",
- "def set_size(width=\"single\", unit=\"cm\", ratio=\"golden\"):\n",
- " if width == \"single\":\n",
- " width = 10\n",
- " elif width == \"full\":\n",
- " width = 16\n",
- " else:\n",
- " try:\n",
- " width = width\n",
- " except:\n",
- " width = 10\n",
- " \n",
- " if unit == \"cm\":\n",
- " width = width / 2.54\n",
- " \n",
- " if ratio == \"golden\":\n",
- " ratio = 1.618\n",
- " else:\n",
- " ratio = ratio\n",
- " \n",
- " try:\n",
- " height = width / ratio\n",
- " except:\n",
- " height = width / 1.618\n",
- " \n",
- " return (width, height)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# only works on Windows when using WSL 2\n",
- "font_dirs = [\"/usr/share/fonts/truetype/dejavu\", \"/mnt/c/Windows/Fonts\"]\n",
- "font_files = font_manager.findSystemFonts(fontpaths=font_dirs)\n",
- "for font_file in font_files:\n",
- " font_manager.fontManager.addfont(font_file)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Data\n",
- "\n",
- "The dataset we use here was reconstructed by [Pouymayou et al.](#pouymayou) from [Sanguineti et al.](#sanguineti). Here, it is loaded from file. Note that $N_0$ patients were added to make up 30% of the dataset."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# parameters and risk estimation from Pouymayou et al [2]\n",
- "pouymayou_params = [0.061, 0.638, 0.094, 0.057, 0.08, 0.331, 0.242]\n",
- "pouymayou_MLrisk = np.array([[ 1.53, 1.64, 1.66, 1.56], \n",
- " [24.67, 81.55, 89.64, 39.06], \n",
- " [ 4.48, 9.97, 59.93, 38.75], \n",
- " [ 1.83, 2.25, 6.05, 4.44]]) / 100.\n",
- "\n",
- "# data reconstructed from Sanguineti et al [1] (without N0 patients)\n",
- "data = pd.read_csv(\"./data/2009_sanguineti.csv\", \n",
- " header=None, \n",
- " names=['I', 'II', 'III', 'IV'])\n",
- "\n",
- "# inserting info about the \"T-stage\"\n",
- "data.insert(0, \"t_stage\", [\"early\"] * data.shape[0])\n",
- "\n",
- "columns = pd.MultiIndex.from_arrays([['info', 'path', 'path', 'path', 'path'], \n",
- " ['t_stage', 'I', 'II', 'III', 'IV']])\n",
- "data = pd.DataFrame(data.values.tolist(), columns=columns)\n",
- "data.head()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Table 1:__ Rows of patients and columns of T-stage, as well as nodal involvement patterns. Reconstructed from [[1]](#sanguineti).*\n",
- "\n",
- "## Inference\n",
- "\n",
- "In this section we will set up everything necessary to perform inference on the dataset recostructed by [[2]](#pouymayou) from [[1]](#sanguineti).\n",
- "\n",
- "### Lymphatic Network\n",
- "\n",
- "Here, we need to define the underlying anatomical network of lymph node levels (LNLs), as it is also defined in [[2]](#pouymayou). \n",
- "\n",
- "The tumor and every LNL are represented by a key in a dictionary called `graph` each. The respective value in the dictionary is a list of nodes it drains to."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# abstract representation of the lymphatic network\n",
- "graph = {('tumor', 'primary') : ['I', 'II', 'III', 'IV'], \n",
- " ('lnl', 'I') : ['II'], \n",
- " ('lnl', 'II') : ['III'], \n",
- " ('lnl', 'III') : ['IV'], \n",
- " ('lnl', 'IV') : []}\n",
- "\n",
- "model = lymph.System(graph=graph)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Time Prior\n",
- "We need to choose a time prior for the parameter learning. A [Binomial distribution](https://en.wikipedia.org/wiki/Binomial_distribution) was chosen for its intuitively meaningful shape and simple structure."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "fig, ax = plt.subplots(figsize=set_size());\n",
- "\n",
- "for i,p in enumerate([0.4, 0.55, 0.7]):\n",
- " ax.plot(np.arange(MAX_T+1), \n",
- " sp.stats.binom.pmf(np.arange(MAX_T+1), MAX_T, p), \n",
- " \"o-\", label=f\"$p = {{{p}}}$\")\n",
- " \n",
- "ax.set_xlim([0,10]);\n",
- "ax.set_ylim([-0.02,0.35])\n",
- "ax.set_xlabel(\"time step $t$\");\n",
- "ax.set_ylabel(r\"$p(t)$\");\n",
- "ax.tick_params();\n",
- "ax.legend();"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 1:__ Binomial distribution with three different p-parameters. They represent the probability of diagnosis at time-step $t$ for three different scenarios, like T-category.*"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# creating a dictionary of time priors for the sampling process\n",
- "p_early = 0.4\n",
- "time_dists = {}\n",
- "time_dists[\"early\"] = sp.stats.binom.pmf(np.arange(MAX_T+1), MAX_T, p_early)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Learning (HMM)\n",
- "\n",
- "Using the anatomical model along with the data, we can now load the data into the lymph system class. To do so we must first assign a dictionary called `modality_spsn` to the lymphatic system or pass it together with the data in the `load_data` method. **``spsn``** stands for **sp**ecificity and **s**e**n**sitivity. In that dictionary, one has to define a list ``[specificty, sensitivity]`` for each diagnostic modality one is interested in. So if we, for example, have MRI and CT data, then we would pass a dictionary like this:\n",
- "\n",
- "```python\n",
- "modality_spsn = {\n",
- " \"MRI\": [spec_MRI, sens_MRI], \n",
- " \"CT\" : [spec_CT , sens_CT ]\n",
- "}\n",
- "```\n",
- "\n",
- "Note however, that the keys of this dictionary must be diagnostic modalities that are also present in the dataset. More precisely, they must be the overarching categories in the ``MultiIndex`` under which one then finds the individual LNLs.\n",
- "\n",
- "Finally, we can use the likelihood function that is built into the `lymph` package together with the sampling implementation `emcee` to infer the base probabilities $b_{v}$ and transition probabilities $t_{\\operatorname{pa}(v) \\rightarrow v}$."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# define specificity and sensitivity for diagnostic modalities\n",
- "modality_spsn = {\"path\": [1., 1.]}\n",
- "model.modalities = modality_spsn\n",
- "\n",
- "# load data\n",
- "model.patient_data = data\n",
- "\n",
- "# check if likelihood works\n",
- "spread_probs = np.random.uniform(size=(7,))\n",
- "llh = model.log_likelihood(\n",
- " spread_probs, t_stages=[\"early\"], \n",
- " time_dists=time_dists, \n",
- " mode=\"HMM\"\n",
- ")\n",
- "print(llh)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# parameters of the sampler\n",
- "ndim, nwalker, nstep, burnin = 7, 200, 2000, 1000\n",
- "moves = [(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2)]\n",
- "\n",
- "if DRAW_SAMPLES:\n",
- " # starting point\n",
- " np.random.seed(SEED)\n",
- " initial_spread_probs = np.random.uniform(low=0., high=1., size=(nwalker,ndim))\n",
- "\n",
- " # the actual sampling round\n",
- " if __name__ == \"__main__\":\n",
- " with Pool() as pool:\n",
- " sampler = emcee.EnsembleSampler(\n",
- " nwalker, ndim, \n",
- " model.log_likelihood, \n",
- " kwargs={\"t_stages\": [\"early\"], \"time_dists\": time_dists}, \n",
- " moves=moves, pool=pool\n",
- " )\n",
- " sampler.run_mcmc(initial_spread_probs, nstep, progress=True)\n",
- "\n",
- " # extracting 200,000 of the 400,000 samples\n",
- " samples_HMM = sampler.get_chain(flat=True, discard=burnin)\n",
- "\n",
- " # saving the sampled data to disk for later convenience\n",
- " np.save(\"./samples/HMM.npy\", samples_HMM)\n",
- " \n",
- "else:\n",
- " # loading in case we don't want to draw all the samples again\n",
- " samples_HMM = np.load(\"./samples/HMM.npy\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "if DRAW_SAMPLES:\n",
- " # check acceptance faction of the sampler to get an indication on whether sth\n",
- " # went wrong or not\n",
- " ar = np.mean(sampler.acceptance_fraction)\n",
- " print(f\"the HMM sampler accepted {ar * 100 :.2f} % of samples.\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "labels = [r\"$\\tilde{b}_1$\", r\"$\\tilde{b}_2$\", \n",
- " r\"$\\tilde{b}_3$\", r\"$\\tilde{b}_4$\", \n",
- " r\"$\\tilde{t}_{12}$\", r\"$\\tilde{t}_{23}$\", r\"$\\tilde{t}_{34}$\"]\n",
- "\n",
- "fig = plt.figure(figsize=set_size(width=\"full\", ratio=1))\n",
- "\n",
- "# using the corner plot package\n",
- "corner.corner(samples_HMM, labels=labels, smooth=True, fig=fig, \n",
- " hist_kwargs={'histtype': 'stepfilled', 'color': usz_blue}, \n",
- " **{'plot_datapoints': False, 'no_fill_contours': True, \n",
- " \"density_cmap\": white_to_blue.reversed(), \n",
- " \"contour_kwargs\": {\"colors\": \"k\"}, \n",
- " \"levels\": np.array([0.2, 0.5, 0.8])}, \n",
- " show_titles=True, title_kwargs={\"fontsize\": \"medium\"});\n",
- "\n",
- "axes = fig.get_axes()\n",
- "for ax in axes:\n",
- " ax.grid(False)\n",
- " \n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/corner_HMM.png\", dpi=300, bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/corner_HMM.svg\", bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 2:__ Corner plot of the sampled parameters for the HMM model parameters. The histograms on the diagonal show the 1D marginals, while the lower triangle shows all possible combinations of 2D marginals. The black lines are the isolines enclosing 20%, 50% and 80% of the sampled points respectively. Correlations between the parameters can at most be seen between $\\tilde{t}_{23}$ and $\\tilde{b}_3$.*"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Transition Matrix\n",
- "\n",
- "We can now set the hidden Markov model's parameters to the expected value of the inferred parameters and look at the resulting transition matrix $\\mathbf{A}$."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# set the parameters\n",
- "model.spread_probs = np.mean(samples_HMM, axis=0)\n",
- "\n",
- "# modify the transition matrix for nicer coloring\n",
- "mod_A = -1 * np.ones_like(model.transition_matrix)\n",
- "for key, nums in model.allowed_transitions.items():\n",
- " for i in nums:\n",
- " mod_A[key, i] = model.transition_matrix[key, i]\n",
- "\n",
- "# plot the transition matrix\n",
- "fig, ax = plt.subplots(figsize=set_size(ratio=1.), \n",
- " constrained_layout=True);\n",
- "\n",
- "h = ax.imshow(mod_A, cmap=halfGray_halfGreenToRed, vmin=-1., vmax=1.);\n",
- "ax.set_xticks(range(len(model.state_list)));\n",
- "ax.set_xticklabels(model.state_list, rotation=-90, fontsize=\"small\");\n",
- "ax.set_yticks(range(len(model.state_list)));\n",
- "ax.set_yticklabels(model.state_list, fontsize=\"small\");\n",
- "ax.tick_params(direction=\"out\")\n",
- "ax.grid(False)\n",
- "\n",
- "# label the non-zero entries with their probability in %\n",
- "for i in range(len(model.state_list)):\n",
- " for j in range(len(model.state_list)):\n",
- " if mod_A[i,j] > 0.:\n",
- " ax.text(j,i, f\"{mod_A[i,j]*100:.1f}\", ha=\"center\", va=\"center\", \n",
- " color=\"white\", fontsize=\"x-small\")\n",
- " \n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/transition_matrix.png\", dpi=300, bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/transition_matrix.svg\", bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 3:__ Transition matrix $\\mathbf{A}$. All gray pixels in this image correspond to entries in the matrix being zero. The colored pixels take on values $\\in [0,1]$ which are here overlayed in %. The exact values stem from the mean of the learned parameters displayed above. The exact shape of the grey “mask” depends on how one orders the states*"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Evolution Plots\n",
- "\n",
- "We can also take a look at how this system evolves over the defined time steps."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# array containing the risk for each state...\n",
- "state_array = np.zeros(shape=(MAX_T+1, len(model.state_list)), dtype=float)\n",
- "# ...weighted with the probability for that time\n",
- "state_array_weighted = np.zeros_like(state_array, dtype=float)\n",
- "state_array_summed = np.zeros_like(state_array, dtype=float)\n",
- "# starting state\n",
- "start = np.zeros(shape=(len(model.state_list),))\n",
- "start[0] = 1.\n",
- "\n",
- "# manually evolving the system and storing all intermediate states\n",
- "time_dist = sp.stats.binom.pmf(np.arange(MAX_T+1), MAX_T, 0.4)\n",
- "for t,p in enumerate(time_dist):\n",
- " state_array[t] = start\n",
- " state_array_weighted[t] = p * start\n",
- " state_array_summed[t] = np.ones(shape=(1,t+1)) @ state_array_weighted[:t+1]\n",
- " state_array_summed[t] = state_array_summed[t] / np.sum(state_array_summed[t])\n",
- " start = start @ model.transition_matrix"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# these arrays define which states we need to marginalize over when we are \n",
- "# interested in a particual LNL's risk of involvement\n",
- "lnl_I_arr = np.array([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], \n",
- " dtype=float)\n",
- "lnl_II_arr = np.array([0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1], \n",
- " dtype=float)\n",
- "lnl_III_arr = np.array([0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], \n",
- " dtype=float)\n",
- "lnl_IV_arr = np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], \n",
- " dtype=float)\n",
- "\n",
- "lnl_I = state_array @ lnl_I_arr\n",
- "lnl_II = state_array @ lnl_II_arr \n",
- "lnl_III = state_array @ lnl_III_arr \n",
- "lnl_IV = state_array @ lnl_IV_arr\n",
- "\n",
- "lnl_concat = np.vstack([lnl_I, lnl_II, lnl_III, lnl_IV])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "marg_II = np.zeros(shape=(len(samples_HMM[::50])))\n",
- "marg_III = np.zeros(shape=(len(samples_HMM[::50])))\n",
- "marg_II_and_III = np.zeros(shape=(len(samples_HMM[::50])))\n",
- "marg_IV = np.zeros(shape=(len(samples_HMM[::50])))\n",
- "\n",
- "np.random.seed(SEED)\n",
- "for i, spread_probs in enumerate(np.random.permutation(samples_HMM[::50])):\n",
- " model.spread_probs = spread_probs\n",
- " marg_II[i] = 100 * model.risk(\n",
- " inv=np.array([None, 1, None, None]),\n",
- " diagnoses={\"path\": np.array([None, None, None, None])},\n",
- " time_dist=time_dist\n",
- " )\n",
- " marg_III[i] = 100 * model.risk(\n",
- " inv=np.array([None, None, 1, None]),\n",
- " diagnoses={\"path\": np.array([None, None, None, None])},\n",
- " time_dist=time_dist\n",
- " )\n",
- " marg_II_and_III[i] = 100 * model.risk(\n",
- " inv=np.array([None, 1, 1, None]),\n",
- " diagnoses={\"path\": np.array([None, None, None, None])},\n",
- " time_dist=time_dist\n",
- " )\n",
- " marg_IV[i] = 100 * model.risk(\n",
- " inv=np.array([None, None, None, 1]),\n",
- " diagnoses={\"path\": np.array([None, None, None, None])},\n",
- " time_dist=time_dist\n",
- " )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# marginalization for involvements of \"at least LNL x\" or \"only LNL x\" \n",
- "only_II = state_array[:,4]\n",
- "atleast_II = (state_array[:,4] + state_array[:,5] + state_array[:,6] \n",
- " + state_array[:,7] + state_array[:,12] + state_array[:,13] \n",
- " + state_array[:,14] + state_array[:,15])\n",
- "emp_II = 100 * np.sum(data[(\"path\", 'II')].to_numpy()) / 147\n",
- "\n",
- "only_III = state_array[:,2]\n",
- "atleast_III = (state_array[:,2] + state_array[:,3] + state_array[:,6] \n",
- " + state_array[:,7] + state_array[:,10] + state_array[:,11] \n",
- " + state_array[:,14] + state_array[:,15])\n",
- "emp_III = 100 * np.sum(data[(\"path\", 'III')].to_numpy()) / 147\n",
- "\n",
- "only_II_and_III = state_array[:,6]\n",
- "atleast_II_and_III = (state_array[:,6] + state_array[:,7] \n",
- " + state_array[:,14] + state_array[:,15])\n",
- "emp_II_and_III = 100 * len(data.loc[(data[(\"path\", 'II')]==1) \n",
- " & (data[(\"path\", 'III')]==1)].to_numpy()) / 147\n",
- "\n",
- "only_IV = state_array[:,1]\n",
- "atleast_IV = (state_array[:,1] + state_array[:,3] + state_array[:,5] \n",
- " + state_array[:,7] + state_array[:,9] + state_array[:,11] \n",
- " + state_array[:,13] + state_array[:,15])\n",
- "emp_IV = 100 * np.sum(data[(\"path\", 'IV')].to_numpy()) / 147\n",
- "\n",
- "# and now for one complicated plot...\n",
- "fig = plt.figure(figsize=set_size(width=\"full\", ratio=2*1.61), \n",
- " constrained_layout=True);\n",
- "spec = gs.GridSpec(ncols=3, nrows=1, figure=fig, width_ratios=[1., 1., 0.3]);\n",
- "\n",
- "# leftmost subplot\n",
- "ax = fig.add_subplot(spec[0,0])\n",
- "ax.plot(range(len(time_dist)), 100*only_II, 'o-', \n",
- " label=r\"$\\xi_5=[0\\ 1\\ 0\\ 0]$\");\n",
- "ax.plot(range(len(time_dist)), 100*only_III, 'o-', \n",
- " label=r\"$\\xi_3=[0\\ 0\\ 1\\ 0]$\");\n",
- "ax.plot(range(len(time_dist)), 100*only_II_and_III, 'o-', \n",
- " label=r\"$\\xi_7=[0\\ 1\\ 1\\ 0]$\");\n",
- "ax.plot(range(len(time_dist)), 100*only_IV, 'o-', \n",
- " label=r\"$\\xi_2=[0\\ 0\\ 0\\ 1]$\");\n",
- "ax.set_xlabel(\"time step $t$\");\n",
- "ax.set_ylim(ymax=50);\n",
- "ax.set_ylabel(\"Risk [%]\");\n",
- "ax.legend();\n",
- "\n",
- "# middle subplot\n",
- "ax = fig.add_subplot(spec[0,1])\n",
- "ax.plot(range(len(time_dist)), 100*atleast_II, 'o-', \n",
- " label=\"lvl II involved\");\n",
- "ax.plot(range(len(time_dist)), 100*atleast_III, 'o-', \n",
- " label=\"lvl III involved\");\n",
- "ax.plot(range(len(time_dist)), 100*atleast_II_and_III, 'o-', \n",
- " label=\"lvl II & III involved\");\n",
- "ax.plot(range(len(time_dist)), 100*atleast_IV, 'o-', \n",
- " label=\"lvl IV involved\");\n",
- "ax.set_xlabel(\"time step $t$\");\n",
- "ax.set_ylim(ymax=100);\n",
- "ax.legend();\n",
- "\n",
- "# rightmost subplot\n",
- "ax = fig.add_subplot(spec[0,2], sharey=ax);\n",
- "plt.setp(ax.get_yticklabels(), visible=False);\n",
- "ax.set_xticks([0, 1, 2, 3]);\n",
- "ax.set_xticklabels([\"II\", \"III\", \"II & III\", \"IV\"], rotation=-45);\n",
- "\n",
- "violin = ax.violinplot(marg_II, positions=[0]);\n",
- "violin[\"bodies\"][0].set_color(usz_blue);\n",
- "violin[\"cbars\"].set_color(usz_blue);\n",
- "ax.axhline(emp_II, color=usz_blue, ls=\"--\");\n",
- "\n",
- "violin = ax.violinplot(marg_III, positions=[1]);\n",
- "violin[\"bodies\"][0].set_color(usz_orange);\n",
- "violin[\"cbars\"].set_color(usz_orange);\n",
- "ax.axhline(emp_III, color=usz_orange, ls=\"--\");\n",
- "\n",
- "violin = ax.violinplot(marg_II_and_III, positions=[2]);\n",
- "violin[\"bodies\"][0].set_color(usz_red);\n",
- "violin[\"cbars\"].set_color(usz_red);\n",
- "ax.axhline(emp_II_and_III, color=usz_red, ls=\"--\");\n",
- "\n",
- "violin = ax.violinplot(marg_IV, positions=[3]);\n",
- "violin[\"bodies\"][0].set_color(usz_green);\n",
- "violin[\"cbars\"].set_color(usz_green);\n",
- "ax.axhline(emp_IV, color=usz_green, ls=\"--\");\n",
- "\n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/HMM_evolution.png\", dpi=300, bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/HMM_evolution.svg\", bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 3:__ (left) Probability of certain hidden state vs time; (middle) Probability of LNL’s involvement marginalized over the other LNL’s involvement vs time; (right) The same probabilities as in the middle, but also marginalized over the time-prior and depicted as violin plots. The dashed lines represent the prevalence in the dataset8 that was used for training.*"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "fig = plt.figure(figsize=set_size(width=\"full\", ratio=1.8), \n",
- " constrained_layout=True)\n",
- "spec = gs.GridSpec(ncols=2, nrows=1, figure=fig, \n",
- " width_ratios=[1., 0.35], height_ratios=[1.])\n",
- "\n",
- "ax_im = fig.add_subplot(spec[0,0])\n",
- "ax_im.set_title(\"Probabilities [%] of states at different time steps\")\n",
- "ax_im.imshow(state_array, cmap=green_to_red);\n",
- "for i in range(len(time_dist)):\n",
- " for j in range(len(model.state_list)):\n",
- " if np.around(state_array[i,j]*100,1) >= 1.:\n",
- " ax_im.text(j,i, f\"{state_array[i,j]*100:.1f}\", \n",
- " ha=\"center\", va=\"center\", \n",
- " color=\"white\", fontsize=\"xx-small\")\n",
- "ax_im.set_xticks(range(len(model.state_list)))\n",
- "ax_im.set_xticklabels(model.state_list, rotation=45);\n",
- "ax_im.set_ylabel(\"time step $t$\")\n",
- "ax_im.set_xlabel(r\"state $\\xi$\")\n",
- "ax_im.grid(False)\n",
- "\n",
- "ax_pr = fig.add_subplot(spec[0,1], sharey=ax_im)\n",
- "ax_pr.set_title(\"Time prior (PDF)\")\n",
- "ax_pr.plot(time_dist, range(len(time_dist)), \"o-\")\n",
- "plt.setp(ax_pr.get_yticklabels(), visible=False);\n",
- "\n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/HMM_evo_matrix.png\", dpi=300, bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/HMM_evo_matrix.svg\", bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 4:__ Probability of being in each hidden state as a function of time (left). The color indicates low (green) and high (red) probabilities, which are also written on the respective pixel in percent if larger than 1%. We used the mean of the inferred parameter samples to compute the probabilities. On the right, the used time-prior is plotted with which each column on the left will be weighted.*"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## (Cross-)Validation\n",
- "\n",
- "### Comparison of risk & prevalences\n",
- "\n",
- "As an attempt to validate the model with the limited data we have, I'll start by simply comparing the prevalence of certain patterns of involvement to the prediction of the model for the respective state."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "thin = 100\n",
- "np.random.seed(SEED)\n",
- "\n",
- "risks = np.zeros(shape=(len(model.obs_list), len(samples_HMM[::thin])), dtype=float)\n",
- "\n",
- "for i, sample in enumerate(np.random.permutation(samples_HMM[::thin])):\n",
- " for j, obs in enumerate(model.obs_list):\n",
- " model.spread_probs = sample\n",
- " risks[j,i] = model.risk(\n",
- " inv=obs, diagnoses={\"path\": [None, None, None, None]}, \n",
- " time_dist=time_dists[\"early\"], mode=\"HMM\"\n",
- " )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "occurences, _ = lymph.utils.comp_state_dist(data[\"path\"].values)\n",
- "\n",
- "validation_df = pd.DataFrame({\"state\": [str(obs) for obs in model.obs_list], \n",
- " \"occurence\": occurences, \n",
- " \"percentage\": 100 * occurences / np.sum(occurences), \n",
- " \"prediction\": 100 * np.mean(risks, axis=1)})\n",
- "validation_df.head()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Table 2:__ Prevalence of each state in the dataset (column \"percentage\") and the corresponding prediction from the model (column \"prediction\").*"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 3-fold cross-validation\n",
- "\n",
- "Now I will split the dataset randomly into three equally large parts. Then I will train the model on all three combinations of two of these thirds and compare them to the respectively remaining third to see if the results are still plausible.\n",
- "\n",
- "These results are actually not in our paper, since they were already done in a similar fashion by Pouymayou et al."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "subsets = []\n",
- "\n",
- "# first third\n",
- "subsets.append(data.sample(frac=1./3.))\n",
- "rest = data.drop(subsets[0].index)\n",
- "\n",
- "# second third\n",
- "subsets.append(rest.sample(frac=0.5))\n",
- "\n",
- "# third third\n",
- "subsets.append(rest.drop(subsets[1].index))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "for i,subset in enumerate(subsets):\n",
- " subset.to_csv(f\"./data/cross-validation_set{i+1}.csv\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# parameters of the sampler\n",
- "cross_validation_samples = []\n",
- "ndim, nwalker, nstep, burnin = 7, 200, 2000, 1000\n",
- "moves = [(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2)]\n",
- "np.random.seed(SEED)\n",
- "\n",
- "if DRAW_SAMPLES:\n",
- " for i,subset in enumerate(subsets):\n",
- " # load data subset (or rather the remainder)\n",
- " model.patient_data = data.drop(subset.index)\n",
- "\n",
- " # starting point\n",
- " theta0 = np.random.uniform(low=0., high=1., size=(nwalker,ndim))\n",
- "\n",
- " with Pool() as pool:\n",
- " sampler = emcee.EnsembleSampler(nwalker, ndim, model.log_likelihood, \n",
- " args=[[\"early\"], time_dists], \n",
- " moves=moves, pool=pool)\n",
- " sampler.run_mcmc(theta0, nstep, progress=True)\n",
- "\n",
- " cross_validation_samples.append(\n",
- " sampler.get_chain(flat=True, discard=burnin)\n",
- " )\n",
- " \n",
- " # saving the sampled data to disk for later convenience\n",
- " np.save(f\"./samples/cross-validation-samples{i+1}.npy\", \n",
- " cross_validation_samples[i])\n",
- " \n",
- "else:\n",
- " for i,subset in enumerate(subsets):\n",
- " # loading in case we don't want to draw all the samples again\n",
- " cross_validation_samples.append(\n",
- " np.load(f\"./samples/cross-validation-samples{i+1}.npy\")\n",
- " )"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now I will plot the results"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "systm.load_data(data, t_stages=[\"early\"],\n",
- " modality_spsn=modality_spsn, mode=\"HMM\",\n",
- " gen_C_kwargs={\"delete_ones\": False})\n",
- "\n",
- "C_total = systm.C_dict[\"early\"]\n",
- "f_total = systm.f_dict[\"early\"]\n",
- "\n",
- "C_matrices = []\n",
- "f_vectors = []\n",
- "\n",
- "for i,subset in enumerate(subsets):\n",
- " systm.load_data(subset, t_stages=[\"early\"],\n",
- " modality_spsn=modality_spsn, mode=\"HMM\",\n",
- " gen_C_kwargs={\"delete_ones\": False})\n",
- " C_matrices.append(systm.C_dict[\"early\"])\n",
- " f_vectors.append(systm.f_dict[\"early\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "cross_val = pd.DataFrame(columns=[\"subset 1\", \"subset 2\", \"subset 3\", \"total\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# go through all the states\n",
- "for s in range(16):\n",
- " new_row = {}\n",
- " try:\n",
- " idx = np.where(C_total[s,:])[0][0]\n",
- " new_row[\"total\"] = f_total[idx]\n",
- " except IndexError:\n",
- " new_row[\"total\"] = 0\n",
- " for i in range(3):\n",
- " try:\n",
- " idx = np.where(C_matrices[i][s,:])[0][0]\n",
- " new_row[f\"subset {i+1}\"] = f_vectors[i][idx]\n",
- " except IndexError:\n",
- " new_row[f\"subset {i+1}\"] = 0\n",
- " \n",
- " cross_val = cross_val.append(new_row, ignore_index=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "obs_spsn_dict = {\"PET\": [0.86, 0.79]}\n",
- "systm.modalities = obs_spsn_dict"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "thin = 500\n",
- "state_list = systm.state_list\n",
- "risk_matrix = np.zeros(shape=(16,4,int(samples_HMM.shape[0]/thin)), dtype=float)\n",
- "\n",
- "for s,state in enumerate(state_list):\n",
- " for i,sample in enumerate([cross_validation_samples[0], \n",
- " cross_validation_samples[1], \n",
- " cross_validation_samples[2], \n",
- " samples_HMM]):\n",
- " sample_set = np.random.permutation(sample)[::thin]\n",
- " risk_matrix[s,i] = [\n",
- " systm.risk(\n",
- " val, inv=state, \n",
- " time_dist=time_dists[\"early\"]) for val in sample_set\n",
- " ]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# select the most interesting states (others have very small probabilities)\n",
- "row_selection = [0, 4, 6, 7, 14]\n",
- "titles = [\"subset 1\", \"subset 2\", \"subset 3\", \"whole dataset\"]\n",
- "\n",
- "fig, ax = plt.subplots(len(row_selection),4, figsize=(8,6), \n",
- " sharex=\"col\", sharey=\"row\")\n",
- "x = np.linspace(0., 1., 100)\n",
- "\n",
- "for s,row in enumerate(row_selection):\n",
- " for i,subset in enumerate([*subsets, data]):\n",
- " if s == 0:\n",
- " ax[s,i].set_title(titles[i])\n",
- " if i == 0:\n",
- " ax[s,i].set_ylabel(f\"$\\\\xi_{{{row+1}}}=$\\n{state_list[row]}\")\n",
- " \n",
- " ax[s,i].plot(x, sp.stats.beta.pdf(x, \n",
- " a=cross_val.iloc[row,i], \n",
- " b=len(subset)-cross_val.iloc[row,i]), \n",
- " color=usz_orange, linewidth=2)\n",
- " ax[s,i].hist(risk_matrix[row,i], bins=10, density=True, \n",
- " histtype=\"stepfilled\", color=usz_blue)\n",
- " ax[s,i].set_xlim([0., 1.])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 5:__ Histograms over predicted risk of certain states (blue) compared to the Beta distribution over the same risk, resulting from the prevalence of the respective state in the dataset (orange). This is plotted for the three subsets of the 3-fold cross-validation as well as the whole dataset.*"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Comparison to Bayesian Network\n",
- "\n",
- "To be able to compare our results to the Bayesian network by [Pouymayou et al.](#pouymayou) we needed to recreate it using the same sampler. To this end, the `lymph` package also supports computing the Bayesian network likelihood for a given graph and observational modality."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# abstract representation of the lymphatic network\n",
- "graph = {('tumor', 'primary') : ['I', 'II', 'III', 'IV'], \n",
- " ('lnl', 'I') : ['II'], \n",
- " ('lnl', 'II') : ['III'], \n",
- " ('lnl', 'III') : ['IV'], \n",
- " ('lnl', 'IV') : []}\n",
- "\n",
- "model = lymph.System(graph=graph)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Learning (BN)\n",
- "\n",
- "All that is different to the learning round [above](#Learning-(HMM)) is that one has to specify the `mode` to be `\"BN\"` instead of `\"HMM\"`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# define specificity and sensitivity for diagnostic modalities\n",
- "modality_spsn = {\"path\": [1., 1.]}\n",
- "\n",
- "model.modalities = modality_spsn\n",
- "\n",
- "# generate C matrix from data\n",
- "model.load_data(data, mode=\"BN\")\n",
- "\n",
- "# check if likelihood works\n",
- "llh = model.log_likelihood(np.random.uniform(size=(7,)), mode=\"BN\")\n",
- "print(llh)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# parameters of the sampler\n",
- "ndim, nwalker, nstep, burnin = 7, 200, 2000, 1000\n",
- "moves = [(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2)]\n",
- "\n",
- "if DRAW_SAMPLES:\n",
- " # starting point\n",
- " np.random.seed(SEED)\n",
- " initial_spread_probs = np.random.uniform(low=0., high=1., size=(nwalker,ndim))\n",
- "\n",
- " # the actual sampling round\n",
- " if __name__ == \"__main__\":\n",
- " with Pool() as pool:\n",
- " sampler = emcee.EnsembleSampler(\n",
- " nwalker, ndim, \n",
- " model.log_likelihood, kwargs={\"mode\": \"BN\"},\n",
- " moves=moves, pool=pool\n",
- " )\n",
- " sampler.run_mcmc(initial_spread_probs, nstep, progress=True)\n",
- "\n",
- " # extracting 200,000 of the 400,000 samples\n",
- " samples_BN = sampler.get_chain(flat=True, discard=burnin)\n",
- "\n",
- " # saving the sampled data to disk for later convenience\n",
- " np.save(\"./samples/BN.npy\", samples_BN)\n",
- " \n",
- "else:\n",
- " # loading in case we don't want to draw all the samples again\n",
- " samples_BN = np.load(\"./samples/BN.npy\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "if DRAW_SAMPLES:\n",
- " # check acceptance faction of the sampler to get an indication on whether sth\n",
- " # went wrong or not\n",
- " ar = np.mean(sampler.acceptance_fraction)\n",
- " print(f\"the BN sampler accepted {ar * 100 :.2f} % of samples.\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "labels = [r\"$b_1$\", r\"$b_2$\", r\"$b_3$\", r\"$b_4$\", \n",
- " r\"$t_{12}$\", r\"$t_{23}$\", r\"$t_{34}$\"]\n",
- "\n",
- "fig = plt.figure(figsize=set_size(width=\"full\", ratio=1))\n",
- "\n",
- "# using the corner plot package\n",
- "corner.corner(samples_BN, labels=labels, smooth=True, fig=fig, \n",
- " hist_kwargs={'histtype': 'stepfilled', 'color': usz_green}, \n",
- " **{'plot_datapoints': False, 'no_fill_contours': True, \n",
- " \"density_cmap\": white_to_green.reversed(), \n",
- " \"contour_kwargs\": {\"colors\": \"k\"}, \n",
- " \"levels\": np.array([0.2, 0.5, 0.8])}, \n",
- " show_titles=True, title_kwargs={\"fontsize\": \"medium\"});\n",
- "\n",
- "axes = fig.get_axes()\n",
- "for ax in axes:\n",
- " ax.grid(False)\n",
- " \n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/corner_BN.png\", dpi=300, bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/corner_BN.svg\", bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 6:__ Corner plor of sampled parameters using the Bayesian network model. Qualitatively, this looks very similar to Figure 2, but since the HMM model works with rates instead of absolute probabilities, the values here are larger.*"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Risk Predictions\n",
- "\n",
- "We can now compute distributions over risks using both the HMM model, as well as the BN.\n",
- "\n",
- "### Evolving beyond \"early\" T-stage\n",
- "\n",
- "We can use the parameters inferred from the early T-stage dataset and use time priors that expect to see a patient's diagnose later on to estimate how risks of involvement might increase over time."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# abstract representation of the lymphatic network\n",
- "graph = {('tumor', 'primary') : ['I', 'II', 'III', 'IV'], \n",
- " ('lnl' , 'I') : ['II'], \n",
- " ('lnl' , 'II') : ['III'], \n",
- " ('lnl' , 'III') : ['IV'], \n",
- " ('lnl' , 'IV' ) : []}\n",
- "\n",
- "tst_model = lymph.System(graph=graph)\n",
- "\n",
- "# set specificity & sensitivity of diagnostic modality (here CT) manually\n",
- "ct_spsn = {\"CT\": [0.76, 0.81]}\n",
- "tst_model.modalities = ct_spsn"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# time priors\n",
- "time_dists = {}\n",
- "time_dists['early'] = sp.stats.binom.pmf(np.arange(MAX_T+1), MAX_T, 0.4)\n",
- "time_dists['mid'] = sp.stats.binom.pmf(np.arange(MAX_T+1), MAX_T, 0.55)\n",
- "time_dists['late'] = sp.stats.binom.pmf(np.arange(MAX_T+1), MAX_T, 0.7)\n",
- "\n",
- "# what do we want to know, what do we know?\n",
- "inv = np.array([None, None, 1, None]) # we're interested in the risk of LNL 3\n",
- "# our observation is that lvl 2 is involved\n",
- "diagnoses = {\"CT\": np.array([0, 1, 0, 0])}\n",
- "\n",
- "thin = 50\n",
- "# risk for HMM and different \"T-stages\"\n",
- "early = []\n",
- "mid = []\n",
- "late = []\n",
- "np.random.seed(SEED)\n",
- "for sample in np.random.permutation(samples_HMM)[::thin]:\n",
- " tst_model.spread_probs = sample\n",
- " early.append(\n",
- " tst_model.risk(\n",
- " inv=inv, diagnoses=diagnoses, \n",
- " time_dist=time_dists[\"early\"], \n",
- " mode=\"HMM\"\n",
- " )\n",
- " )\n",
- " mid.append(\n",
- " tst_model.risk(\n",
- " inv=inv, diagnoses=diagnoses, \n",
- " time_dist=time_dists[\"mid\"], \n",
- " mode=\"HMM\"\n",
- " )\n",
- " )\n",
- " late.append(\n",
- " tst_model.risk(\n",
- " inv=inv, diagnoses=diagnoses, \n",
- " time_dist=time_dists[\"late\"], \n",
- " mode=\"HMM\"\n",
- " )\n",
- " )\n",
- "\n",
- "# risk for BN\n",
- "bn = []\n",
- "np.random.seed(SEED)\n",
- "for sample in np.random.permutation(samples_BN)[::thin]:\n",
- " tst_model.spread_probs = sample\n",
- " bn.append(tst_model.risk(inv=inv, diagnoses=diagnoses, mode=\"BN\"))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "bins = np.linspace(0, 35, 50)\n",
- "r = (0, 35)\n",
- "fig, ax = plt.subplots(figsize=set_size())\n",
- "\n",
- "ax.hist(np.asarray(early)*100., bins=bins, density=True, \n",
- " histtype=\"stepfilled\", color=usz_green, label=\"$p = 0.4$\");\n",
- "ax.hist(np.asarray(mid)*100., bins=bins, density=True, alpha=0.8, \n",
- " histtype=\"stepfilled\", color=usz_orange, label=\"$p = 0.55$\");\n",
- "ax.hist(np.asarray(late)*100., bins=bins, density=True, alpha=0.8, \n",
- " histtype=\"stepfilled\", color=usz_red, label=\"$p = 0.7$\");\n",
- "ax.hist(np.asarray(bn)*100., bins=bins, histtype=\"step\", density=True, \n",
- " color=usz_blue, label=\"Bayesian network\");\n",
- "\n",
- "ax.set_xlim(r)\n",
- "ax.set_xlabel(\"risk $R$ [%]\");\n",
- "ax.set_ylabel(r\"$p(R)$\");\n",
- "ax.tick_params();\n",
- "ax.legend();\n",
- "\n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/HMM_risk_increaseP.png\", dpi=300, \n",
- " bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/HMM_risk_increaseP.svg\", \n",
- " bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 7:__ Risk prediction for LNL III, given observed positive involvement in LNL II and negative observations in all other LNLs (assuming $s_N = 81\\%$ and $s_P = 76\\%$). The Binomial parameter p was fixed to 0.4 for parameter learning (green), representing early T-category patients. Increasing this parameter results in higher risk. The blue outline shows the risk in level III obtained for the Bayesian network model. The histograms correspond to 1% of the 200,000 samples.*"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Comparison to [Pouymayou et al.](#pouymayou)\n",
- "\n",
- "Now we will compare how the sampled HMM's, sampled BN's and maximum likelihood BN's risk predictions compare to each other."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "ndim, nwalker, nstep, burnin = 7, 200, 2000, 1000\n",
- "thin = 50\n",
- "\n",
- "# what do we want to know?\n",
- "inv = np.array([[1 , None, None, None],\n",
- " [None, 1 , None, None],\n",
- " [None, None, 1 , None],\n",
- " [None, None, None, 1 ]])\n",
- "\n",
- "# what do we know?\n",
- "obs = np.array([[0, 0, 0, 0],\n",
- " [0, 1, 0, 0],\n",
- " [0, 1, 1, 0],\n",
- " [0, 0, 1, 0]])\n",
- "\n",
- "# risk for HMM and two different \"T-stages\" (early and late)\n",
- "np.random.seed(SEED)\n",
- "hmm_risk = np.zeros(shape=(4,4,(nstep-burnin)*nwalker//thin))\n",
- "for i, sample in enumerate(np.random.permutation(samples_HMM)[::thin]):\n",
- " tst_model.spread_probs = sample\n",
- " for k in range(4):\n",
- " for l in range(4):\n",
- " hmm_risk[k,l,i] = tst_model.risk(\n",
- " inv=inv[k], diagnoses={\"CT\": obs[l]}, \n",
- " time_dist=time_dists[\"early\"], mode=\"HMM\"\n",
- " )\n",
- "\n",
- "# risk for BN\n",
- "np.random.seed(SEED)\n",
- "bn_risk = np.zeros(shape=(4,4,(nstep-burnin)*nwalker//thin))\n",
- "for i, sample in enumerate(np.random.permutation(samples_BN)[::thin]):\n",
- " tst_model.spread_probs = sample\n",
- " for k in range(4):\n",
- " for l in range(4):\n",
- " bn_risk[k,l,i] = tst_model.risk(\n",
- " inv=inv[k], diagnoses={\"CT\": obs[l]}, mode=\"BN\"\n",
- " )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "fig = plt.figure(figsize=set_size(width=\"full\"), constrained_layout=True)\n",
- "spec = gs.GridSpec(ncols=4, nrows=4, figure=fig)\n",
- "\n",
- "lvls = [\"I\", \"II\", \"III\", \"IV\"]\n",
- "txt = [\"Ø\", \"II\", \"II & III\", \"III\"]\n",
- "risk_map = LinearSegmentedColormap.from_list(\"risk_map\", [usz_green, \n",
- " usz_gray, \n",
- " usz_red, \n",
- " usz_red], N=256)\n",
- "n_bins = 25\n",
- "\n",
- "for i in range(4):\n",
- " if (i == 0) or (i == 3):\n",
- " bins = np.linspace(0., 12., 30)\n",
- " for j in range(4):\n",
- " if j == 0:\n",
- " ax = fig.add_subplot(spec[i,j])\n",
- " ax.set_ylabel(f\"{lvls[i]}\");\n",
- " else:\n",
- " ax = fig.add_subplot(spec[i,j], sharey=ax)\n",
- " plt.setp(ax.get_yticklabels(), visible=False)\n",
- " \n",
- " ax.set_xlim(bins[0], bins[-1])\n",
- " \n",
- " tmp_mean = np.mean(hmm_risk[i,j])\n",
- " hmm_color = risk_map(tmp_mean)\n",
- " ax.axvline(pouymayou_MLrisk[i,j]*100., \n",
- " color=usz_orange, label=\"Pouymayou et al\");\n",
- " _, bins, _ = ax.hist(hmm_risk[i,j]*100., bins=bins, \n",
- " histtype=\"stepfilled\", density=True, \n",
- " color=hmm_color) #, label=\"HMM sampling\");\n",
- " ax.hist(bn_risk[i,j]*100., bins=bins, density=True, \n",
- " histtype=\"step\", label=\"BN sampling\", color=usz_blue);\n",
- " ax.tick_params(labelsize=\"xx-small\")\n",
- " \n",
- " if i == 0:\n",
- " ax.set_title(f\"{txt[j]}\", \n",
- " fontsize=\"medium\", fontweight=\"normal\");\n",
- " else:\n",
- " ax.set_xlabel(\"risk [%]\");\n",
- " \n",
- " else:\n",
- " ax = fig.add_subplot(spec[i,:])\n",
- " bins = np.linspace(0., 100., 150)\n",
- " ax.set_xlim(bins[0], bins[-1])\n",
- " ax.set_ylabel(f\"{lvls[i]}\");\n",
- " \n",
- " for j in range(4):\n",
- " tmp_mean = np.mean(hmm_risk[i,j])\n",
- " hmm_color = risk_map(tmp_mean)\n",
- " ax.axvline(pouymayou_MLrisk[i,j]*100., \n",
- " color=usz_orange, label=\"Pouymayou et al\");\n",
- " n, bins, _ = ax.hist(hmm_risk[i,j]*100., bins=bins, \n",
- " histtype=\"stepfilled\", density=True, \n",
- " color=hmm_color) #, label=\"HMM sampling\");\n",
- " ax.hist(bn_risk[i,j]*100., bins=bins, density=True, \n",
- " histtype=\"step\", label=\"BN sampling\", color=usz_blue, \n",
- " linestyle='-');\n",
- " ax.set_xticks(np.linspace(0,100, 11))\n",
- " ax.tick_params(labelsize=\"xx-small\")\n",
- " \n",
- " if ((i == 2) and (j == 0)):\n",
- " ax.text(x=100*pouymayou_MLrisk[i,j]-4*(bins[1]-bins[0]), \n",
- " y=0.35, \n",
- " s=txt[j])\n",
- " elif ((i == 1) and (j == 2)):\n",
- " ax.text(x=100*pouymayou_MLrisk[i,j]+2.5*(bins[1]-bins[0]), \n",
- " y=0.2, \n",
- " s=txt[j])\n",
- " else:\n",
- " ax.text(x=100*pouymayou_MLrisk[i,j]+(bins[1]-bins[0]), \n",
- " y=np.max(n)+0.05, \n",
- " s=txt[j])\n",
- " \n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/HMM_BN_risk_comparison.png\", dpi=300, \n",
- " bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/HMM_BN_risk_comparison.svg\", \n",
- " bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 8:__ Risk assessment for the involvement of different LNLs (rows), given positive observational findings in specified LNLs (columns or labels next to histograms). E.g. row 3 depicts the risk of involvement in LNL III, given different observed involvements (from left to right: no involvement, LNL II only, LNL III only, and LNL II and III but no others). The orange line depicts the maximum likelihood result from [[2]](#pouymayou), the blue outline histogram represents the BN sampling solutions and the solid coloured histograms are the results from the HMM. The colour goes from green (low risk) to red (high risk). Of 200,000 parameter samples, 2% were used to create this plot.*"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Simultaneous Learning\n",
- "If we learn both the system's parameters AND the center of the time prior at the same time. But the naive way just leads to overfitting and very unrealistic combinations of parameters. So, what we are doing here is fixing the time prior for the early stage learning (and use the sanguineti data with chosen N0-ratio) and learn the probability rates along with the time-prior for the late stage (where the data only consists of N0 / N+ patients)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# abstract representation of the lymphatic network\n",
- "graph = {('tumor', 'primary') : ['I', 'II', 'III', 'IV'], \n",
- " ('lnl', 'I') : ['II'], \n",
- " ('lnl', 'II') : ['III'], \n",
- " ('lnl', 'III') : ['IV'], \n",
- " ('lnl', 'IV') : []}\n",
- "\n",
- "model = lymph.System(graph=graph)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# define specificity and sensitivity for diagnostic modalities\n",
- "pathology_spsn = {\"path\": [1., 1.]}\n",
- "\n",
- "model.modalities = pathology_spsn\n",
- "model.patient_data = data\n",
- "\n",
- "# create late T-stage diagnose matrices\n",
- "n_late = len(data)\n",
- "shape = (len(model.state_list), n_late)\n",
- "model._diagnose_matrices[\"late\"] = np.zeros(shape=shape)\n",
- "\n",
- "late_n0_num_patients = int(n_late * (1 - (1. / 1.2)))\n",
- "late_nplus_num_patients = int(n_late / 1.2)\n",
- "\n",
- "# N0 patients\n",
- "model._diagnose_matrices[\"late\"][0 , :late_n0_num_patients] = 1.\n",
- "model._diagnose_matrices[\"late\"][1:, late_n0_num_patients:] = 1."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "plt.imshow(model.diagnose_matrices[\"late\"])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# define likelihood function for simultaneous learning\n",
- "def simultaneous_log_likelihood(theta):\n",
- " len_spread_probs = len(model.spread_probs)\n",
- " spread_probs = theta[:len_spread_probs]\n",
- " early_p = 0.4\n",
- " late_p = theta[-1]\n",
- " \n",
- " if late_p > 1. or late_p < 0.:\n",
- " return -np.inf\n",
- " \n",
- " t = np.arange(MAX_T + 1)\n",
- " pt = lambda p : sp.stats.binom.pmf(t, MAX_T, p)\n",
- " \n",
- " time_dists = {}\n",
- " time_dists[\"early\"] = pt(early_p)\n",
- " time_dists[\"late\"] = pt(late_p)\n",
- " \n",
- " return model.marginal_log_likelihood(\n",
- " spread_probs, [\"early\", \"late\"], time_dists=time_dists\n",
- " )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "theta0 = np.random.uniform(size=(8,))\n",
- "llh = simultaneous_log_likelihood(theta0)\n",
- "print(llh)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "ndim, nwalker, nstep, burnin = 7 + 1, 200, 2000, 1000\n",
- "np.random.seed(SEED)\n",
- "theta0 = np.random.uniform(low=0., high=1., size=(nwalker,ndim))\n",
- "moves = [(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2)]\n",
- "\n",
- "if DRAW_SAMPLES:\n",
- " if __name__ == \"__main__\":\n",
- " with Pool() as pool:\n",
- " sampler = emcee.EnsembleSampler(\n",
- " nwalker, ndim, \n",
- " simultaneous_log_likelihood, \n",
- " moves=moves, pool=pool\n",
- " )\n",
- " sampler.run_mcmc(theta0, nstep, progress=True)\n",
- "\n",
- " # extracting 200,000 of the 400,000 samples\n",
- " samples_simultaneous = sampler.get_chain(flat=True, discard=burnin)\n",
- "\n",
- " # saving the sampled data to disk for later convenience\n",
- " np.save(\"./samples/simultaneous.npy\", samples_simultaneous)\n",
- " \n",
- "else:\n",
- " # loading in case we don't want to draw all the samples again\n",
- " samples_simultaneous = np.load(\"./samples/simultaneous.npy\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "if DRAW_SAMPLES:\n",
- " # check acceptance faction of the sampler to get an indication on whether sth\n",
- " # went wrong or not\n",
- " ar = np.mean(sampler.acceptance_fraction)\n",
- " print(f\"the simultaneous sampler accepted {ar * 100 :.2f} % of samples.\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "labels = [r\"$\\tilde{b}_1$\", r\"$\\tilde{b}_2$\", \n",
- " r\"$\\tilde{b}_3$\", r\"$\\tilde{b}_4$\", \n",
- " r\"$\\tilde{t}_{12}$\", r\"$\\tilde{t}_{23}$\", \n",
- " r\"$\\tilde{t}_{34}$\", r\"$p$\"]\n",
- "\n",
- "fig = plt.figure(figsize=set_size(width=\"full\", ratio=1))\n",
- "corner.corner(samples_simultaneous, labels=labels, smooth=True, fig=fig, \n",
- " hist_kwargs={'histtype': 'stepfilled', 'color': usz_blue}, \n",
- " **{'plot_datapoints': False, 'no_fill_contours': True, \n",
- " \"density_cmap\": white_to_blue.reversed(), \n",
- " \"contour_kwargs\": {\"colors\": \"k\"}, \n",
- " \"levels\": np.array([0.2, 0.5, 0.8])}, \n",
- " show_titles=True, title_kwargs={\"fontsize\": \"medium\"});\n",
- "\n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/corner_simultaneous.png\", dpi=300, \n",
- " bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/corner_simultaneous.svg\", \n",
- " bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 9:__ Corner plot of the sampled paramters during the simultaneous sampling process, were we also inferred the p-parameter of the late T-categorie's time-prior.*"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "mean_late_p = np.mean(samples_simultaneous[:,7])\n",
- "\n",
- "fig, ax = plt.subplots(1,2, \n",
- " figsize=set_size(width=\"full\", ratio=1.61**2), \n",
- " constrained_layout=True)\n",
- "\n",
- "ax[0].axvline(0.4, color=usz_blue, linewidth=2, \n",
- " label=r\"$p_{\\mathrm{early}}$ (fixed)\");\n",
- "ax[0].hist(samples_simultaneous[:,7], bins=40, density=True, \n",
- " color=usz_red, histtype=\"stepfilled\", \n",
- " label=r\"$p_{\\mathrm{late}}$\");\n",
- "ax[0].set_xlabel(\"Binomial parameter $p$\");\n",
- "# ax[0].set_ylabel(r\"$p\\left(\\theta_p^T\\right)$\", fontsize=\"large\");\n",
- "ax[0].legend();\n",
- "\n",
- "t = np.arange(MAX_T+1)\n",
- "dist_sum = np.zeros_like(t, dtype=float)\n",
- "\n",
- "for sample in np.random.permutation(samples_simultaneous[::100]):\n",
- " p = sample[-1]\n",
- " dist_sum += sp.stats.binom.pmf(t, MAX_T, p)\n",
- "\n",
- "dist_avrg = dist_sum / len(samples_simultaneous[::100])\n",
- "\n",
- "ax[1].plot(t, sp.stats.binom.pmf(t, MAX_T, 0.4), 'o-', \n",
- " label=r\"$p_{\\mathrm{early}}$ (fixed)\");\n",
- "ax[1].plot(t, sp.stats.binom.pmf(t, MAX_T, mean_late_p), 'o-', color=usz_red, \n",
- " label=r\"$\\mathbb{E}[p_{\\mathrm{late}}]$\");\n",
- "ax[1].plot(t, dist_avrg, 'o--', color=usz_red, alpha=0.5, \n",
- " label=\"avergaged binomials\");\n",
- "\n",
- "ax[1].legend();\n",
- "ax[1].set_xlabel(\"Time step $t$\");\n",
- "ax[1].set_ylabel(r\"$p_T(t)$\");\n",
- "ax[1].set_xlim([1,MAX_T]);\n",
- "\n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/simultaneous_learnedP.png\", dpi=300, \n",
- " bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/simultaneous_learnedP.svg\", \n",
- " bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 10:__ Sampled late T-category p parameter given an early T-category cohort and a fixed fraction of N0 patients (20%) for late T-category (left). Plots of the PMFs of the fixed early T-category binomial distribution, the distribution for the expected value of the late T-category parameter as well as the average distribution resulting from the many sampled binomials (right).*\n",
- "\n",
- "Now we can again compare risk predictions for different T-stages of disease."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# abstract representation of the lymphatic network\n",
- "graph = {('tumor', 'primary') : ['I', 'II', 'III', 'IV'], \n",
- " ('lnl' , 'I') : ['II'], \n",
- " ('lnl' , 'II') : ['III'], \n",
- " ('lnl' , 'III') : ['IV'], \n",
- " ('lnl' , 'IV' ) : []}\n",
- "\n",
- "tst_systm = lymph.System(graph=graph)\n",
- "tst_systm.modalities = {\"CT\": [0.76, 0.81]}"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "np.random.seed(SEED)\n",
- "subset = np.random.permutation(samples_simultaneous[::20])\n",
- "nsubset = len(subset)\n",
- "risk_III = np.zeros(shape=(4, nsubset))\n",
- "risk_IV = np.zeros(shape=(4, nsubset))\n",
- "\n",
- "inv_III = np.array([None, None, 1, None])\n",
- "obs_III = np.array([[0, 0, 0, 0], # no involvement in lvl II observed\n",
- " [0, 1, 0, 0]]) # involvement observed\n",
- "\n",
- "inv_IV = np.array([None, None, None, 1])\n",
- "obs_IV = np.array([[0, 0, 0, 0], \n",
- " [0, 1, 1, 0]])\n",
- "\n",
- "for i, th in enumerate(subset):\n",
- " tst_systm.spread_probs = th[:7]\n",
- " prior = np.vstack([sp.stats.binom.pmf(np.arange(MAX_T+1), MAX_T, 0.4), \n",
- " sp.stats.binom.pmf(np.arange(MAX_T+1), MAX_T, th[7])])\n",
- " \n",
- " for k in range(4):\n",
- " risk_III[k,i] = tst_systm.risk(\n",
- " inv=inv_III, diagnoses={\"CT\": obs_III[k % 2]}, \n",
- " time_dist=prior[k // 2]\n",
- " )\n",
- " risk_IV[k,i] = tst_systm.risk(\n",
- " inv=inv_IV, diagnoses={\"CT\": obs_IV[k % 2]}, \n",
- " time_dist=prior[k // 2]\n",
- " )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "fig, ax = plt.subplots(1,2, \n",
- " figsize=set_size(width=\"full\", ratio=2*1.5), \n",
- " constrained_layout=True)\n",
- "kwargs = [{\"histtype\": \"stepfilled\", \"alpha\": 0.5}, \n",
- " {\"histtype\": \"step\", \"linewidth\": 1.5}]\n",
- "colors = [usz_blue, usz_orange]\n",
- "time_label = [\"early\", \"late\"]\n",
- "inv_III_label = [\"no observed involvement\", \n",
- " \"only LNL II observed involved\"]\n",
- "inv_IV_label = [\"no observed involvement\", \n",
- " \"LNL II & III observed involved\"]\n",
- "\n",
- "for k in range(4):\n",
- " bins = np.linspace(0., 25., 50)\n",
- " ax[0].hist(100*risk_III[k], bins=bins, density=True, \n",
- " color=colors[k // 2], **kwargs[k % 2], \n",
- " label=f\"{time_label[k // 2]} | {inv_III_label[k % 2]}\")\n",
- " ax[0].set_xlabel(\"risk $R$ [%]\");\n",
- " ax[0].set_ylabel(r\"$p(R)$\");\n",
- " ax[0].set_xlim([bins[0], bins[-1]])\n",
- " ax[0].legend()\n",
- " ax[0].set_title(\"III\")\n",
- " \n",
- " bins = np.linspace(0., 15., 50)\n",
- " ax[1].hist(100*risk_IV[k], bins=bins, density=True, \n",
- " color=colors[k // 2], **kwargs[k % 2], \n",
- " label=f\"{time_label[k // 2]} | {inv_IV_label[k % 2]}\")\n",
- " ax[1].set_xlabel(\"risk $R$ [%]\");\n",
- " ax[1].set_xlim([bins[0], bins[-1]]);\n",
- " ax[1].legend()\n",
- " ax[1].set_title(\"IV\")\n",
- " \n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/simultaneous_risk.png\", dpi=300, \n",
- " bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/simultaneous_risk.svg\", \n",
- " bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "> *__Figure 11:__ Distributions over risk of involvement for LNL III (left) and LNL IV (right), each for early and late T-category as well as depending on the given observed involvement. The sampled parameters displayed here are a randomly selected subset (1% of 200,000) from simultaneous learning. Comparison with Fig. 8 shows that these predictions still agree with the results from the early stage only learning.*\n",
- "\n",
- "## References\n",
- "1. Sanguineti Giuseppe [et al.] Defining the risk of involvement for each neck nodal level in patients with early T-stage node-positive oropharyngeal carcinoma. [Journal] // International Journal of Radiation Oncology Biology Physics. - 2008. - 5 : Vol. 74. - pp. 1356-1364.\n",
- "2. Pouymayou Bertrand [et al.] A Bayesian network model of lymphatic tumor progression for personalized elective CTV definition in head and neck cancers [Journal] // Physics in Medicine and Biology. - 2019. - 16 : Vol. 64. - p. 165003."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Appendix\n",
- "\n",
- "### t-Dependence of Rates\n",
- "\n",
- "Since the HMM-formalism has more parameters than the BN through its time prior, we expect the system to be somewhat overdetermined. In our case this means that we can basically choose an arbitrary number of time steps and the base and transition probability rates will essentially adapt to our choice. To see *how* the rates depend on the time prior's length we'll look at a simplistic example:\n",
- "\n",
- "For the simplest example, the time prior $p_T(t,T)=1 / T$ is uniform and we're only looking at a system with one node that has empirically an involvement probability of $p^*$ (e.g. $0.4$), the base probability rate $p$ must become smaller, as the length of the uniform prior increases. More formally,\n",
- "\n",
- "$$\n",
- "\\begin{align}\n",
- "p^* \n",
- "&= \\sum_{t=1}^{T}{\\frac{1}{T}\\begin{pmatrix} 1 & 0 \\end{pmatrix}\\begin{bmatrix} 1-p & p \\\\ 0 & 1 \\end{bmatrix}^t\\begin{pmatrix} 0 \\\\ 1 \\end{pmatrix}} \\\\\n",
- "&= \\sum_{t=1}^{T}{\\frac{1}{T}\\begin{pmatrix} 1 & 0 \\end{pmatrix}\\begin{bmatrix} (1-p)^t & 1-(1-p)^t \\\\ 0 & 1 \\end{bmatrix}\\begin{pmatrix} 0 \\\\ 1 \\end{pmatrix}} \\\\\n",
- "&= \\sum_{t=1}^{T}{\\frac{1}{T}\\sum_{k=1}^{t}{{t\\choose k}p^k(1-p)^{t-k}}}\n",
- "= \\sum_{t=1}^{T}{\\frac{1}{T}\\left[ 1 - (1-p)^t \\right]} \\\\\n",
- "&= 1 - \\frac{1}{T}\\sum_{t=1}^{T}{(1-p)^t}\n",
- "\\end{align}\n",
- "$$\n",
- "\n",
- "Let's write $q = 1 - p$ and then the sum as\n",
- "\n",
- "$$\n",
- "s = \\sum_{t=1}^{T}{q^t} = q + q^2 + \\ldots + q^T\n",
- "$$\n",
- "\n",
- "hence\n",
- "\n",
- "$$\n",
- "q\\cdot s = q^2 + q^3 + \\ldots + q^{T+1}\n",
- "$$\n",
- "\n",
- "and\n",
- "\n",
- "$$\n",
- "\\begin{align}\n",
- "s - qs &= q - q^{T+1} \\\\\n",
- "\\Rightarrow\\, s(1 - q) &= q(1 - q^T) \\\\\n",
- "\\Rightarrow\\, s &= \\frac{q(1 - q^T)}{1 - q}\n",
- "\\end{align}\n",
- "$$\n",
- "\n",
- "So we can write the empirical probability $p^*$ as\n",
- "\n",
- "$$\n",
- "p^* = 1 - \\frac{q(1 - q^T)}{T(1 - q)} = 1 - \\frac{(1-p)(1 - (1-p)^T)}{Tp}\n",
- "$$\n",
- "\n",
- "This cannot be solved analytically for $p$ in the case of arbitrary $T$, but it is easy to find numerical solutions."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "p_star = 0.4\n",
- "max_prior_length = 15\n",
- "\n",
- "Ts = np.linspace(1, max_prior_length, max_prior_length, dtype=int)\n",
- "f = lambda p,t,p_star : 1 - (1 - p)*(1 - (1-p)**t) / (t*p) - p_star\n",
- "\n",
- "sols = []\n",
- "for t in Ts:\n",
- " # find roots of f numerically\n",
- " sols.append(sp.optimize.root_scalar(f, args=(t,p_star), bracket=[0.001, 0.999]).root)\n",
- " \n",
- "# plot the computed solutions\n",
- "fig, ax = plt.subplots(figsize=set_size())\n",
- "ax.plot(Ts, sols, 'o--');\n",
- "ax.set_xlim([0, max_prior_length+1]);\n",
- "ax.set_xticks(Ts[::2]);\n",
- "ax.set_xlabel(\"length of prior $T$\");\n",
- "ax.set_ylabel(r\"solution for $\\tilde{b}_1$\");\n",
- "\n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/simple_rate_decay.png\", dpi=300, \n",
- " bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/simple_rate_decay.svg\", \n",
- " bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "To see if this overly simplified model applies to the full HMM as well, we trained our model for 15 different uniform time priors and plotted the learned mean parameters. We also compared the risk predictions using these different time priors."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# sampling parameters\n",
- "ndim, nwalker, nstep, burnin = 7, 200, 2000, 1000\n",
- "moves = [(emcee.moves.DEMove(), 0.8), (emcee.moves.DESnookerMove(), 0.2)]\n",
- "\n",
- "# starting point\n",
- "np.random.seed(SEED)\n",
- "theta0 = np.random.uniform(low=0., high=1., size=(nwalker,ndim))\n",
- "prior_dict = {}\n",
- "\n",
- "samples_multiT_HMM = np.zeros(shape=(max_prior_length, \n",
- " nwalker*(nstep-burnin), \n",
- " ndim))\n",
- "\n",
- "for i in range(max_prior_length):\n",
- " if DRAW_SAMPLES:\n",
- " if __name__ == \"__main__\":\n",
- " with Pool() as pool:\n",
- " # uniform time prior of varying length\n",
- " prior_dict['early'] = np.concatenate([[0.], [1./(i+1)]*(i+1)])\n",
- "\n",
- " sampler = emcee.EnsembleSampler(nwalker, ndim, systm.likelihood, \n",
- " args=[['early'], prior_dict], \n",
- " kwargs={\"mode\": \"HMM\"},\n",
- " moves=moves, pool=pool)\n",
- " sampler.run_mcmc(theta0, nstep, progress=True)\n",
- "\n",
- " # store the i-th sampling round\n",
- " samples_multiT_HMM[i] = sampler.get_chain(flat=True, \n",
- " discard=burnin)\n",
- "\n",
- " # saving the sampled data to disk for later convenience\n",
- " np.save(f\"./samples/multiT_HMM_{i}.npy\", samples_multiT_HMM[i])\n",
- " \n",
- " else:\n",
- " # loading in case we don't want to draw all the samples again\n",
- " samples_multiT_HMM[i] = np.load(f\"./samples/multiT_HMM_{i}.npy\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Plot the results vs the theoretical values. This is only applicable to the base probability rates $\\tilde{b}_1$ and $\\tilde{b}_2$, since all other LNL's probability rates are also influenced by other efferent spread pathways."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# compute mean and variance of sampled parameters\n",
- "multiT_b = [np.mean(samples_multiT_HMM[:,:,i], axis=1) for i in range(4)]\n",
- "multiT_b_var = [np.var(samples_multiT_HMM[:,:,i], axis=1) for i in range(4)]\n",
- "multiT_t = [np.mean(samples_multiT_HMM[1:,:,i], axis=1) for i in range(4,7)]\n",
- "multiT_t_var = [np.var(samples_multiT_HMM[1:,:,i], axis=1) for i in range(4,7)]\n",
- "\n",
- "# plot\n",
- "fig, ax = plt.subplots(1,2, figsize=set_size(\"full\", ratio=2.8))\n",
- "\n",
- "# plot base prob sample means for differently lengthed time priors\n",
- "for i in range(4):\n",
- " ax[0].plot(range(1,max_prior_length+1), multiT_b[i], \"o\", mfc=\"none\", ms=4, \n",
- " label=f\"$\\\\tilde{{b}}_{{{i+1}}}$ sampled\");\n",
- "\n",
- "# prevalence of involvement in LNL I and II\n",
- "prevalence_I = np.sum(data[(\"path\", \"I\")].to_numpy()) / len(data)\n",
- "prevalence_II = np.sum(data[(\"path\", \"II\")].to_numpy()) / len(data)\n",
- "theory_roots = np.zeros(shape=(2,max_prior_length))\n",
- "\n",
- "# compute roots with P^* = prevalence\n",
- "for t in range(1,max_prior_length+1):\n",
- " theory_roots[0,t-1] = sp.optimize.root_scalar(f, \n",
- " args=(t, prevalence_I), \n",
- " bracket=[0.001, 0.999]).root\n",
- " theory_roots[1,t-1] = sp.optimize.root_scalar(f, \n",
- " args=(t, prevalence_II), \n",
- " bracket=[0.001, 0.999]).root\n",
- "\n",
- "# plot theoretical roots\n",
- "ax[0].plot(range(1,max_prior_length+1), theory_roots[0], \"-\", alpha=0.5, \n",
- " label=r\"$\\tilde{b}_{1}$ theory\");\n",
- "ax[0].plot(range(1,max_prior_length+1), theory_roots[1], \"-\", alpha=0.5, \n",
- " label=r\"$\\tilde{b}_{2}$ theory\");\n",
- "\n",
- "offset = 0.5\n",
- "ax[0].set_xlim([1 - offset, max_prior_length + offset]);\n",
- "ax[0].set_xticks(np.arange(1, max_prior_length+1, 2))\n",
- "ax[0].set_xlabel(\"length of prior $T$\");\n",
- "ax[0].set_ylabel(r\"Base Probability Rate $\\tilde{b}$\");\n",
- "ax[0].legend(ncol=1);\n",
- "\n",
- "\n",
- "# plot trans prob sample means\n",
- "# plot sample means for differently lengthed time priors\n",
- "for i in range(3):\n",
- " ax[1].plot(range(2,max_prior_length+1), multiT_t[i], \"o\", mfc=\"none\", ms=4, \n",
- " label=f\"$\\\\tilde{{t}}_{{{i+1}}}$ sampled\");\n",
- " \n",
- "offset = 0.5\n",
- "ax[1].set_xlim([1 - offset, max_prior_length + offset]);\n",
- "ax[1].set_xticks(np.arange(1, max_prior_length+1, 2))\n",
- "ax[1].set_xlabel(\"length of prior $T$\");\n",
- "ax[1].set_ylabel(r\"Transition Probability Rate $\\tilde{t}$\");\n",
- "ax[1].legend(ncol=1);\n",
- "\n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/rate_decay_theory_vs_sampled.png\", dpi=300, \n",
- " bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/rate_decay_theory_vs_sampled.svg\", \n",
- " bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "These plots show that the probability rates of a more complex and realistic system qualitatively follow the same behaviour as the example with only one node. This serves as an argument why we can essentially choose the length of the prior as it suits us.\n",
- "\n",
- "Finally, let's check if the risk actually changes:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# abstract representation of the lymphatic network\n",
- "graph = {('tumor', 'primary') : ['I', 'II', 'III', 'IV'], \n",
- " ('lnl', 'I') : ['II'], \n",
- " ('lnl', 'II') : ['III'], \n",
- " ('lnl', 'III') : ['IV'], \n",
- " ('lnl', 'IV') : []}\n",
- "\n",
- "systm = lymph.System(graph=graph)\n",
- "\n",
- "# set specificity & sensitivity of diagnostic modality (here CT) manually\n",
- "ct_spsn = {\"CT\": [0.76, 0.81]}\n",
- "systm.modalities = ct_spsn"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# what do we want to know, what do we know?\n",
- "inv = np.array([None, None, 1, None]) # we're interested in the risk of lvl 3 being involved\n",
- "# our observation is that lvl 2 is involved\n",
- "obs = {\"CT\": np.array([0, 1, 0, 0])}\n",
- "\n",
- "np.random.seed(SEED)\n",
- "ndim, nwalker, nstep, burnin = 7, 200, 2000, 1000\n",
- "thin = 200\n",
- "time_dists = {}\n",
- "hmm_risk = np.zeros(shape=(max_prior_length, (nstep-burnin)*nwalker//thin))\n",
- "\n",
- "for k in range(max_prior_length):\n",
- " # time priors\n",
- " time_dists['early'] = np.concatenate([[0.], [1./(k+1)] * (k+1)])\n",
- "\n",
- " # risk for HMM and two different \"T-stages\" (early and late)\n",
- " subset = np.random.permutation(samples_multiT_HMM[k])[::thin]\n",
- " for i, sample in enumerate(subset):\n",
- " systm.spread_probs = sample\n",
- " hmm_risk[k, i] = systm.risk(\n",
- " inv=inv, diagnoses=obs, \n",
- " time_dist=time_dists[\"early\"], mode=\"HMM\"\n",
- " )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "bn_risk = np.zeros(shape=(len(samples_BN) // thin))\n",
- "for i, sample in enumerate(np.random.permutation(samples_BN)[::thin]):\n",
- " systm.spread_probs = sample\n",
- " bn_risk[i] = systm.risk(inv=inv, diagnoses=obs, mode=\"BN\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "fig, ax = plt.subplots(figsize=set_size())\n",
- "\n",
- "ax.hist(100*hmm_risk[0], bins=100, range=[0., 50], density=True, \n",
- " color=usz_red, alpha=0.1, label=f\"HMM ({max_prior_length} different $T$)\")\n",
- "[ax.hist(100*hmm_risk[k], bins=100, range=[0., 50], density=True, \n",
- " color=usz_red, alpha=0.1) for k in np.arange(1,max_prior_length)];\n",
- "ax.hist(100*bn_risk, histtype=\"step\", bins=100, linewidth=2, range=[0., 50], \n",
- " density=True, label=\"BN\");\n",
- "ax.set_xlim([0., 20]);\n",
- "ax.legend();\n",
- "ax.set_xlabel(r\"risk $R$ [%]\");\n",
- "ax.set_ylabel(\"PDF of risk\");\n",
- "\n",
- "if SAVE_FIGURES:\n",
- " plt.savefig(\"./figures/multi_length_risk.png\", dpi=300, \n",
- " bbox_inches=\"tight\")\n",
- " plt.savefig(\"./figures/multi_length_risk.svg\", \n",
- " bbox_inches=\"tight\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The risk for all the HMMs is the same, except for the one that only includes one time step. The reason for this is that when the system is not given the time to spread, it cannot correctly estimate the conditional risks of what would happen if the previous node was involved vs if it was not involved. So the HMM that is \"too short\" overestimates the risk due to the high prevalence of LNL III's involvement."
- ]
- }
- ],
- "metadata": {
- "interpreter": {
- "hash": "178d24c204a30672ea1fdde86877e76d36aba87ee79c03bdea66a58d070c8fdb"
- },
- "kernelspec": {
- "display_name": "Python 3.8.10 64-bit ('.venv': venv)",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.10"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {
- "height": "235px",
- "width": "277px"
- },
- "number_sections": false,
- "sideBar": true,
- "skip_h1_title": true,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {
- "height": "calc(100% - 180px)",
- "left": "10px",
- "top": "150px",
- "width": "332.391px"
- },
- "toc_section_display": true,
- "toc_window_display": true
- },
- "varInspector": {
- "cols": {
- "lenName": 16,
- "lenType": 16,
- "lenVar": 40
- },
- "kernels_config": {
- "python": {
- "delete_cmd_postfix": "",
- "delete_cmd_prefix": "del ",
- "library": "var_list.py",
- "varRefreshCmd": "print(var_dic_list())"
- },
- "r": {
- "delete_cmd_postfix": ") ",
- "delete_cmd_prefix": "rm(",
- "library": "var_list.r",
- "varRefreshCmd": "cat(var_dic_list()) "
- }
- },
- "types_to_exclude": [
- "module",
- "function",
- "builtin_function_or_method",
- "instance",
- "_Feature"
- ],
- "window_display": false
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/notebook/samples/BN.npy b/notebook/samples/BN.npy
deleted file mode 100644
index c68608a..0000000
Binary files a/notebook/samples/BN.npy and /dev/null differ
diff --git a/notebook/samples/HMM.npy b/notebook/samples/HMM.npy
deleted file mode 100644
index af19386..0000000
Binary files a/notebook/samples/HMM.npy and /dev/null differ
diff --git a/notebook/samples/cross-validation-samples1.npy b/notebook/samples/cross-validation-samples1.npy
deleted file mode 100644
index 77c3e44..0000000
Binary files a/notebook/samples/cross-validation-samples1.npy and /dev/null differ
diff --git a/notebook/samples/cross-validation-samples2.npy b/notebook/samples/cross-validation-samples2.npy
deleted file mode 100644
index 3c193d9..0000000
Binary files a/notebook/samples/cross-validation-samples2.npy and /dev/null differ
diff --git a/notebook/samples/cross-validation-samples3.npy b/notebook/samples/cross-validation-samples3.npy
deleted file mode 100644
index 9643b2c..0000000
Binary files a/notebook/samples/cross-validation-samples3.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_0.npy b/notebook/samples/multiT_HMM_0.npy
deleted file mode 100644
index a8ff7fa..0000000
Binary files a/notebook/samples/multiT_HMM_0.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_1.npy b/notebook/samples/multiT_HMM_1.npy
deleted file mode 100644
index 2e4fdee..0000000
Binary files a/notebook/samples/multiT_HMM_1.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_10.npy b/notebook/samples/multiT_HMM_10.npy
deleted file mode 100644
index 5df95c8..0000000
Binary files a/notebook/samples/multiT_HMM_10.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_11.npy b/notebook/samples/multiT_HMM_11.npy
deleted file mode 100644
index 9d648f8..0000000
Binary files a/notebook/samples/multiT_HMM_11.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_12.npy b/notebook/samples/multiT_HMM_12.npy
deleted file mode 100644
index ba88905..0000000
Binary files a/notebook/samples/multiT_HMM_12.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_13.npy b/notebook/samples/multiT_HMM_13.npy
deleted file mode 100644
index 47ccfc5..0000000
Binary files a/notebook/samples/multiT_HMM_13.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_14.npy b/notebook/samples/multiT_HMM_14.npy
deleted file mode 100644
index 11f438b..0000000
Binary files a/notebook/samples/multiT_HMM_14.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_2.npy b/notebook/samples/multiT_HMM_2.npy
deleted file mode 100644
index a918758..0000000
Binary files a/notebook/samples/multiT_HMM_2.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_3.npy b/notebook/samples/multiT_HMM_3.npy
deleted file mode 100644
index b6b8de9..0000000
Binary files a/notebook/samples/multiT_HMM_3.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_4.npy b/notebook/samples/multiT_HMM_4.npy
deleted file mode 100644
index ce9a51c..0000000
Binary files a/notebook/samples/multiT_HMM_4.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_5.npy b/notebook/samples/multiT_HMM_5.npy
deleted file mode 100644
index b32e767..0000000
Binary files a/notebook/samples/multiT_HMM_5.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_6.npy b/notebook/samples/multiT_HMM_6.npy
deleted file mode 100644
index 08be933..0000000
Binary files a/notebook/samples/multiT_HMM_6.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_7.npy b/notebook/samples/multiT_HMM_7.npy
deleted file mode 100644
index 754fd48..0000000
Binary files a/notebook/samples/multiT_HMM_7.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_8.npy b/notebook/samples/multiT_HMM_8.npy
deleted file mode 100644
index 09dad81..0000000
Binary files a/notebook/samples/multiT_HMM_8.npy and /dev/null differ
diff --git a/notebook/samples/multiT_HMM_9.npy b/notebook/samples/multiT_HMM_9.npy
deleted file mode 100644
index 9fa303d..0000000
Binary files a/notebook/samples/multiT_HMM_9.npy and /dev/null differ
diff --git a/notebook/samples/simultaneous.npy b/notebook/samples/simultaneous.npy
deleted file mode 100644
index 4bbd4ea..0000000
Binary files a/notebook/samples/simultaneous.npy and /dev/null differ
diff --git a/pyproject.toml b/pyproject.toml
index aeb0d98..7a9e2b3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,36 +17,31 @@ requires-python = ">=3.10"
keywords = ["cancer", "metastasis", "lymphatic progression", "model"]
license = {text = "MIT"}
classifiers = [
+ "Development Status :: 3 - Alpha",
"License :: OSI Approved :: MIT License",
+ "Natural Language :: English",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: Implementation :: CPython",
- "Programming Language :: Python :: Implementation :: PyPy",
"Operating System :: OS Independent",
+ "Topic :: Scientific/Engineering",
]
dependencies = [
- "numpy",
- "scipy",
- "pandas",
- "emcee",
- "h5py",
- "tables",
- "tqdm",
+ "numpy < 2",
+ "pandas < 3",
+ "cachetools < 6",
]
dynamic = ["version"]
[project.optional-dependencies]
test = [
- "pytest",
- "coverage",
- "hypothesis",
+ "scipy < 2",
+ "coverage < 8",
]
dev = [
"pre-commit",
- "isort",
- "pycln",
"pylint",
+ "git-cliff",
]
docs = [
"sphinx",
@@ -55,6 +50,7 @@ docs = [
"myst-nb",
"ipython",
"matplotlib",
+ "scipy",
]
[project.urls]
@@ -136,7 +132,7 @@ commit_preprocessors = [
commit_parsers = [
{ message = "^feat", group = "Features" },
{ message = "^fix", group = "Bug Fixes" },
- { message = "^doc", group = "Documentation" },
+ { message = "^docs", group = "Documentation" },
{ message = "^perf", group = "Performance" },
{ message = "^refactor", group = "Refactor" },
{ message = "^style", group = "Styling" },
diff --git a/tests/binary_bilateral_test.py b/tests/binary_bilateral_test.py
index 9c7ef44..aea5702 100644
--- a/tests/binary_bilateral_test.py
+++ b/tests/binary_bilateral_test.py
@@ -14,41 +14,60 @@ class BilateralInitTest(fixtures.BilateralModelMixin, unittest.TestCase):
"""Test the delegation of attrs from the unilateral class to the bilateral one."""
def setUp(self):
+ self.model_kwargs = {"is_symmetric": {
+ "tumor_spread": True,
+ "lnl_spread": True,
+ "modalities": True,
+ }}
super().setUp()
self.load_patient_data()
def test_delegation(self):
"""Test that the unilateral model delegates the attributes."""
- self.assertEqual(
- self.model.is_binary, self.model.ipsi.is_binary
- )
- self.assertEqual(
- self.model.is_trinary, self.model.ipsi.is_trinary
- )
- self.assertEqual(
- self.model.max_time, self.model.ipsi.max_time
- )
- self.assertEqual(
- list(self.model.t_stages), list(self.model.ipsi.t_stages)
- )
-
- def test_lnl_edge_sync(self):
- """Check if synced LNL edges update their respective parameters."""
- rng = np.random.default_rng(42)
- for ipsi_edge in self.model.ipsi.graph.lnl_edges.values():
- contra_edge = self.model.contra.graph.lnl_edges[ipsi_edge.name]
- ipsi_edge.set_params(spread=rng.random())
+ self.assertEqual(self.model.is_binary, self.model.ipsi.is_binary)
+ self.assertEqual(self.model.is_trinary, self.model.ipsi.is_trinary)
+ self.assertEqual(self.model.max_time, self.model.ipsi.max_time)
+ self.assertEqual(list(self.model.t_stages), list(self.model.ipsi.t_stages))
+
+ def test_edge_sync(self):
+ """Check if synced edges update their respective parameters."""
+ for ipsi_edge in self.model.ipsi.graph.edges.values():
+ contra_edge = self.model.contra.graph.edges[ipsi_edge.name]
+ ipsi_edge.set_params(spread=self.rng.random())
self.assertEqual(
ipsi_edge.get_params("spread"),
contra_edge.get_params("spread"),
)
+ def test_tensor_sync(self):
+ """Check the transition tensors of the edges get deleted and updated properly."""
+ for ipsi_edge in self.model.ipsi.graph.edges.values():
+ ipsi_edge.set_params(spread=self.rng.random())
+ contra_edge = self.model.contra.graph.edges[ipsi_edge.name]
+ self.assertTrue(np.all(
+ ipsi_edge.transition_tensor == contra_edge.transition_tensor
+ ))
+
+ def test_transition_matrix_sync(self):
+ """Make sure contra transition matrix gets recomputed when ipsi param is set."""
+ ipsi_trans_mat = self.model.ipsi.transition_matrix()
+ contra_trans_mat = self.model.contra.transition_matrix()
+ rand_ipsi_param = self.rng.choice(list(
+ self.model.ipsi.get_params(as_dict=True).keys()
+ ))
+ self.model.assign_params(**{f"ipsi_{rand_ipsi_param}": self.rng.random()})
+ self.assertFalse(np.all(
+ ipsi_trans_mat == self.model.ipsi.transition_matrix()
+ ))
+ self.assertFalse(np.all(
+ contra_trans_mat == self.model.contra.transition_matrix()
+ ))
+
def test_modality_sync(self):
"""Make sure the modalities are synced between the two sides."""
- rng = np.random.default_rng(42)
self.model.ipsi.modalities = {"foo": Clinical(
- specificity=rng.uniform(),
- sensitivity=rng.uniform(),
+ specificity=self.rng.uniform(),
+ sensitivity=self.rng.uniform(),
)}
self.assertEqual(
self.model.ipsi.modalities["foo"].sensitivity,
@@ -280,3 +299,33 @@ def test_risk(self):
)
self.assertLessEqual(risk, 1.)
self.assertGreaterEqual(risk, 0.)
+
+
+class DataGenerationTestCase(fixtures.BilateralModelMixin, unittest.TestCase):
+ """Check the binary model's data generation method."""
+
+ def setUp(self):
+ super().setUp()
+ self.model.modalities = fixtures.MODALITIES
+ self.init_diag_time_dists(early="frozen", late="parametric")
+ self.model.assign_params(**self.create_random_params())
+
+ def test_generate_data(self):
+ """Check bilateral data generation."""
+ dataset = self.model.draw_patients(
+ num=10000,
+ stage_dist=[0.5, 0.5],
+ rng=self.rng,
+ )
+
+ for mod in self.model.modalities.keys():
+ self.assertIn(mod, dataset)
+ for side in ["ipsi", "contra"]:
+ self.assertIn(side, dataset[mod])
+ for lnl in self.model.ipsi.graph.lnls.keys():
+ self.assertIn(lnl, dataset[mod][side])
+
+ self.assertAlmostEqual(
+ (dataset["tumor", "1", "t_stage"] == "early").mean(), 0.5,
+ delta=0.02
+ )
diff --git a/tests/binary_unilateral_test.py b/tests/binary_unilateral_test.py
index cf80dc4..5e939ff 100644
--- a/tests/binary_unilateral_test.py
+++ b/tests/binary_unilateral_test.py
@@ -5,6 +5,7 @@
import numpy as np
from lymph.graph import LymphNodeLevel, Tumor
+from lymph.modalities import Pathological
class InitTestCase(fixtures.BinaryUnilateralModelMixin, unittest.TestCase):
@@ -140,17 +141,13 @@ def test_params_assignment_via_method(self):
)
def test_transition_matrix_deletion(self):
- """Check if the transition matrix gets deleted when a parameter is set.
-
- NOTE: This test is disabled because apparently, the `model` instance is
- changed during the test and the `_transition_matrix` attribute is deleted on
- the wrong instance. I have no clue why, but generally, the method works.
- """
+ """Check if the transition matrix gets deleted when a parameter is set."""
first_lnl_name = list(self.model.graph.lnls.values())[0].name
- _ = self.model.transition_matrix
- self.assertTrue("transition_matrix" in self.model.__dict__)
+ trans_mat = self.model.transition_matrix()
self.model.graph.edges[f"T_to_{first_lnl_name}"].set_spread_prob(0.5)
- self.assertFalse("transition_matrix" in self.model.__dict__)
+ self.assertFalse(np.all(
+ trans_mat == self.model.transition_matrix()
+ ))
class TransitionMatrixTestCase(fixtures.BinaryUnilateralModelMixin, unittest.TestCase):
@@ -164,11 +161,11 @@ def setUp(self):
def test_shape(self):
"""Make sure the transition matrix has the correct shape."""
num_lnls = len({name for kind, name in self.graph_dict if kind == "lnl"})
- self.assertEqual(self.model.transition_matrix.shape, (2**num_lnls, 2**num_lnls))
+ self.assertEqual(self.model.transition_matrix().shape, (2**num_lnls, 2**num_lnls))
def test_is_probabilistic(self):
"""Make sure the rows of the transition matrix sum to one."""
- row_sums = np.sum(self.model.transition_matrix, axis=1)
+ row_sums = np.sum(self.model.transition_matrix(), axis=1)
self.assertTrue(np.allclose(row_sums, 1.))
@staticmethod
@@ -189,7 +186,7 @@ def is_recusively_upper_triangular(mat: np.ndarray) -> bool:
def test_is_recusively_upper_triangular(self) -> None:
"""Make sure the transition matrix is recursively upper triangular."""
- self.assertTrue(self.is_recusively_upper_triangular(self.model.transition_matrix))
+ self.assertTrue(self.is_recusively_upper_triangular(self.model.transition_matrix()))
class ObservationMatrixTestCase(fixtures.BinaryUnilateralModelMixin, unittest.TestCase):
@@ -205,11 +202,11 @@ def test_shape(self):
num_lnls = len(self.model.graph.lnls)
num_modalities = len(self.model.modalities)
expected_shape = (2**num_lnls, 2**(num_lnls * num_modalities))
- self.assertEqual(self.model.observation_matrix.shape, expected_shape)
+ self.assertEqual(self.model.observation_matrix().shape, expected_shape)
def test_is_probabilistic(self):
"""Make sure the rows of the observation matrix sum to one."""
- row_sums = np.sum(self.model.observation_matrix, axis=1)
+ row_sums = np.sum(self.model.observation_matrix(), axis=1)
self.assertTrue(np.allclose(row_sums, 1.))
@@ -247,7 +244,6 @@ def test_t_stages(self):
self.assertIn(t_stage, t_stages_in_data)
self.assertIn(t_stage, t_stages_in_diag_time_dists)
-
def test_data_matrices(self):
"""Make sure the data matrices are generated correctly."""
for t_stage in ["early", "late"]:
@@ -260,14 +256,13 @@ def test_data_matrices(self):
self.assertTrue(t_stage in self.model.data_matrices)
self.assertEqual(
data_matrix.shape[0],
- self.model.observation_matrix.shape[1],
+ self.model.observation_matrix().shape[1],
)
self.assertEqual(
data_matrix.shape[1],
has_t_stage.sum(),
)
-
def test_diagnose_matrices(self):
"""Make sure the diagnose matrices are generated correctly."""
for t_stage in ["early", "late"]:
@@ -280,7 +275,7 @@ def test_diagnose_matrices(self):
self.assertTrue(t_stage in self.model.diagnose_matrices)
self.assertEqual(
diagnose_matrix.shape[0],
- self.model.transition_matrix.shape[1],
+ self.model.transition_matrix().shape[1],
)
self.assertEqual(
diagnose_matrix.shape[1],
@@ -343,7 +338,6 @@ def create_random_diagnoses(self):
return diagnoses
-
def test_comp_diagnose_encoding(self):
"""Check computation of one-hot encoding of diagnoses."""
random_diagnoses = self.create_random_diagnoses()
@@ -381,3 +375,70 @@ def test_risk(self):
self.assertEqual(risk.dtype, float)
self.assertGreaterEqual(risk, 0.)
self.assertLessEqual(risk, 1.)
+
+
+class DataGenerationTestCase(fixtures.BinaryUnilateralModelMixin, unittest.TestCase):
+ """Check the data generation utilities."""
+
+ def setUp(self):
+ """Load params."""
+ super().setUp()
+ self.model.modalities = fixtures.MODALITIES
+ self.init_diag_time_dists(early="frozen", late="parametric")
+ self.model.assign_params(**self.create_random_params())
+
+ def test_generate_early_patients(self):
+ """Check that generating only early T-stage patients works."""
+ early_patients = self.model.draw_patients(
+ num=100,
+ stage_dist=[1., 0.],
+ rng=self.rng,
+ )
+ self.assertEqual(len(early_patients), 100)
+ self.assertEqual(sum(early_patients["tumor", "1", "t_stage"] == "early"), 100)
+ self.assertIn(("CT", "ipsi", "II"), early_patients.columns)
+ self.assertIn(("FNA", "ipsi", "III"), early_patients.columns)
+
+ def test_generate_late_patients(self):
+ """Check that generating only late T-stage patients works."""
+ late_patients = self.model.draw_patients(
+ num=100,
+ stage_dist=[0., 1.],
+ rng=self.rng,
+ )
+ self.assertEqual(len(late_patients), 100)
+ self.assertEqual(sum(late_patients["tumor", "1", "t_stage"] == "late"), 100)
+ self.assertIn(("CT", "ipsi", "II"), late_patients.columns)
+ self.assertIn(("FNA", "ipsi", "III"), late_patients.columns)
+
+ def test_distribution_of_patients(self):
+ """Check that the distribution of LNL involvement is correct."""
+ # set spread params all to 0
+ for lnl_edge in self.model.graph.lnl_edges.values():
+ lnl_edge.set_spread_prob(0.)
+
+ # make all patients diagnosed after exactly one time-step
+ self.model.diag_time_dists["early"] = [0,1,0,0,0,0,0,0,0,0,0]
+
+ # assign only one pathology modality
+ self.model.modalities = {"tmp": Pathological(specificity=1., sensitivity=1.)}
+
+ # extract the tumor spread parameters
+ params = self.model.get_params(as_dict=True)
+ params = {
+ key.replace("T_to_", "").replace("_spread", ""): value
+ for key, value in params.items()
+ if "T_to_" in key
+ }
+
+ # draw large enough amount of patients
+ patients = self.model.draw_patients(
+ num=10000,
+ stage_dist=[1., 0.],
+ rng=self.rng,
+ )
+
+ # check that the distribution of LNL involvement matches tumor spread params
+ for lnl, expected_mean in params.items():
+ actual_mean = patients[("tmp", "ipsi", lnl)].mean()
+ self.assertAlmostEqual(actual_mean, expected_mean, delta=0.02)
diff --git a/tests/doc_test.py b/tests/doc_test.py
new file mode 100644
index 0000000..7265cf3
--- /dev/null
+++ b/tests/doc_test.py
@@ -0,0 +1,21 @@
+"""
+Make doctests in the lymph package discoverable by unittest.
+"""
+import doctest
+import unittest
+
+from lymph import diagnose_times, graph, helper, matrix, modalities
+from lymph.models import bilateral, unilateral
+
+
+def load_tests(loader, tests: unittest.TestSuite, ignore):
+ """Load doctests from the lymph package."""
+ tests.addTests(doctest.DocTestSuite(diagnose_times))
+ tests.addTests(doctest.DocTestSuite(graph))
+ tests.addTests(doctest.DocTestSuite(helper))
+ tests.addTests(doctest.DocTestSuite(matrix))
+ tests.addTests(doctest.DocTestSuite(modalities))
+
+ tests.addTests(doctest.DocTestSuite(unilateral))
+ tests.addTests(doctest.DocTestSuite(bilateral))
+ return tests
diff --git a/tests/edge_test.py b/tests/edge_test.py
index ed9262e..623bb28 100644
--- a/tests/edge_test.py
+++ b/tests/edge_test.py
@@ -45,9 +45,6 @@ def test_callback_on_param_change(self) -> None:
"""Test if the callback function is called."""
self.edge.spread_prob = 0.5
self.assertTrue(self.was_called)
- self.assertFalse(hasattr(self.edge, "_transition_tensor"))
- _ = self.edge.transition_tensor
- self.assertTrue(hasattr(self.edge, "_transition_tensor"))
def test_graph_change(self) -> None:
"""Check if the callback also works when parent/child nodes are changed."""
diff --git a/tests/trinary_unilateral_test.py b/tests/trinary_unilateral_test.py
index d6870f4..f09c0dc 100644
--- a/tests/trinary_unilateral_test.py
+++ b/tests/trinary_unilateral_test.py
@@ -40,21 +40,21 @@ def test_edge_transition_tensors(self) -> None:
NOTE: I am using this only in debug mode to look a the tensors. I am not sure
how to test them yet.
"""
- base_edge_tensor = list(self.model.graph.tumor_edges.values())[0].comp_transition_tensor()
+ base_edge_tensor = list(self.model.graph.tumor_edges.values())[0].transition_tensor
row_sums = base_edge_tensor.sum(axis=2)
self.assertTrue(np.allclose(row_sums, 1.0))
- lnl_edge_tensor = list(self.model.graph.lnl_edges.values())[0].comp_transition_tensor()
+ lnl_edge_tensor = list(self.model.graph.lnl_edges.values())[0].transition_tensor
row_sums = lnl_edge_tensor.sum(axis=2)
self.assertTrue(np.allclose(row_sums, 1.0))
- growth_edge_tensor = list(self.model.graph.growth_edges.values())[0].comp_transition_tensor()
+ growth_edge_tensor = list(self.model.graph.growth_edges.values())[0].transition_tensor
row_sums = growth_edge_tensor.sum(axis=2)
self.assertTrue(np.allclose(row_sums, 1.0))
def test_transition_matrix(self) -> None:
"""Test the transition matrix of the model."""
- transition_matrix = self.model.transition_matrix
+ transition_matrix = self.model.transition_matrix()
row_sums = transition_matrix.sum(axis=1)
self.assertTrue(np.allclose(row_sums, 1.0))
@@ -72,7 +72,7 @@ def test_observation_matrix(self) -> None:
"""Test the observation matrix of the model."""
num_lnls = len(self.model.graph.lnls)
num = num_lnls * len(self.model.modalities)
- observation_matrix = self.model.observation_matrix
+ observation_matrix = self.model.observation_matrix()
self.assertEqual(observation_matrix.shape, (3 ** num_lnls, 2 ** num))
row_sums = observation_matrix.sum(axis=1)
@@ -84,12 +84,12 @@ class TrinaryDiagnoseMatricesTestCase(fixtures.TrinaryFixtureMixin, unittest.Tes
def setUp(self):
super().setUp()
- self.model.load_patient_data(self.get_patient_data(), side="ipsi")
- _ = self.model.diagnose_matrices
+ self.model.modalities = fixtures.MODALITIES
+ self.load_patient_data(filename="2021-usz-oropharynx.csv")
def get_patient_data(self) -> pd.DataFrame:
"""Load an example dataset that has both clinical and pathology data."""
- return pd.read_csv("tests/data/2021-clb-oropharynx.csv", header=[0, 1, 2])
+ return pd.read_csv("tests/data/2021-usz-oropharynx.csv", header=[0, 1, 2])
def test_diagnose_matrices_shape(self) -> None:
"""Test the diagnose matrix of the model."""