diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index a372e35328..55b3c12d08 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,3 +1,7 @@ +# nodestats codespell corrections +f13f8d40d6768f14a46315d105b3d80beaa33aaf +284f113cf8416eed43b88b4c438d7552456616e3 + # numpydoc validation - Linting docstrings with pre-commit numpydoc 836e1bf5347eb7a9f97e784e783045c0287b3fe9 2580811edefac867938ee0c4b705649a493e5d4f diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 5656df4534..16831aa5a8 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -28,7 +28,7 @@ jobs: strategy: matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 - name: Set up Python diff --git a/.gitignore b/.gitignore index a9a224ac47..007ecbb5b3 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ output/config.yaml *.npy *.csv *.spm +*.svg # Notebook checkpoints notebooks/.ipynb_checkpoints/ @@ -34,6 +35,7 @@ notebooks/.ipynb_checkpoints/ .vscode/ .idea/ *~ +pytest-debug.ini # Documentation _build/ @@ -57,5 +59,10 @@ topostats/_version.py # default output directory, often common from testing output/ +# Include all files in tests and all subdirectories except processed and __pycache__ +!tests/** +tests/resources/processed/ +__pycache__/ + # Debugging -pytest-debug.ini \ No newline at end of file +pytest-debug.ini diff --git a/.markdownlint-cli2.yaml b/.markdownlint-cli2.yaml index a245804cdf..7cf6ccbc88 100644 --- a/.markdownlint-cli2.yaml +++ b/.markdownlint-cli2.yaml @@ -8,6 +8,7 @@ config: html: allowed_elements: - div + - br # Globs globs: diff --git a/.pylintrc b/.pylintrc index 269b895464..8c7e0e29a9 100644 --- a/.pylintrc +++ b/.pylintrc @@ -22,7 +22,6 @@ fail-under=10.0 # Files or directories to be skipped. They should be base names, not paths. ignore=CVS, _version.py, - dnatracing.py, dnacurvature.py, conf.py, test_dnacurvature.py, diff --git a/contributing.md b/contributing.md index b2dd0978b4..7a430f21cf 100644 --- a/contributing.md +++ b/contributing.md @@ -22,9 +22,17 @@ not currently considered core elements of `topostats` apart from `Plotting.py`. Currently the `topostats` module consists of: - `default_config.ini` The default config file. -- `dnatracing.py` Applies tracing functions to each molecule. -- `pygwytracing.py` The "main" routine. -- `tracingfuncs.py` Skeletonises and generates backbone traces from masks. +- `run_topostats.py` The main script for running `topostats` and handles multiprocessing. +- `entry_point.py` Handles the entry point commandline commands. +- `processing.py` Handles processing a single AFM file. +- `filters.py` Handles flattening and pre-processing images. +- `grains.py` Handles grain segmentation. +- `grainstats.py` Calculates statistics for each grain. +- `disordered_tracing.py` Initial tracing of molecules. +- `nodestats.py` Handles any crossings / branches in DNA grains. +- `ordered_tracing.py` Proper tracing of molecules. +- `splining.py` Smooths / splines the traces for a more representative trace. +- `plotting.py` Handles plotting of the data. The current working plan is to move to a more modular architecture with new (and existing) functionality being grouped by theme within files. We expect to add such files as: diff --git a/docs/configuration.md b/docs/configuration.md index 2bbd17c18f..877f2331a8 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -54,59 +54,73 @@ above: Aside from the comments in YAML file itself the fields are described below. -| Section | Sub-Section | Data Type | Default | Description | -| :-------------- | :-------------------------------- | :------------- | :-------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `base_dir` | | string | `./` | Directory to recursively search for files within.[^1] | -| `output_dir` | | string | `./output` | Directory that output should be saved to.[^1] | -| `log_level` | | string | `info` | Verbosity of logging, options are (in increasing order) `warning`, `error`, `info`, `debug`. | -| `cores` | | integer | `2` | Number of cores to run parallel processes on. | -| `file_ext` | | string | `.spm` | File extensions to search for. | -| `loading` | `channel` | string | `Height` | The channel of data to be processed, what this is will depend on the file-format you are processing and the channel you wish to process. | -| `filter` | `run` | boolean | `true` | Whether to run the filtering stage, without this other stages won't run so leave as `true`. | -| | `threshold_method` | str | `std_dev` | Threshold method for filtering, options are `ostu`, `std_dev` or `absolute`. | -| | `otsu_threshold_multiplier` | float | `1.0` | Factor by which the derived Otsu Threshold should be scaled. | -| | `threshold_std_dev` | dictionary | `10.0, 1.0` | A pair of values that scale the standard deviation, after scaling the standard deviation `below` is subtracted from the image mean to give the below/lower threshold and the `above` is added to the image mean to give the above/upper threshold. These values should _always_ be positive. | -| | `threshold_absolute` | dictionary | `-1.0, 1.0` | Below (first) and above (second) absolute threshold for separating data from the image background. | -| | `gaussian_size` | float | `0.5` | The number of standard deviations to build the Gaussian kernel and thus affects the degree of blurring. See [skimage.filters.gaussian](https://scikit-image.org/docs/dev/api/skimage.filters.html#skimage.filters.gaussian) and `sigma` for more information. | -| | `gaussian_mode` | string | `nearest` | | -| `grains` | `run` | boolean | `true` | Whether to run grain finding. Options `true`, `false` | -| | `row_alignment_quantile` | float | `0.5` | Quantile (0.0 to 1.0) to be used to determine the average background for the image. below values may improve flattening of large features. | -| | `smallest_grain_size_nm2` | int | `100` | The smallest size of grains to be included (in nm^2), anything smaller than this is considered noise and removed. **NB** must be `> 0.0`. | -| | `threshold_method` | float | `std_dev` | Threshold method for grain finding. Options : `otsu`, `std_dev`, `absolute` | -| | `otsu_threshold_multiplier` | | `1.0` | Factor by which the derived Otsu Threshold should be scaled. | -| | `threshold_std_dev` | dictionary | `10.0, 1.0` | A pair of values that scale the standard deviation, after scaling the standard deviation `below` is subtracted from the image mean to give the below/lower threshold and the `above` is added to the image mean to give the above/upper threshold. These values should _always_ be positive. | -| | `threshold_absolute` | dictionary | `-1.0, 1.0` | Below (first), above (second) absolute threshold for separating grains from the image background. | -| | `direction` | | `above` | Defines whether to look for grains above or below thresholds or both. Options: `above`, `below`, `both` | -| | `smallest_grain_size` | int | `50` | Catch-all value for the minimum size of grains. Measured in nanometres squared. All grains with area below than this value are removed. | -| | `absolute_area_threshold` | dictionary | `[300, 3000], [null, null]` | Area thresholds for above the image background (first) and below the image background (second), which grain sizes are permitted, measured in nanometres squared. All grains outside this area range are removed. | -| | `remove_edge_intersecting_grains` | boolean | `true` | Whether to remove grains that intersect the image border. _Do not change this unless you know what you are doing_. This will ruin any statistics relating to grain size, shape and DNA traces. | -| | `unet_config` | dictionary | `null, 0, 5.0, -1.0` | Configuration for loading a specified U-Net model to override traditional segmentation. Supply a path to a tensorflow U-Net model to use, else U-Net segmentation will be skipped. Values are `path_to_model, grain_crop_padding, upper_norm_bound, lower_norm_bound`. | -| `grainstats` | `run` | boolean | `true` | Whether to calculate grain statistics. Options : `true`, `false` | -| | `cropped_size` | float | `40.0` | Force cropping of grains to this length (in nm) of square cropped images (can take `-1` for grain-sized box) | -| | `edge_detection_method` | str | `binary_erosion` | Type of edge detection method to use when determining the edges of grain masks before calculating statistics on them. Options : `binary_erosion`, `canny`. | -| `dnatracing` | `run` | boolean | `true` | Whether to run DNA Tracing. Options : true, false | -| | `min_skeleton_size` | int | `10` | The minimum number of pixels a skeleton should be for statistics to be calculated on it. Anything smaller than this is dropped but grain statistics are retained. | -| | `skeletonisation_method` | str | `topostats` | Skeletonisation method to use, possible options are `zhang`, `lee`, `thin` (from [Scikit-image Morphology module](https://scikit-image.org/docs/stable/api/skimage.morphology.html)) or the original bespoke TopoStas method `topostats`. | -| | `spline_step_size` | float | `7.0e-9` | The sampling rate of the spline in metres. This is the frequency at which points are sampled from fitted traces to act as guide points for the splining process using scipy's splprep. | -| | `spline_linear_smoothing` | float | `5.0` | The amount of smoothing to apply to splines of linear molecule traces. | -| | `spline_circular_smoothing` | float | `0.0` | The amount of smoothing to apply to splines of circular molecule traces. | -| | `pad_width` | int | 10 | Padding for individual grains when tracing. This is sometimes required if the bounding box around grains is too tight and they touch the edge of the image. | -| | `cores` | int | 1 | Number of cores to use for tracing. **NB** Currently this is NOT used and should be left commented in the YAML file. | -| `plotting` | `run` | boolean | `true` | Whether to run plotting. Options : `true`, `false` | -| | `style` | str | `topostats.mplstyle` | The default loads a custom [matplotlibrc param file](https://matplotlib.org/stable/users/explain/customizing.html#the-matplotlibrc-file) that comes with TopoStats. Users can specify the path to their own style file as an alternative. | -| | `save_format` | string | `null` | Format to save images in, `null` defaults to `png` see [matplotlib.pyplot.savefig](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html) | -| | `savefig_dpi` | string / float | `null` | Dots Per Inch (DPI), if `null` then the value `figure` is used, for other values (typically integers) see [#further-customisation] and [Matplotlib](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html). Low DPI's improve processing time but can reduce the plotted trace (but not the actual trace) accuracy. | -| | `pixel_interpolation` | string | `null` | Interpolation method for image plots. Recommended default 'null' prevents banding that occurs in some images. If interpolation is needed, we recommend `gaussian`. See [matplotlib imshow interpolations documentation](https://matplotlib.org/stable/gallery/images_contours_and_fields/interpolation_methods.html) for details. | -| | `image_set` | string | `all` | Which images to plot. Options : `all`, `core` (flattened image, grain mask overlay and trace overlay only). | -| | `zrange` | list | `[0, 3]` | Low (first number) and high (second number) height range for core images (can take [null, null]). **NB** `low <= high` otherwise you will see a `ValueError: minvalue must be less than or equal to maxvalue` error. | -| | `colorbar` | boolean | `true` | Whether to include the colorbar scale in plots. Options `true`, `false` | -| | `axes` | boolean | `true` | Whether to include the axes in the produced plots. | -| | `num_ticks` | null / int | `null` | Number of ticks to have along the x and y axes. Options : `null` (auto) or an integer >1 | -| | `cmap` | string | `null` | Colormap/colourmap to use (defaults to 'nanoscope' if null (defined in `topostats/topostats.mplstyle`). Other options are 'afmhot', 'viridis' etc., see [Matplotlib : Choosing Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html). | -| | `mask_cmap` | string | `blu` | Color used when masking regions. Options `blu`, `jet_r` or any valid Matplotlib colour. | -| | `histogram_log_axis` | boolean | `false` | Whether to plot hisograms using a logarithmic scale or not. Options: `true`, `false`. | -| `summary_stats` | `run` | boolean | `true` | Whether to generate summary statistical plots of the distribution of different metrics grouped by the image that has been processed. | -| | `config` | str | `null` | Path to a summary config YAML file that configures/controls how plotting is done. If one is not specified either the command line argument `--summary_config` value will be used or if that option is not invoked the default `topostats/summary_config.yaml` will be used. | +| Section | Sub-Section | Data Type | Default | Description | +| :------------------- | :-------------------------------- | :------------- | :-------------------------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `base_dir` | | string | `./` | Directory to recursively search for files within.[^1] | +| `output_dir` | | string | `./output` | Directory that output should be saved to.[^1] | +| `log_level` | | string | `info` | Verbosity of logging, options are (in increasing order) `warning`, `error`, `info`, `debug`. | +| `cores` | | integer | `2` | Number of cores to run parallel processes on. | +| `file_ext` | | string | `.spm` | File extensions to search for. | +| `loading` | `channel` | string | `Height` | The channel of data to be processed, what this is will depend on the file-format you are processing and the channel you wish to process. | +| `filter` | `run` | boolean | `true` | Whether to run the filtering stage, without this other stages won't run so leave as `true`. | +| | `threshold_method` | str | `std_dev` | Threshold method for filtering, options are `ostu`, `std_dev` or `absolute`. | +| | `otsu_threshold_multiplier` | float | `1.0` | Factor by which the derived Otsu Threshold should be scaled. | +| | `threshold_std_dev` | dictionary | `10.0, 1.0` | A pair of values that scale the standard deviation, after scaling the standard deviation `below` is subtracted from the image mean to give the below/lower threshold and the `above` is added to the image mean to give the above/upper threshold. These values should _always_ be positive. | +| | `threshold_absolute` | dictionary | `-1.0, 1.0` | Below (first) and above (second) absolute threshold for separating data from the image background. | +| | `gaussian_size` | float | `0.5` | The number of standard deviations to build the Gaussian kernel and thus affects the degree of blurring. See [skimage.filters.gaussian](https://scikit-image.org/docs/dev/api/skimage.filters.html#skimage.filters.gaussian) and `sigma` for more information. | +| | `gaussian_mode` | string | `nearest` | | +| `grains` | `run` | boolean | `true` | Whether to run grain finding. Options `true`, `false` | +| | `row_alignment_quantile` | float | `0.5` | Quantile (0.0 to 1.0) to be used to determine the average background for the image. below values may improve flattening of large features. | +| | `smallest_grain_size_nm2` | int | `100` | The smallest size of grains to be included (in nm^2), anything smaller than this is considered noise and removed. **NB** must be `> 0.0`. | +| | `threshold_method` | float | `std_dev` | Threshold method for grain finding. Options : `otsu`, `std_dev`, `absolute` | +| | `otsu_threshold_multiplier` | | `1.0` | Factor by which the derived Otsu Threshold should be scaled. | +| | `threshold_std_dev` | dictionary | `10.0, 1.0` | A pair of values that scale the standard deviation, after scaling the standard deviation `below` is subtracted from the image mean to give the below/lower threshold and the `above` is added to the image mean to give the above/upper threshold. These values should _always_ be positive. | +| | `threshold_absolute` | dictionary | `-1.0, 1.0` | Below (first), above (second) absolute threshold for separating grains from the image background. | +| | `direction` | | `above` | Defines whether to look for grains above or below thresholds or both. Options: `above`, `below`, `both` | +| | `smallest_grain_size` | int | `50` | Catch-all value for the minimum size of grains. Measured in nanometres squared. All grains with area below than this value are removed. | +| | `absolute_area_threshold` | dictionary | `[300, 3000], [null, null]` | Area thresholds for above the image background (first) and below the image background (second), which grain sizes are permitted, measured in nanometres squared. All grains outside this area range are removed. | +| | `remove_edge_intersecting_grains` | boolean | `true` | Whether to remove grains that intersect the image border. _Do not change this unless you know what you are doing_. This will ruin any statistics relating to grain size, shape and DNA traces. | +| | `unet_config` | dictionary | `null, 0, 5.0, -1.0` | Configuration for loading a specified U-Net model to override traditional segmentation. Supply a path to a tensorflow U-Net model to use, else U-Net segmentation will be skipped. Values are `path_to_model, grain_crop_padding, upper_norm_bound, lower_norm_bound`. | +| `grainstats` | `run` | boolean | `true` | Whether to calculate grain statistics. Options : `true`, `false` | +| | `cropped_size` | float | `40.0` | Force cropping of grains to this length (in nm) of square cropped images (can take `-1` for grain-sized box) | +| | `edge_detection_method` | str | `binary_erosion` | Type of edge detection method to use when determining the edges of grain masks before calculating statistics on them. Options : `binary_erosion`, `canny`. | +| `disordered_tracing` | `run` | boolean | `true` | Whether to run the Disordered Traces pipeline. Options : true, false | +| | `min_skeleton_size` | int | `10` | The minimum number of pixels a skeleton should be for statistics to be calculated on it. Anything smaller than this is dropped but grain statistics are retained. | +| | `pad_width` | str | `1` | Padding for individual grains when tracing. This is sometimes required if the bounding box around grains is too tight and they touch the edge of the image. | +| | `mask_smoothing_params` | dictionary | `gaussian_sigma:2`
`dilation_iterations:2`
`holearea_min_max:[0, null]` | Parameters to smooth the grain mask for producing skeletons more akin to the underlying structure. First, the amount of smoothing by a gaussian kernel, then the number of dilations to perform to smooth - these compete to see which changes the grain mask least, ensuring quality over different scan sizes. Then, as smoothing fill holes in the mask, the last parameter replaces those within a size range (in nm^2). | +| | `skeletonisation_params` | dictionary | `method:topostats`
`height_bias:0.6` | Parameters to skeletonise grain mask. First, the Skeletonisation method to use, possible options are `zhang`, `lee`, `thin` (from [Scikit-image Morphology module](https://scikit-image.org/docs/stable/api/skimage.morphology.html)) or the original bespoke TopoStats method `topostats`. The height biasing percentage for the `topostats` method. | +| | `pruning_params` | dictionary | `method:topostats`
`max_length:-1`
`height_threshold:null`
`method_values:mid`
`method_outliers:mean_abs` | Parameters to prune unwanted branches from the skeleton. First, the pruning method to use, possible options are `topostats`. The length (in nm) below which to prune branches (default is `-1` meaning 15% of the total length). The height threshold (in nm) option allows pruning of branches below a specified height. The method values determines how the branches height is calculated (options: `min`, `median` and `mid` for middle). Alternatively, the branch height outliers can be removed based on the inter-quartile range `iqr`, an absolute value `abs` or the mean of all branches minus an absolute value `mean_abs`. | +| `nodestats` | `run` | boolean | `true` | Whether to quantify the crossings in an image. Required for over/under tracing through crossings. Options : true, false | +| | `node_joining_length` | float | `7.0` | The distance (nm) over which to join nearby crossing points as the skeletonisation will not always force crossing points to connect. | +| | `node_extend_list` | float | `14.0` | The distance (nm) over which to join nearby odd-branched nodes. | +| | `branch_pairing_length` | float | `nodestats` | The length (nm) from the crossing point to pair the emainating branches and trace along to obtain the over/under distinguishing full-width half-maximum (FWHM's) values. | +| | `pair_odd_branches` | boolean | `true` | Whether to try and pair branches at odd-branch crossing regions and leave one hanging branch, or to leave all branches hanging here. Options: `true` or `false`. | +| | `pad_width` | str | `1` | Padding for individual grains when tracing. This is sometimes required if the bounding box around grains is too tight and they touch the edge of the image. | +| `ordered_tracing` | `run` | boolean | `true` | Whether to order the pruned skeletons of Disordered Traces. Options : true, false | +| | `ordering_method` | str | `nodestats` | The method of ordering the disordered traces either using the nodestats output or solely the disordered traces. Options: `nodestats` or `topostats`. | +| | `pad_width` | int | 10 | Padding for individual grains when tracing. This is sometimes required if the bounding box around grains is too tight and they touch the edge of the image. | +| `splining` | `run` | boolean | `true` | Whether to run ordered trace splining to generate smooth traces. Options : true, false | +| | `method` | int | `rolling_window` | The method used to smooth out the ordered traces. Options: `rolling_window` or `spline`. | +| | `rolling_window_size` | int | `20.0e-9` | The length (in meters) of the coordinate averaging window to smooth the ordered trace. | +| | `spline_step_size` | int | `7.0e-9` | The The sampling length of the spline (in meters) to obtain an average of splines. | +| | `spline_linear_smoothing` | int | `5.0` | The amount of smoothing to apply to linear molecule splines. | +| | `spline_circular_smoothing` | int | `5.0` | The amount of smoothing to apply to circular molecule splines. | +| | `spline_degree` | int | `3` | The polynomial degree of the spline. Smaller, odd degrees work best [SciPy - slprep](https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.splrep.html). | +| `plotting` | `run` | boolean | `true` | Whether to run plotting. Options : `true`, `false` | +| | `style` | str | `topostats.mplstyle` | The default loads a custom [matplotlibrc param file](https://matplotlib.org/stable/users/explain/customizing.html#the-matplotlibrc-file) that comes with TopoStats. Users can specify the path to their own style file as an alternative. | +| | `save_format` | string | `null` | Format to save images in, `null` defaults to `png` see [matplotlib.pyplot.savefig](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html) | +| | `savefig_dpi` | string / float | `null` | Dots Per Inch (DPI), if `null` then the value `figure` is used, for other values (typically integers) see [#further-customisation] and [Matplotlib](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html). Low DPI's improve processing time but can reduce the plotted trace (but not the actual trace) accuracy. | +| | `pixel_interpolation` | string | `null` | Interpolation method for image plots. Recommended default 'null' prevents banding that occurs in some images. If interpolation is needed, we recommend `gaussian`. See [matplotlib imshow interpolations documentation](https://matplotlib.org/stable/gallery/images_contours_and_fields/interpolation_methods.html) for details. | +| | `image_set` | string | `all` | Which images to plot. Options : `all`, `core` (flattened image, grain mask overlay and trace overlay only). | +| | `zrange` | list | `[0, 3]` | Low (first number) and high (second number) height range for core images (can take [null, null]). **NB** `low <= high` otherwise you will see a `ValueError: minvalue must be less than or equal to maxvalue` error. | +| | `colorbar` | boolean | `true` | Whether to include the colorbar scale in plots. Options `true`, `false` | +| | `axes` | boolean | `true` | Whether to include the axes in the produced plots. | +| | `num_ticks` | null / int | `null` | Number of ticks to have along the x and y axes. Options : `null` (auto) or an integer >1 | +| | `cmap` | string | `null` | Colormap/colourmap to use. Defaults to 'nanoscope' if null (defined in `topostats/topostats.mplstyle`). Other options are 'afmhot', 'viridis' etc., see [Matplotlib : Choosing Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html). | +| | `mask_cmap` | string | `blu` | Color used when masking regions. Options `blu`, `jet_r` or any valid Matplotlib colour. | +| | `histogram_log_axis` | boolean | `false` | Whether to plot hisograms using a logarithmic scale or not. Options: `true`, `false`. | +| `summary_stats` | `run` | boolean | `true` | Whether to generate summary statistical plots of the distribution of different metrics grouped by the image that has been processed. | +| | `config` | str | `null` | Path to a summary config YAML file that configures/controls how plotting is done. If one is not specified either the command line argument `--summary_config` value will be used or if that option is not invoked the default `topostats/summary_config.yaml` will be used. | ## Summary Configuration diff --git a/notebooks/00-Walkthrough-minicircle.ipynb b/notebooks/OUTDATED-00-Walkthrough-minicircle.ipynb similarity index 99% rename from notebooks/00-Walkthrough-minicircle.ipynb rename to notebooks/OUTDATED-00-Walkthrough-minicircle.ipynb index 34cab85cc6..8c2f1d8521 100644 --- a/notebooks/00-Walkthrough-minicircle.ipynb +++ b/notebooks/OUTDATED-00-Walkthrough-minicircle.ipynb @@ -870,20 +870,8 @@ ], "metadata": { "kernelspec": { - "argv": [ - "python", - "-m", - "ipykernel_launcher", - "-f", - "{connection_file}" - ], - "display_name": "Python 3 (ipykernel)", - "env": null, - "interrupt_mode": "signal", + "display_name": "new", "language": "python", - "metadata": { - "debugger": true - }, "name": "python3" }, "language_info": { diff --git a/notebooks/02-Summary-statistics-and-plots.ipynb b/notebooks/OUTDATED-02-Summary-statistics-and-plots.ipynb similarity index 100% rename from notebooks/02-Summary-statistics-and-plots.ipynb rename to notebooks/OUTDATED-02-Summary-statistics-and-plots.ipynb diff --git a/notebooks/03-Plotting-scans.ipynb b/notebooks/OUTDATED-03-Plotting-scans.ipynb similarity index 100% rename from notebooks/03-Plotting-scans.ipynb rename to notebooks/OUTDATED-03-Plotting-scans.ipynb diff --git a/pyproject.toml b/pyproject.toml index c1f0bacf98..76243dc889 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ keywords = [ "afm", "image processing" ] -requires-python = ">=3.9" +requires-python = ">=3.9, <3.12" dependencies = [ "AFMReader", "h5py", @@ -44,18 +44,20 @@ dependencies = [ "numpy", "numpyencoder", "pandas", - "pySPM", "pyfiglet", + "pySPM", "pyyaml", "ruamel.yaml", "schema", "scikit-image", "scipy", "seaborn", + "skan", "snoop", + "tensorflow", "tifffile", + "topoly==1.0.2", "tqdm", - "tensorflow", ] [project.optional-dependencies] @@ -127,8 +129,10 @@ write_to = "topostats/_version.py" [tool.pytest.ini_options] minversion = "7.0" -addopts = ["--mpl", "-ra", "--strict-config", "--strict-markers"] -log_cli_level = "Info" +addopts = ["--cov", "--mpl", "-ra", "--strict-config", "--strict-markers"] +log_level = "INFO" +log_cli = true +log_cli_level = "INFO" testpaths = [ "tests", ] @@ -185,13 +189,8 @@ exclude = [ "dist", "docs/conf.py", "node_modules", - "pygwytracing.py", "tests/tracing/test_dnacurvature.py", - "tests/tracing/test_dnatracing.py", - "tests/tracing/test_tracing_dna.py", "topostats/plotting.py", - "topostats/tracing/dnatracing.py", - "topostats/tracing/tracing_dna.py", "topostats/tracing/tracingfuncs.py", "venv", ] @@ -258,7 +257,7 @@ convention = "numpy" fixture-parentheses = true [tool.codespell] -skip = '*.spm*,*.mplstyle' +skip = '*.spm*,*.mplstyle,*.svg' count = '' quiet-level = 3 @@ -275,9 +274,9 @@ exclude = [ # don't report on objects that match any of these regex "\\.__repr__$", "^test_", "^conftest", - "^dnatrcing", "^tracingfuncs", "^conf$", + "^theme", ] override_SS05 = [ # override SS05 to allow docstrings starting with these words "^Process ", diff --git a/temp.ipynb b/temp.ipynb new file mode 100644 index 0000000000..06e133e934 --- /dev/null +++ b/temp.ipynb @@ -0,0 +1,103 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from scipy.spatial.distance import cdist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# point arrays\n", + "array1 = np.array([[1, 2], [2, 3], [3, 3]])\n", + "array2 = np.array([[4, 4], [4, 5], [5, 5], [6, 6]])\n", + "\n", + "img = np.zeros((10, 10))\n", + "# set all array1 points to 1\n", + "for point in array1:\n", + " print(point)\n", + " img[point[0], point[1]] = 1\n", + "\n", + "print(\"----\")\n", + "\n", + "# set all array2 points to 2\n", + "for point in array2:\n", + " print(point)\n", + " img[point[0], point[1]] = 2\n", + "\n", + "plt.imshow(img)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Numpy approach\n", + "\n", + "# How to create this matrix from array1 and array2?\n", + "# [1 4] [1 4] [1 5] [1 6]\n", + "# [2 4] [2 4] [2 5] [2 6]\n", + "# [3 4] [3 4] [3 5] [3 6]\n", + "\n", + "grid1, grid2 = np.meshgrid(array1, array2)\n", + "print(grid1)\n", + "print(grid2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# test if any points are touching between the two arrays, using vectorisation\n", + "\n", + "diffs = np.subtract.outer(array1, array2)\n", + "print(diffs)\n", + "print(\"---\")\n", + "# linalg the diffs\n", + "dists = np.linalg.norm(diffs, axis=1)\n", + "print(dists)\n", + "\n", + "print(\"---\")\n", + "\n", + "\n", + "cdists = cdist(array1, array2, \"euclidean\")\n", + "print(cdists)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "new", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/_regtest_outputs/test_grainstats_minicircle.test_grainstats_regression.out b/tests/_regtest_outputs/test_grainstats_minicircle.test_grainstats_regression.out index cb0bb469cc..e701c344c9 100644 --- a/tests/_regtest_outputs/test_grainstats_minicircle.test_grainstats_regression.out +++ b/tests/_regtest_outputs/test_grainstats_minicircle.test_grainstats_regression.out @@ -1,5 +1,5 @@ - centre_x centre_y radius_min radius_max radius_mean radius_median height_min height_max height_median height_mean volume area area_cartesian_bbox smallest_bounding_width smallest_bounding_length smallest_bounding_area aspect_ratio threshold max_feret min_feret image -molecule_number -0 7.500895e-08 4.775362e-08 4.044948e-09 2.542964e-08 1.600048e-08 1.647888e-08 1.006306e-09 2.696640e-09 1.597337e-09 1.596062e-09 1.084371e-24 6.794043e-16 1.319762e-15 2.053897e-08 5.037861e-08 1.034725e-15 0.407692 above 5.037861e-08 2.053897e-08 None -1 8.028649e-08 7.895237e-08 6.658610e-09 2.623645e-08 1.609777e-08 1.573693e-08 1.009513e-09 2.480685e-09 1.668932e-09 1.683197e-09 1.031843e-24 6.130257e-16 1.526707e-15 2.017441e-08 4.942135e-08 9.970467e-16 0.408213 above 4.987228e-08 2.017441e-08 None -2 4.001424e-08 7.585689e-08 9.745734e-09 2.369314e-08 1.746942e-08 1.818722e-08 1.006681e-09 2.139365e-09 1.646541e-09 1.598842e-09 1.123717e-24 7.028320e-16 1.546230e-15 3.359220e-08 4.149625e-08 1.393950e-15 0.809524 above 4.440534e-08 3.176819e-08 None + centre_x centre_y radius_min radius_max radius_mean radius_median height_min height_max height_median height_mean volume area area_cartesian_bbox smallest_bounding_width smallest_bounding_length smallest_bounding_area aspect_ratio threshold max_feret min_feret image +grain_number +0 7.500895e-08 4.775362e-08 4.044948e-09 2.542964e-08 1.600048e-08 1.647888e-08 1.006306e-09 2.696640e-09 1.597337e-09 1.596062e-09 1.084371e-24 6.794043e-16 1.319762e-15 2.053897e-08 5.037861e-08 1.034725e-15 0.407692 above 5.037861e-08 2.053897e-08 None +1 8.028649e-08 7.895237e-08 6.658610e-09 2.623645e-08 1.609777e-08 1.573693e-08 1.009513e-09 2.480685e-09 1.668932e-09 1.683197e-09 1.031843e-24 6.130257e-16 1.526707e-15 2.017441e-08 4.942135e-08 9.970467e-16 0.408213 above 4.987228e-08 2.017441e-08 None +2 4.001424e-08 7.585689e-08 9.745734e-09 2.369314e-08 1.746942e-08 1.818722e-08 1.006681e-09 2.139365e-09 1.646541e-09 1.598842e-09 1.123717e-24 7.028320e-16 1.546230e-15 3.359220e-08 4.149625e-08 1.393950e-15 0.809524 above 4.440534e-08 3.176819e-08 None diff --git a/tests/_regtest_outputs/test_processing.test_process_scan_above.out b/tests/_regtest_outputs/test_processing.test_process_scan_above.out index a476a64997..2bf26e804d 100644 --- a/tests/_regtest_outputs/test_processing.test_process_scan_above.out +++ b/tests/_regtest_outputs/test_processing.test_process_scan_above.out @@ -1,8 +1,7 @@ image_size_x_m image_size_y_m image_area_m2 image_size_x_px image_size_y_px image_area_px2 grains_number_above grains_per_m2_above grains_number_below grains_per_m2_below rms_roughness image minicircle_small 1.2646e-07 1.2646e-07 1.5993e-14 64 64 4096 3 1.8758e+14 0 0.0000e+00 6.8208e-10 - centre_x centre_y radius_min radius_max radius_mean radius_median height_min height_max height_median height_mean volume area area_cartesian_bbox smallest_bounding_width smallest_bounding_length smallest_bounding_area aspect_ratio threshold max_feret min_feret image contour_length circular end_to_end_distance -molecule_number -0 7.5100e-08 4.7559e-08 3.9431e-09 2.5631e-08 1.6016e-08 1.6680e-08 9.1991e-10 2.6422e-09 1.5338e-09 1.5341e-09 1.0543e-24 6.8721e-16 1.3198e-15 2.0539e-08 5.0379e-08 1.0347e-15 4.0769e-01 above 5.0379e-08 2.0539e-08 minicircle_small 6.0226e-08 False 8.6738e-09 -1 8.0241e-08 7.8677e-08 6.8951e-09 2.7188e-08 1.6272e-08 1.6263e-08 9.0630e-10 2.4586e-09 1.6144e-09 1.6264e-09 1.0352e-24 6.3645e-16 1.5931e-15 2.0174e-08 5.1212e-08 1.0332e-15 3.9394e-01 above 5.1262e-08 2.0174e-08 minicircle_small 6.6355e-08 True 0.0000e+00 -2 4.0012e-08 7.5644e-08 9.9461e-09 2.3654e-08 1.7561e-08 1.8364e-08 9.0641e-10 2.1066e-09 1.5939e-09 1.5493e-09 1.1192e-24 7.2236e-16 1.5462e-15 3.3592e-08 4.1496e-08 1.3940e-15 8.0952e-01 above 4.4405e-08 3.2528e-08 minicircle_small 9.6106e-08 True 0.0000e+00 + centre_x centre_y grain_number radius_min radius_max radius_mean radius_median height_min height_max height_median height_mean volume area area_cartesian_bbox smallest_bounding_width smallest_bounding_length smallest_bounding_area aspect_ratio threshold max_feret min_feret image grain_endpoints grain_junctions total_branch_lengths num_crossings avg_crossing_confidence min_crossing_confidence num_mols writhe_string total_contour_length average_end_to_end_distance +0 7.5100e-08 4.7559e-08 0 3.9431e-09 2.5631e-08 1.6016e-08 1.6680e-08 9.1991e-10 2.6422e-09 1.5338e-09 1.5341e-09 1.0543e-24 6.8721e-16 1.3198e-15 2.0539e-08 5.0379e-08 1.0347e-15 4.0769e-01 above 5.0379e-08 2.0539e-08 minicircle_small 1 1 8.4571e-08 1 None None 2 6.5881e-08 8.8370e-09 +1 8.0241e-08 7.8677e-08 1 6.8951e-09 2.7188e-08 1.6272e-08 1.6263e-08 9.0630e-10 2.4586e-09 1.6144e-09 1.6264e-09 1.0352e-24 6.3645e-16 1.5931e-15 2.0174e-08 5.1212e-08 1.0332e-15 3.9394e-01 above 5.1262e-08 2.0174e-08 minicircle_small 0 0 7.3054e-08 0 None None 1 NaN 5.8272e-08 0.0000e+00 +2 4.0012e-08 7.5644e-08 2 9.9461e-09 2.3654e-08 1.7561e-08 1.8364e-08 9.0641e-10 2.1066e-09 1.5939e-09 1.5493e-09 1.1192e-24 7.2236e-16 1.5462e-15 3.3592e-08 4.1496e-08 1.3940e-15 8.0952e-01 above 4.4405e-08 3.2528e-08 minicircle_small 0 0 1.0447e-07 0 None None 1 NaN 8.7183e-08 0.0000e+00 diff --git a/tests/_regtest_outputs/test_processing.test_process_scan_below.out b/tests/_regtest_outputs/test_processing.test_process_scan_below.out index 7addff3d3b..66cc02ac81 100644 --- a/tests/_regtest_outputs/test_processing.test_process_scan_below.out +++ b/tests/_regtest_outputs/test_processing.test_process_scan_below.out @@ -1,6 +1,5 @@ image_size_x_m image_size_y_m image_area_m2 image_size_x_px image_size_y_px image_area_px2 grains_number_above grains_per_m2_above grains_number_below grains_per_m2_below rms_roughness image minicircle_small 1.2646e-07 1.2646e-07 1.5993e-14 64 64 4096 0 0.0000e+00 1 6.2526e+13 6.8208e-10 - centre_x centre_y radius_min radius_max radius_mean radius_median height_min height_max height_median height_mean volume area area_cartesian_bbox smallest_bounding_width smallest_bounding_length smallest_bounding_area aspect_ratio threshold max_feret min_feret image contour_length circular end_to_end_distance -molecule_number -0 3.2366e-08 1.4036e-08 7.7690e-10 1.2272e-08 6.4301e-09 6.4170e-09 -3.7937e-10 -2.1207e-10 -2.4477e-10 -2.6816e-10 -3.0364e-26 1.1323e-16 3.0066e-16 7.0841e-09 2.1505e-08 1.5234e-16 3.2941e-01 below 2.2092e-08 7.0841e-09 minicircle_small NaN NaN NaN + centre_x centre_y grain_number radius_min radius_max radius_mean radius_median height_min height_max height_median height_mean volume area area_cartesian_bbox smallest_bounding_width smallest_bounding_length smallest_bounding_area aspect_ratio threshold max_feret min_feret image grain_endpoints grain_junctions total_branch_lengths num_crossings avg_crossing_confidence min_crossing_confidence num_mols total_contour_length average_end_to_end_distance +0 3.2366e-08 1.4036e-08 0 7.7690e-10 1.2272e-08 6.4301e-09 6.4170e-09 -3.7937e-10 -2.1207e-10 -2.4477e-10 -2.6816e-10 -3.0364e-26 1.1323e-16 3.0066e-16 7.0841e-09 2.1505e-08 1.5234e-16 3.2941e-01 below 2.2092e-08 7.0841e-09 minicircle_small 2 0 1.3493e-08 0 None None 1 1.0799e-08 1.0076e-08 diff --git a/tests/_regtest_outputs/test_processing.test_process_scan_both.out b/tests/_regtest_outputs/test_processing.test_process_scan_both.out index 6b929fd2ad..65fe7a368e 100644 --- a/tests/_regtest_outputs/test_processing.test_process_scan_both.out +++ b/tests/_regtest_outputs/test_processing.test_process_scan_both.out @@ -1,9 +1,8 @@ image_size_x_m image_size_y_m image_area_m2 image_size_x_px image_size_y_px image_area_px2 grains_number_above grains_per_m2_above grains_number_below grains_per_m2_below rms_roughness image minicircle_small 1.2646e-07 1.2646e-07 1.5993e-14 64 64 4096 3 1.8758e+14 1 6.2526e+13 6.8208e-10 - centre_x centre_y radius_min radius_max radius_mean radius_median height_min height_max height_median height_mean volume area area_cartesian_bbox smallest_bounding_width smallest_bounding_length smallest_bounding_area aspect_ratio threshold max_feret min_feret image contour_length circular end_to_end_distance -molecule_number -0 3.2366e-08 1.4036e-08 7.7690e-10 1.2272e-08 6.4301e-09 6.4170e-09 -3.7937e-10 -2.1207e-10 -2.4477e-10 -2.6816e-10 -3.0364e-26 1.1323e-16 3.0066e-16 7.0841e-09 2.1505e-08 1.5234e-16 3.2941e-01 below 2.2092e-08 7.0841e-09 minicircle_small NaN NaN NaN -0 7.5100e-08 4.7559e-08 3.9431e-09 2.5631e-08 1.6016e-08 1.6680e-08 9.1991e-10 2.6422e-09 1.5338e-09 1.5341e-09 1.0543e-24 6.8721e-16 1.3198e-15 2.0539e-08 5.0379e-08 1.0347e-15 4.0769e-01 above 5.0379e-08 2.0539e-08 minicircle_small 6.0226e-08 0.0000e+00 8.6738e-09 -1 8.0241e-08 7.8677e-08 6.8951e-09 2.7188e-08 1.6272e-08 1.6263e-08 9.0630e-10 2.4586e-09 1.6144e-09 1.6264e-09 1.0352e-24 6.3645e-16 1.5931e-15 2.0174e-08 5.1212e-08 1.0332e-15 3.9394e-01 above 5.1262e-08 2.0174e-08 minicircle_small 6.6355e-08 1.0000e+00 0.0000e+00 -2 4.0012e-08 7.5644e-08 9.9461e-09 2.3654e-08 1.7561e-08 1.8364e-08 9.0641e-10 2.1066e-09 1.5939e-09 1.5493e-09 1.1192e-24 7.2236e-16 1.5462e-15 3.3592e-08 4.1496e-08 1.3940e-15 8.0952e-01 above 4.4405e-08 3.2528e-08 minicircle_small 9.6106e-08 1.0000e+00 0.0000e+00 + centre_x centre_y grain_number radius_min radius_max radius_mean radius_median height_min height_max height_median height_mean volume area area_cartesian_bbox smallest_bounding_width smallest_bounding_length smallest_bounding_area aspect_ratio threshold max_feret min_feret image grain_endpoints grain_junctions total_branch_lengths num_crossings avg_crossing_confidence min_crossing_confidence num_mols writhe_string total_contour_length average_end_to_end_distance +0 3.2366e-08 1.4036e-08 0 7.7690e-10 1.2272e-08 6.4301e-09 6.4170e-09 -3.7937e-10 -2.1207e-10 -2.4477e-10 -2.6816e-10 -3.0364e-26 1.1323e-16 3.0066e-16 7.0841e-09 2.1505e-08 1.5234e-16 3.2941e-01 below 2.2092e-08 7.0841e-09 minicircle_small 2 0 1.3493e-08 0 None None 1 NaN 1.0799e-08 1.0076e-08 +1 7.5100e-08 4.7559e-08 0 3.9431e-09 2.5631e-08 1.6016e-08 1.6680e-08 9.1991e-10 2.6422e-09 1.5338e-09 1.5341e-09 1.0543e-24 6.8721e-16 1.3198e-15 2.0539e-08 5.0379e-08 1.0347e-15 4.0769e-01 above 5.0379e-08 2.0539e-08 minicircle_small 1 1 8.4571e-08 1 None None 2 6.5881e-08 8.8370e-09 +2 8.0241e-08 7.8677e-08 1 6.8951e-09 2.7188e-08 1.6272e-08 1.6263e-08 9.0630e-10 2.4586e-09 1.6144e-09 1.6264e-09 1.0352e-24 6.3645e-16 1.5931e-15 2.0174e-08 5.1212e-08 1.0332e-15 3.9394e-01 above 5.1262e-08 2.0174e-08 minicircle_small 0 0 7.3054e-08 0 None None 1 NaN 5.8272e-08 0.0000e+00 +3 4.0012e-08 7.5644e-08 2 9.9461e-09 2.3654e-08 1.7561e-08 1.8364e-08 9.0641e-10 2.1066e-09 1.5939e-09 1.5493e-09 1.1192e-24 7.2236e-16 1.5462e-15 3.3592e-08 4.1496e-08 1.3940e-15 8.0952e-01 above 4.4405e-08 3.2528e-08 minicircle_small 0 0 1.0447e-07 0 None None 1 NaN 8.7183e-08 0.0000e+00 diff --git a/tests/conftest.py b/tests/conftest.py index a05604d582..30742b2afb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,6 @@ from topostats.grainstats import GrainStats from topostats.io import LoadScans, read_yaml from topostats.plotting import TopoSum -from topostats.tracing.dnatracing import dnaTrace from topostats.utils import _get_mask, get_mask, get_thresholds # This is required because of the inheritance used throughout @@ -141,14 +140,8 @@ def grainstats_config(default_config: dict) -> dict: """Configurations for grainstats.""" config = default_config["grainstats"] config["direction"] = "above" - config.pop("run") - return config - - -@pytest.fixture() -def dnatracing_config(default_config: dict) -> dict: - """Configurations for dnatracing.""" - config = default_config["dnatracing"] + # Set cropped image size to 40nm + config["cropped_size"] = 40.0 config.pop("run") return config @@ -732,46 +725,37 @@ def minicircle_grainstats( ) -# Derive fixtures for DNA Tracing -GRAINS = np.array( - [ - [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2], - [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 2], - [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 2], - [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], - [0, 0, 3, 3, 3, 3, 3, 0, 0, 0, 2], - [0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 2], - [0, 0, 3, 3, 3, 3, 3, 0, 0, 0, 2], - [0, 0, 4, 4, 4, 4, 4, 0, 0, 0, 2], - ] -) -FULL_IMAGE = RNG.random((GRAINS.shape[0], GRAINS.shape[1])) +# Random shapes +# Generate a random skeletons, first is a skeleton with a closed loop with side branches +kwargs = { + "image_shape": (60, 32), + "max_shapes": 10, + "channel_axis": None, + "shape": None, + "allow_overlap": True, + "min_size": 20, +} @pytest.fixture() -def test_dnatracing() -> dnaTrace: - """Instantiate a dnaTrace object.""" - return dnaTrace(image=FULL_IMAGE, grain=GRAINS, filename="Test", pixel_to_nm_scaling=1.0) +def utils_skeleton_linear1() -> npt.NDArray: + """Linear skeleton.""" + random_images, _ = draw.random_shapes(rng=1, **kwargs) + return skeletonize(random_images != 255) @pytest.fixture() -def minicircle_dnatracing( - minicircle_grain_gaussian_filter: Filters, - minicircle_grain_coloured: Grains, - dnatracing_config: dict, -) -> dnaTrace: - """DnaTrace object instantiated with minicircle data.""" # noqa: D403 - dnatracing_config.pop("pad_width") - dna_traces = dnaTrace( - image=minicircle_grain_coloured.image.T, - grain=minicircle_grain_coloured.directions["above"]["labelled_regions_02"], - filename=minicircle_grain_gaussian_filter.filename, - pixel_to_nm_scaling=minicircle_grain_gaussian_filter.pixel_to_nm_scaling, - **dnatracing_config, - ) - dna_traces.trace_dna() - return dna_traces +def utils_skeleton_linear2() -> npt.NDArray: + """Linear skeleton T-junction and side-branch.""" + random_images, _ = draw.random_shapes(rng=165103, **kwargs) + return skeletonize(random_images != 255) + + +@pytest.fixture() +def utils_skeleton_linear3() -> npt.NDArray: + """Linear skeleton with several branches.""" + random_images, _ = draw.random_shapes(rng=7334281, **kwargs) + return skeletonize(random_images != 255) # DNA Tracing Fixtures @@ -900,11 +884,13 @@ def _generate_random_skeleton(**extra_kwargs): "shape": None, "allow_overlap": True, } + # kwargs.update heights = {"scale": 100, "sigma": 5.0, "cval": 20.0} - random_image, _ = draw.random_shapes(**kwargs, **extra_kwargs) + kwargs = {**kwargs, **extra_kwargs} + random_image, _ = draw.random_shapes(**kwargs) mask = random_image != 255 skeleton = skeletonize(mask) - return {"img": _generate_heights(skeleton, **heights), "skeleton": skeleton} + return {"original": mask, "img": _generate_heights(skeleton, **heights), "skeleton": skeleton} @pytest.fixture() @@ -937,13 +923,21 @@ def skeleton_linear3() -> dict: return _generate_random_skeleton(rng=894632511, min_size=20) -# Helper functions for visualising skeletons and heights -# +@pytest.fixture() +def pruning_skeleton() -> dict: + """Smaller skeleton for testing parameters of prune_all_skeletons(). Has a T-junction.""" + return _generate_random_skeleton(rng=69432138, min_size=15, image_shape=(30, 30)) + + +## Helper function visualising for generating skeletons and heights + + +# import matplotlib.pyplot as plt # def pruned_plot(gen_shape: dict) -> None: # """Plot the original skeleton, its derived height and the pruned skeleton.""" -# img_skeleton = gen_shape() +# img_skeleton = gen_shape # pruned = topostatsPrune( -# img_skeleton["heights"], +# img_skeleton["img"], # img_skeleton["skeleton"], # max_length=-1, # height_threshold=90, @@ -952,17 +946,23 @@ def skeleton_linear3() -> dict: # ) # pruned_skeleton = pruned._prune_by_length(pruned.skeleton, pruned.max_length) # fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) -# ax1.imshow(img_skeleton["skeleton"]) -# ax2.imshow(img_skeleton["heights"]) -# ax3.imshow(pruned_skeleton) +# ax1.imshow(img_skeleton["original"]) +# ax1.set_title("Original mask") +# ax2.imshow(img_skeleton["skeleton"]) +# ax2.set_title("Skeleton") +# ax3.imshow(img_skeleton["img"]) +# ax3.set_title("Gaussian Blurring") +# ax4.imshow(pruned_skeleton) +# ax4.set_title("Pruned Skeleton") # plt.show() -# pruned_plot(skeleton_loop1) -# pruned_plot(skeleton_loop2) -# pruned_plot(skeleton_linear1) -# pruned_plot(skeleton_linear2) -# pruned_plot(skeleton_linear3) +# pruned_plot(pruning_skeleton_loop1()) +# pruned_plot(pruning_skeleton_loop2()) +# pruned_plot(pruning_skeleton_linear1()) +# pruned_plot(pruning_skeleton_linear2()) +# pruned_plot(pruning_skeleton_linear3()) +# pruned_plot(pruning_skeleton()) # U-Net fixtures diff --git a/tests/measure/conftest.py b/tests/measure/conftest.py index 2f812788c5..3dc9191e2a 100644 --- a/tests/measure/conftest.py +++ b/tests/measure/conftest.py @@ -1,15 +1,42 @@ -"""Fixtures for testing sub-modules of the measure module.""" +"""Fixtures for testing the measure module.""" from __future__ import annotations +import networkx as nx import numpy as np import numpy.typing as npt import pytest from skimage import draw +from skimage.morphology import label + +from topostats.tracing.nodestats import nodeStats # pylint: disable=redefined-outer-name +@pytest.fixture() +def network_array_representation_figure_8() -> npt.NDArray[np.int32]: + """Fixture for the network array representation of the figure 8 test molecule.""" + return np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 3, 1, 1, 1, 1, 1, 1, 1, 3, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ) + + @pytest.fixture() def tiny_circle() -> npt.NDArray: """Tiny circle.""" @@ -45,6 +72,61 @@ def tiny_quadrilateral() -> npt.NDArray: ) +@pytest.fixture() +def labelled_nodes_figure_8(network_array_representation_figure_8) -> npt.NDArray[np.int32]: + """Fixture for the labelled nodes of the figure 8 test molecule.""" + # Grab the nodes from the network array representation. + just_nodes = np.where(network_array_representation_figure_8 == 3, 1, 0) + # Adding the 1 and 0 in the np.where call above makes the nodes 1 and the rest 0. + return label(just_nodes) + + +@pytest.fixture() +def labelled_branches_figure_8(network_array_representation_figure_8) -> npt.NDArray[np.int32]: + """Fixture for the labelled branches of the figure 8 test molecule.""" + # Grab the branches from the network array representation. + just_branches = np.where(network_array_representation_figure_8 == 1, 1, 0) + # Adding the 1 and 0 in the np.where call above makes the branches 1 and the rest 0. + return label(just_branches) + + +# fixture for just the skeleton +@pytest.fixture() +def skeleton_figure_8(network_array_representation_figure_8) -> npt.NDArray[np.bool_]: + """Fixture for the skeleton of the figure 8 test molecule.""" + return network_array_representation_figure_8.astype(bool) + + +@pytest.fixture() +def whole_skeleton_graph_figure_8(skeleton_figure_8) -> nx.classes.graph.Graph: + """Fixture for the whole skeleton graph of the figure 8 test molecule.""" + # Create graph of just skeleton from the skeleton (just 1s) + return nodeStats.skeleton_image_to_graph(skeleton_figure_8) + + +@pytest.fixture() +def expected_network_array_representation_figure_8() -> npt.NDArray[np.int32]: + """Fixture for the expected network array representation of the figure 8 test molecule.""" + return np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ) + + @pytest.fixture() def tiny_square() -> npt.NDArray: """Tiny square.""" diff --git a/tests/measure/test_geometry.py b/tests/measure/test_geometry.py new file mode 100644 index 0000000000..6c880ea512 --- /dev/null +++ b/tests/measure/test_geometry.py @@ -0,0 +1,213 @@ +"""Tests for the geometry module.""" + +import networkx +import numpy as np +import numpy.typing as npt +import pytest + +# pylint: disable=too-many-arguments +from topostats.measure.geometry import ( + bounding_box_cartesian_points_float, + bounding_box_cartesian_points_integer, + calculate_shortest_branch_distances, + connect_best_matches, + do_points_in_arrays_touch, +) + + +def test_bounding_box_cartesian_points_float_raises_value_error() -> None: + """Test the bounding_box_cartesian_points function raises a ValueError.""" + with pytest.raises(ValueError, match="Input array must be Nx2."): + bounding_box_cartesian_points_float(np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]])) + + +def test_bounding_box_cartesian_points_integer_raises_value_error() -> None: + """Test the bounding_box_cartesian_points function raises a ValueError.""" + with pytest.raises(ValueError, match="Input array must be Nx2."): + bounding_box_cartesian_points_float(np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]])) + + +@pytest.mark.parametrize( + ("points", "expected_bbox"), + [ + pytest.param(np.array([[0, 0], [1, 1], [2, 2]]), (0, 0, 2, 2), id="diagonal line"), + pytest.param(np.array([[0, 0], [1, 1], [2, 0]]), (0, 0, 2, 1), id="triangle"), + pytest.param(np.array([[-5, -5], [-10, -10], [-3, -3]]), (-10, -10, -3, -3), id="negative values"), + pytest.param(np.array([[-1, -1], [1, 1], [2, 0], [-2, 0]]), (-2, -1, 2, 1), id="negative and positive values"), + pytest.param(np.array([[0.1, 0.1], [1.1, 1.1], [2.1, 2.1]]), (0.1, 0.1, 2.1, 2.1), id="diagonal line, floats"), + ], +) +def test_bounding_box_cartesian_points_float( + points: npt.NDArray[np.number], expected_bbox: tuple[np.float64, np.float64, np.float64, np.float64] +) -> None: + """Test the bounding_box_cartesian_points function.""" + assert bounding_box_cartesian_points_float(points) == expected_bbox + + +@pytest.mark.parametrize( + ("points", "expected_bbox"), + [ + pytest.param(np.array([[0, 0], [1, 1], [2, 2]]), (0, 0, 2, 2), id="diagonal line"), + pytest.param(np.array([[0, 0], [1, 1], [2, 0]]), (0, 0, 2, 1), id="triangle"), + pytest.param(np.array([[-5, -5], [-10, -10], [-3, -3]]), (-10, -10, -3, -3), id="negative values"), + pytest.param(np.array([[-1, -1], [1, 1], [2, 0], [-2, 0]]), (-2, -1, 2, 1), id="negative and positive values"), + pytest.param(np.array([[0.1, 0.1], [1.1, 1.1], [2.1, 2.1]]), (0, 0, 2, 2), id="diagonal line, floats"), + ], +) +def test_bounding_box_cartesian_points_integer( + points: npt.NDArray[np.number], expected_bbox: tuple[np.float64, np.float64, np.float64, np.float64] +) -> None: + """Test the bounding_box_cartesian_points function.""" + assert bounding_box_cartesian_points_integer(points) == expected_bbox + + +def test_do_points_in_arrays_touch_raises_value_error() -> None: + """Test the do_points_in_arrays_touch function raises a ValueError.""" + with pytest.raises(ValueError, match="Input arrays must be Nx2 and Mx2."): + do_points_in_arrays_touch(np.array([[0, 0], [1, 1], [2, 2]]), np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]])) + + +@pytest.mark.parametrize( + ("array_1", "array_2", "expected_result_touching", "expected_point_1_touching", "expected_point_2_touching"), + [ + pytest.param( + np.array([[0, 0], [1, 1], [2, 2]]), + np.array([[4, 4], [5, 5], [6, 6]]), + False, + None, + None, + id="no touching points", + ), + pytest.param( + np.array([[0, 0], [1, 1], [2, 2]]), + np.array([[2, 3], [3, 4], [5, 5]]), + True, + np.array([2, 2]), + np.array([2, 3]), + id="touching points, non_diag", + ), + pytest.param( + np.array([[0, 0], [1, 1], [2, 2]]), + np.array([[3, 3], [4, 4], [5, 5]]), + True, + np.array([2, 2]), + np.array([3, 3]), + id="touching points, diag", + ), + ], +) +def test_do_points_in_arrays_touch( + array_1: npt.NDArray[np.number], + array_2: npt.NDArray[np.number], + expected_result_touching: bool, + expected_point_1_touching: npt.NDArray[np.number], + expected_point_2_touching: npt.NDArray[np.number], +) -> None: + """Test the do_points_in_arrays_touch function.""" + result_touching, point_touching_1, point_touching_2 = do_points_in_arrays_touch(array_1, array_2) + + assert result_touching == expected_result_touching + np.testing.assert_array_equal(point_touching_1, expected_point_1_touching) + np.testing.assert_array_equal(point_touching_2, expected_point_2_touching) + + +@pytest.mark.parametrize( + ( + "network_array_representation", + "whole_skeleton_graph", + "match_indexes", + "shortest_distances_between_nodes", + "shortest_distances_branch_indexes", + "emanating_branch_starts_by_node", + "extend_distance", + "expected_network_array_representation", + ), + [ + pytest.param( + "network_array_representation_figure_8", + "whole_skeleton_graph_figure_8", + np.array([[1, 0]]), + np.array([[0.0, 6.0], [6.0, 0.0]]), + np.array([[[0, 0], [1, 1]], [[1, 1], [0, 0]]]), + { + 0: [np.array([6, 1]), np.array([7, 3]), np.array([8, 1])], + 1: [np.array([6, 11]), np.array([7, 9]), np.array([8, 11])], + }, + 8.0, + "expected_network_array_representation_figure_8", + id="figure 8", + ) + ], +) +def test_connect_best_matches( + network_array_representation: npt.NDArray[np.int32], + whole_skeleton_graph: networkx.classes.graph.Graph, + match_indexes: npt.NDArray[np.int32], + shortest_distances_between_nodes: npt.NDArray[np.number], + shortest_distances_branch_indexes: npt.NDArray[np.int32], + emanating_branch_starts_by_node: npt.NDArray[np.int32], + extend_distance: float, + expected_network_array_representation: npt.NDArray[np.int32], + request, +) -> None: + """Test the connect_best_matches function.""" + # Load fixtures + network_array_representation = request.getfixturevalue(network_array_representation) + whole_skeleton_graph = request.getfixturevalue(whole_skeleton_graph) + expected_network_array_representation = request.getfixturevalue(expected_network_array_representation) + + result = connect_best_matches( + network_array_representation, + whole_skeleton_graph, + match_indexes, + shortest_distances_between_nodes, + shortest_distances_branch_indexes, + emanating_branch_starts_by_node, + extend_distance, + ) + + np.testing.assert_array_equal(result, expected_network_array_representation) + + +@pytest.mark.parametrize( + ( + "nodes_with_branches_starting_coords", + "whole_skeleton_graph", + "expected_shortest_node_distances", + "expected_shortest_distances_branch_indexes", + "expected_shortest_distances_branch_coordinates", + ), + [ + pytest.param( + { + 0: [np.array([6, 1]), np.array([7, 3]), np.array([8, 1])], + 1: [np.array([6, 11]), np.array([7, 9]), np.array([8, 11])], + }, + "whole_skeleton_graph_figure_8", + np.array([[0, 6.0], [6.0, 0.0]]), + np.array([[[0, 0], [1, 1]], [[1, 1], [0, 0]]]), + np.array([[[[0, 0], [0, 0]], [[7, 3], [7, 9]]], [[[7, 9], [7, 3]], [[0, 0], [0, 0]]]]), + id="figure 8", + ) + ], +) +def test_calculate_shortest_branch_distances( + nodes_with_branches_starting_coords: dict[int, npt.NDArray[np.number]], + whole_skeleton_graph: networkx.classes.graph.Graph, + expected_shortest_node_distances: dict[int, float], + expected_shortest_distances_branch_indexes: dict[int, npt.NDArray[np.int32]], + expected_shortest_distances_branch_coordinates: dict[int, npt.NDArray[np.number]], + request, +) -> None: + """Test the calculate_shortest_branch_distances function.""" + # Load fixtures + whole_skeleton_graph = request.getfixturevalue(whole_skeleton_graph) + + shortest_node_distances, shortest_distances_branch_indexes, shortest_distances_branch_coordinates = ( + calculate_shortest_branch_distances(nodes_with_branches_starting_coords, whole_skeleton_graph) + ) + + np.testing.assert_array_equal(shortest_node_distances, expected_shortest_node_distances) + np.testing.assert_array_equal(shortest_distances_branch_indexes, expected_shortest_distances_branch_indexes) + print(shortest_distances_branch_coordinates) + np.testing.assert_array_equal(shortest_distances_branch_coordinates, expected_shortest_distances_branch_coordinates) diff --git a/tests/resources/catenane_connected_nodes.npy b/tests/resources/catenane_connected_nodes.npy new file mode 100644 index 0000000000..6226bcb8e5 Binary files /dev/null and b/tests/resources/catenane_connected_nodes.npy differ diff --git a/tests/resources/catenane_image.npy b/tests/resources/catenane_image.npy new file mode 100644 index 0000000000..72b32b9dbc Binary files /dev/null and b/tests/resources/catenane_image.npy differ diff --git a/tests/resources/catenane_node_0_avg_image.npy b/tests/resources/catenane_node_0_avg_image.npy new file mode 100644 index 0000000000..a0cfc1e80d Binary files /dev/null and b/tests/resources/catenane_node_0_avg_image.npy differ diff --git a/tests/resources/catenane_node_0_branch_image.npy b/tests/resources/catenane_node_0_branch_image.npy new file mode 100644 index 0000000000..4c5951098f Binary files /dev/null and b/tests/resources/catenane_node_0_branch_image.npy differ diff --git a/tests/resources/catenane_node_0_masked_image.pkl b/tests/resources/catenane_node_0_masked_image.pkl new file mode 100644 index 0000000000..e266fb7ad7 Binary files /dev/null and b/tests/resources/catenane_node_0_masked_image.pkl differ diff --git a/tests/resources/catenane_node_0_matched_branches_analyse_node_branches.pkl b/tests/resources/catenane_node_0_matched_branches_analyse_node_branches.pkl new file mode 100644 index 0000000000..abb3790d14 Binary files /dev/null and b/tests/resources/catenane_node_0_matched_branches_analyse_node_branches.pkl differ diff --git a/tests/resources/catenane_node_0_matched_branches_join_matching_branches_through_node.pkl b/tests/resources/catenane_node_0_matched_branches_join_matching_branches_through_node.pkl new file mode 100644 index 0000000000..150bb395c2 Binary files /dev/null and b/tests/resources/catenane_node_0_matched_branches_join_matching_branches_through_node.pkl differ diff --git a/tests/resources/catenane_node_0_ordered_branches.pkl b/tests/resources/catenane_node_0_ordered_branches.pkl new file mode 100644 index 0000000000..dfa7eba5e7 Binary files /dev/null and b/tests/resources/catenane_node_0_ordered_branches.pkl differ diff --git a/tests/resources/catenane_node_0_reduced_node_area.npy b/tests/resources/catenane_node_0_reduced_node_area.npy new file mode 100644 index 0000000000..f1b0db3c80 Binary files /dev/null and b/tests/resources/catenane_node_0_reduced_node_area.npy differ diff --git a/tests/resources/catenane_node_0_reduced_skeleton_graph.pkl b/tests/resources/catenane_node_0_reduced_skeleton_graph.pkl new file mode 100644 index 0000000000..02a60657b2 Binary files /dev/null and b/tests/resources/catenane_node_0_reduced_skeleton_graph.pkl differ diff --git a/tests/resources/catenane_node_1_masked_image.pkl b/tests/resources/catenane_node_1_masked_image.pkl new file mode 100644 index 0000000000..e4eb3dd957 Binary files /dev/null and b/tests/resources/catenane_node_1_masked_image.pkl differ diff --git a/tests/resources/catenane_node_1_matched_branches_analyse_node_branches.pkl b/tests/resources/catenane_node_1_matched_branches_analyse_node_branches.pkl new file mode 100644 index 0000000000..d867e3917d Binary files /dev/null and b/tests/resources/catenane_node_1_matched_branches_analyse_node_branches.pkl differ diff --git a/tests/resources/catenane_node_1_matched_branches_join_matching_branches_through_node.pkl b/tests/resources/catenane_node_1_matched_branches_join_matching_branches_through_node.pkl new file mode 100644 index 0000000000..13424683a0 Binary files /dev/null and b/tests/resources/catenane_node_1_matched_branches_join_matching_branches_through_node.pkl differ diff --git a/tests/resources/catenane_node_1_ordered_branches.pkl b/tests/resources/catenane_node_1_ordered_branches.pkl new file mode 100644 index 0000000000..e1edaf0ab1 Binary files /dev/null and b/tests/resources/catenane_node_1_ordered_branches.pkl differ diff --git a/tests/resources/catenane_node_1_reduced_skeleton_graph.pkl b/tests/resources/catenane_node_1_reduced_skeleton_graph.pkl new file mode 100644 index 0000000000..876a8e87f2 Binary files /dev/null and b/tests/resources/catenane_node_1_reduced_skeleton_graph.pkl differ diff --git a/tests/resources/catenane_node_2_masked_image.pkl b/tests/resources/catenane_node_2_masked_image.pkl new file mode 100644 index 0000000000..888d2d9db4 Binary files /dev/null and b/tests/resources/catenane_node_2_masked_image.pkl differ diff --git a/tests/resources/catenane_node_2_matched_branches_analyse_node_branches.pkl b/tests/resources/catenane_node_2_matched_branches_analyse_node_branches.pkl new file mode 100644 index 0000000000..23fd7716ed Binary files /dev/null and b/tests/resources/catenane_node_2_matched_branches_analyse_node_branches.pkl differ diff --git a/tests/resources/catenane_node_2_matched_branches_join_matching_branches_through_node.pkl b/tests/resources/catenane_node_2_matched_branches_join_matching_branches_through_node.pkl new file mode 100644 index 0000000000..c49bc350fc Binary files /dev/null and b/tests/resources/catenane_node_2_matched_branches_join_matching_branches_through_node.pkl differ diff --git a/tests/resources/catenane_node_2_ordered_branches.pkl b/tests/resources/catenane_node_2_ordered_branches.pkl new file mode 100644 index 0000000000..8c389f011b Binary files /dev/null and b/tests/resources/catenane_node_2_ordered_branches.pkl differ diff --git a/tests/resources/catenane_node_2_reduced_skeleton_graph.pkl b/tests/resources/catenane_node_2_reduced_skeleton_graph.pkl new file mode 100644 index 0000000000..a95eb04552 Binary files /dev/null and b/tests/resources/catenane_node_2_reduced_skeleton_graph.pkl differ diff --git a/tests/resources/catenane_node_3_masked_image.pkl b/tests/resources/catenane_node_3_masked_image.pkl new file mode 100644 index 0000000000..5e2852a4a9 Binary files /dev/null and b/tests/resources/catenane_node_3_masked_image.pkl differ diff --git a/tests/resources/catenane_node_3_matched_branches_analyse_node_branches.pkl b/tests/resources/catenane_node_3_matched_branches_analyse_node_branches.pkl new file mode 100644 index 0000000000..d63c387f4b Binary files /dev/null and b/tests/resources/catenane_node_3_matched_branches_analyse_node_branches.pkl differ diff --git a/tests/resources/catenane_node_3_matched_branches_join_matching_branches_through_node.pkl b/tests/resources/catenane_node_3_matched_branches_join_matching_branches_through_node.pkl new file mode 100644 index 0000000000..d673f847b8 Binary files /dev/null and b/tests/resources/catenane_node_3_matched_branches_join_matching_branches_through_node.pkl differ diff --git a/tests/resources/catenane_node_3_ordered_branches.pkl b/tests/resources/catenane_node_3_ordered_branches.pkl new file mode 100644 index 0000000000..f5f642fcef Binary files /dev/null and b/tests/resources/catenane_node_3_ordered_branches.pkl differ diff --git a/tests/resources/catenane_node_3_reduced_skeleton_graph.pkl b/tests/resources/catenane_node_3_reduced_skeleton_graph.pkl new file mode 100644 index 0000000000..2c8b41b740 Binary files /dev/null and b/tests/resources/catenane_node_3_reduced_skeleton_graph.pkl differ diff --git a/tests/resources/catenane_node_4_masked_image.pkl b/tests/resources/catenane_node_4_masked_image.pkl new file mode 100644 index 0000000000..27661148b6 Binary files /dev/null and b/tests/resources/catenane_node_4_masked_image.pkl differ diff --git a/tests/resources/catenane_node_4_matched_branches_analyse_node_branches.pkl b/tests/resources/catenane_node_4_matched_branches_analyse_node_branches.pkl new file mode 100644 index 0000000000..d926dc0111 Binary files /dev/null and b/tests/resources/catenane_node_4_matched_branches_analyse_node_branches.pkl differ diff --git a/tests/resources/catenane_node_4_matched_branches_join_matching_branches_through_node.pkl b/tests/resources/catenane_node_4_matched_branches_join_matching_branches_through_node.pkl new file mode 100644 index 0000000000..49e7e13c04 Binary files /dev/null and b/tests/resources/catenane_node_4_matched_branches_join_matching_branches_through_node.pkl differ diff --git a/tests/resources/catenane_node_4_ordered_branches.pkl b/tests/resources/catenane_node_4_ordered_branches.pkl new file mode 100644 index 0000000000..7333cdc76d Binary files /dev/null and b/tests/resources/catenane_node_4_ordered_branches.pkl differ diff --git a/tests/resources/catenane_node_4_reduced_skeleton_graph.pkl b/tests/resources/catenane_node_4_reduced_skeleton_graph.pkl new file mode 100644 index 0000000000..c8ca6eab18 Binary files /dev/null and b/tests/resources/catenane_node_4_reduced_skeleton_graph.pkl differ diff --git a/tests/resources/catenane_node_centre_mask.npy b/tests/resources/catenane_node_centre_mask.npy new file mode 100644 index 0000000000..981b3b7a07 Binary files /dev/null and b/tests/resources/catenane_node_centre_mask.npy differ diff --git a/tests/resources/catenane_skeleton.npy b/tests/resources/catenane_skeleton.npy new file mode 100644 index 0000000000..45f0a5b41d Binary files /dev/null and b/tests/resources/catenane_skeleton.npy differ diff --git a/tests/resources/catenane_smoothed_mask.npy b/tests/resources/catenane_smoothed_mask.npy new file mode 100644 index 0000000000..328acc821a Binary files /dev/null and b/tests/resources/catenane_smoothed_mask.npy differ diff --git a/tests/resources/example_catenanes.npy b/tests/resources/example_catenanes.npy new file mode 100644 index 0000000000..fe48e84ccc Binary files /dev/null and b/tests/resources/example_catenanes.npy differ diff --git a/tests/resources/example_catenanes_labelled_grain_mask_thresholded.npy b/tests/resources/example_catenanes_labelled_grain_mask_thresholded.npy new file mode 100644 index 0000000000..7570f4ae8d Binary files /dev/null and b/tests/resources/example_catenanes_labelled_grain_mask_thresholded.npy differ diff --git a/tests/resources/example_rep_int.npy b/tests/resources/example_rep_int.npy new file mode 100644 index 0000000000..7df4fd5f76 Binary files /dev/null and b/tests/resources/example_rep_int.npy differ diff --git a/tests/resources/example_rep_int_labelled_grain_mask_thresholded.npy b/tests/resources/example_rep_int_labelled_grain_mask_thresholded.npy new file mode 100644 index 0000000000..182a91e46a Binary files /dev/null and b/tests/resources/example_rep_int_labelled_grain_mask_thresholded.npy differ diff --git a/tests/resources/minicircle_default_all_statistics.csv b/tests/resources/minicircle_default_all_statistics.csv index fe8065d0cc..971d2b01c0 100644 --- a/tests/resources/minicircle_default_all_statistics.csv +++ b/tests/resources/minicircle_default_all_statistics.csv @@ -1,4 +1,4 @@ -molecule_number,centre_x,centre_y,radius_min,radius_max,radius_mean,radius_median,height_min,height_max,height_median,height_mean,volume,area,area_cartesian_bbox,smallest_bounding_width,smallest_bounding_length,smallest_bounding_area,aspect_ratio,threshold,max_feret,min_feret,contour_length,circular,end_to_end_distance,image,basename +grain_number,centre_x,centre_y,radius_min,radius_max,radius_mean,radius_median,height_min,height_max,height_median,height_mean,volume,area,area_cartesian_bbox,smallest_bounding_width,smallest_bounding_length,smallest_bounding_area,aspect_ratio,threshold,max_feret,min_feret,contour_length,circular,end_to_end_distance,image,basename 0,18.72534751358894,96.56689212959863,5.918489103781348,20.866157242818016,15.491706069705224,16.218426875142722,0.7452716544633669,3.8599067676448637,2.0391347095418917,1.9481343856971967,1220.4045823987904,626.4478422837515,1137.465287450162,24.292634523076288,40.8496256505361,992.3450263329557,1.6815642458100555,upper,41.33420255276943,24.126052806347744,113.16428314861125,False,12.929293977622509,minicircle.spm,tests/resources 1,400.2920881110063,93.12806953798491,0.7643280440003184,21.96855256891761,13.999840272706289,14.792429582135073,0.7454383438304222,3.179159623594896,2.131409254934984,1.984183560332176,1263.808882457969,636.9415147489642,1186.0290274635886,27.45053797699605,39.09250720831134,1073.1103537377423,1.4241071428571428,upper,39.14797897350707,27.45053797699605,115.66050824497863,False,21.47066537827239,minicircle.spm,tests/resources 2,218.26129518646468,172.6865624592939,8.312817001643236,28.4840908864891,17.761201297555598,16.374812395420516,0.7458769816768294,3.0837823469355783,1.965905358311789,1.8574106218419255,1152.692145851135,620.5909088147954,1618.9540263772524,24.48642216522931,54.1892216956022,1326.9001592437185,2.2130314232902033,upper,54.20767776942792,24.48642216522931,90.27296454772564,True,0.0,minicircle.spm,tests/resources diff --git a/tests/resources/nodestats_analyse_nodes_catenane_all_connected_nodes.npy b/tests/resources/nodestats_analyse_nodes_catenane_all_connected_nodes.npy new file mode 100644 index 0000000000..e57b034e14 Binary files /dev/null and b/tests/resources/nodestats_analyse_nodes_catenane_all_connected_nodes.npy differ diff --git a/tests/resources/nodestats_analyse_nodes_catenane_image_dict.pkl b/tests/resources/nodestats_analyse_nodes_catenane_image_dict.pkl new file mode 100644 index 0000000000..5f60e204ea Binary files /dev/null and b/tests/resources/nodestats_analyse_nodes_catenane_image_dict.pkl differ diff --git a/tests/resources/nodestats_analyse_nodes_catenane_node_dict.pkl b/tests/resources/nodestats_analyse_nodes_catenane_node_dict.pkl new file mode 100644 index 0000000000..33017d01f8 Binary files /dev/null and b/tests/resources/nodestats_analyse_nodes_catenane_node_dict.pkl differ diff --git a/tests/resources/process_scan_expected_below_height_profiles.pickle b/tests/resources/process_scan_expected_below_height_profiles.pickle index 966bd97f4f..a1fc85f388 100644 Binary files a/tests/resources/process_scan_expected_below_height_profiles.pickle and b/tests/resources/process_scan_expected_below_height_profiles.pickle differ diff --git a/tests/resources/process_scan_topostats_file_regtest.topostats b/tests/resources/process_scan_topostats_file_regtest.topostats index 40e14b6a1f..4012b7a1f3 100644 Binary files a/tests/resources/process_scan_topostats_file_regtest.topostats and b/tests/resources/process_scan_topostats_file_regtest.topostats differ diff --git a/tests/resources/toposum_all_statistics_multiple_directories.csv b/tests/resources/toposum_all_statistics_multiple_directories.csv index d62a21afb2..b9b6a7c6e4 100644 --- a/tests/resources/toposum_all_statistics_multiple_directories.csv +++ b/tests/resources/toposum_all_statistics_multiple_directories.csv @@ -1,4 +1,4 @@ -molecule_number,centre_x,centre_y,radius_min,radius_max,radius_mean,radius_median,height_min,height_max,height_median,height_mean,volume,area,area_cartesian_bbox,smallest_bounding_width,smallest_bounding_length,smallest_bounding_area,aspect_ratio,threshold,max_feret,min_feret,contour_lengths,circular,end_to_end_distance,image,basename +grain_number,centre_x,centre_y,radius_min,radius_max,radius_mean,radius_median,height_min,height_max,height_median,height_mean,volume,area,area_cartesian_bbox,smallest_bounding_width,smallest_bounding_length,smallest_bounding_area,aspect_ratio,threshold,max_feret,min_feret,contour_lengths,circular,end_to_end_distance,image,basename 0,18.72534751358894,96.56689212959863,5.918489103781348,20.866157242818016,15.491706069705224,16.218426875142722,0.7452716544633669,3.8599067676448637,2.0391347095418917,1.9481343856971967,1220.4045823987904,626.4478422837515,1137.465287450162,24.292634523076288,40.8496256505361,992.3450263329557,1.6815642458100555,upper,41.33420255276943,24.126052806347744,113.16428314861125,False,12.929293977622509,minicircle.spm,parent_dir/child_dir_1 1,400.2920881110063,93.12806953798491,0.7643280440003184,21.96855256891761,13.999840272706289,14.792429582135073,0.7454383438304222,3.179159623594896,2.131409254934984,1.984183560332176,1263.808882457969,636.9415147489642,1186.0290274635886,27.45053797699605,39.09250720831134,1073.1103537377423,1.4241071428571428,upper,39.14797897350707,27.45053797699605,115.66050824497863,False,21.47066537827239,minicircle.spm,parent_dir/child_dir_1 2,218.26129518646468,172.6865624592939,8.312817001643236,28.4840908864891,17.761201297555598,16.374812395420516,0.7458769816768294,3.0837823469355783,1.965905358311789,1.8574106218419255,1152.692145851135,620.5909088147954,1618.9540263772524,24.48642216522931,54.1892216956022,1326.9001592437185,2.2130314232902033,upper,54.20767776942792,24.48642216522931,90.27296454772564,True,0.0,minicircle.spm,parent_dir/child_dir_1 @@ -19,4 +19,4 @@ molecule_number,centre_x,centre_y,radius_min,radius_max,radius_mean,radius_media 4,460.81528892859257,458.61529179058647,8.335567737084954,28.018485456887074,16.951505396530333,16.480914927347325,0.746481994571252,3.1528255105319674,2.161743197106236,1.9909999302468957,1282.72695409673,644.2626815851593,1616.513637431854,21.77853967542204,52.651771534465524,1146.678695344614,2.417598806860552,upper,53.354603407833075,21.77853967542204,69.81455411588722,True,0.0,minicircle.spm,parent_dir/child_dir_3/child_dir_4 5,419.7119810116665,456.33581467682035,10.647572346135227,24.867679326046783,18.230207187178145,18.820514568860247,0.7460540784349388,3.0467196263335836,2.133687174437284,1.9766412306722034,1370.9164026533533,693.5585382822055,1625.2990376352882,35.0742080078125,42.9782548828125,1507.4282515725488,1.2253521126760563,upper,45.96616961337843,33.74603230760295,101.81981258840132,True,0.0,minicircle.spm,parent_dir/child_dir_3/child_dir_4 6,90.92468996474322,469.0147359066234,4.419162608876172,25.241879915828346,16.11511146882466,16.417253423698206,0.7465371311191019,2.9364197117159456,2.109230218842712,1.9584847703176893,1300.0143666342835,663.7857931483459,1600.1630314976853,21.15084957523595,49.01725079602158,1036.7564981782668,2.317507418397626,upper,49.08311284848621,21.150849575235945,104.65035017927731,False,14.128829241433397,minicircle.spm,parent_dir/child_dir_3/child_dir_4 -7,19.82883778582696,501.3507324051849,2.0456794027695278,25.469796537814755,16.062271878609476,16.64728658159786,0.7453685468075075,2.9898919464281057,2.1262758767778243,1.9713545047879897,1210.4153300219657,614.0018586622199,1366.617809423065,20.435629724999576,50.04784121900791,1022.7591516872169,2.4490481522956324,upper,50.15772087935355,20.435629724999572,102.53757918370543,False,20.559070147820183,minicircle.spm,parent_dir/child_dir_3/child_dir_4 \ No newline at end of file +7,19.82883778582696,501.3507324051849,2.0456794027695278,25.469796537814755,16.062271878609476,16.64728658159786,0.7453685468075075,2.9898919464281057,2.1262758767778243,1.9713545047879897,1210.4153300219657,614.0018586622199,1366.617809423065,20.435629724999576,50.04784121900791,1022.7591516872169,2.4490481522956324,upper,50.15772087935355,20.435629724999572,102.53757918370543,False,20.559070147820183,minicircle.spm,parent_dir/child_dir_3/child_dir_4 diff --git a/tests/resources/toposum_all_statistics_single_directory.csv b/tests/resources/toposum_all_statistics_single_directory.csv index 065c6d4fb5..04c9cb5a94 100644 --- a/tests/resources/toposum_all_statistics_single_directory.csv +++ b/tests/resources/toposum_all_statistics_single_directory.csv @@ -1,4 +1,4 @@ -molecule_number,centre_x,centre_y,radius_min,radius_max,radius_mean,radius_median,height_min,height_max,height_median,height_mean,volume,area,area_cartesian_bbox,smallest_bounding_width,smallest_bounding_length,smallest_bounding_area,aspect_ratio,threshold,max_feret,min_feret,contour_lengths,circular,end_to_end_distance,image,basename +grain_number,centre_x,centre_y,radius_min,radius_max,radius_mean,radius_median,height_min,height_max,height_median,height_mean,volume,area,area_cartesian_bbox,smallest_bounding_width,smallest_bounding_length,smallest_bounding_area,aspect_ratio,threshold,max_feret,min_feret,contour_lengths,circular,end_to_end_distance,image,basename 0,18.72534751358894,96.56689212959863,5.918489103781348,20.866157242818016,15.491706069705224,16.218426875142722,0.7452716544633669,3.8599067676448637,2.0391347095418917,1.9481343856971967,1220.4045823987904,626.4478422837515,1137.465287450162,24.292634523076288,40.8496256505361,992.3450263329557,1.6815642458100555,upper,41.33420255276943,24.126052806347744,113.16428314861125,False,12.929293977622509,minicircle.spm,parent_dir/child_dir_1 1,400.2920881110063,93.12806953798491,0.7643280440003184,21.96855256891761,13.999840272706289,14.792429582135073,0.7454383438304222,3.179159623594896,2.131409254934984,1.984183560332176,1263.808882457969,636.9415147489642,1186.0290274635886,27.45053797699605,39.09250720831134,1073.1103537377423,1.4241071428571428,upper,39.14797897350707,27.45053797699605,115.66050824497863,False,21.47066537827239,minicircle.spm,parent_dir/child_dir_1 2,218.26129518646468,172.6865624592939,8.312817001643236,28.4840908864891,17.761201297555598,16.374812395420516,0.7458769816768294,3.0837823469355783,1.965905358311789,1.8574106218419255,1152.692145851135,620.5909088147954,1618.9540263772524,24.48642216522931,54.1892216956022,1326.9001592437185,2.2130314232902033,upper,54.20767776942792,24.48642216522931,90.27296454772564,True,0.0,minicircle.spm,parent_dir/child_dir_1 @@ -6,4 +6,4 @@ molecule_number,centre_x,centre_y,radius_min,radius_max,radius_mean,radius_media 4,301.3416206102413,211.61756844725574,10.256489210615582,25.29455404248287,18.031900182444964,18.452561215234933,0.7452246436395157,2.997930729110457,1.8882842984021164,1.828628948414764,1324.0432939099048,724.0634000996846,1739.50924027993,33.10304661739775,44.076669063566506,1459.0720307508554,1.3314988669472319,upper,44.12416912122865,33.103046617397744,95.99042452161821,True,0.0,minicircle.spm,parent_dir/child_dir_1 5,432.9621809446219,215.2885822162342,2.880805118744918,23.596602791673394,15.320545978528381,15.933506166242601,0.7464484840406989,3.239939630220636,2.1660416821287702,2.0180826496579956,1247.971329684413,618.394558763937,1273.8830294979284,20.832593986797836,45.91657123529629,956.5612858108079,2.2040736388562476,upper,46.20710214165134,20.832593986797836,92.98492331644334,False,22.535415228865716,minicircle.spm,parent_dir/child_dir_1 6,194.45265142078358,231.24213759821308,9.00198665937126,23.654796974035794,17.933819108767963,18.10868586793719,0.7477364070983863,2.8897474559397125,2.0249157366744233,1.9043852670799726,1304.0718754493323,684.7731380787716,1559.408536109533,29.68436723824748,45.552010816510666,1352.182617517924,1.5345454545454542,upper,46.14368163530131,29.68436723824747,97.49636171317444,True,0.0,minicircle.spm,parent_dir/child_dir_1 -7,390.2341825538701,273.9228450843562,2.083432106712806,25.56974100134752,16.749517580670158,18.505450010697373,0.7453047159088426,3.0346042893430605,2.2166123872313888,2.035879305318322,1448.2703377768753,711.3733775836133,1792.9537581841535,22.849351267189622,49.22624872637181,1124.787848714915,2.1543827722171667,upper,49.462004789149276,22.849351267189622,113.36677817432933,False,22.627275592624265,minicircle.spm,parent_dir/child_dir_1 \ No newline at end of file +7,390.2341825538701,273.9228450843562,2.083432106712806,25.56974100134752,16.749517580670158,18.505450010697373,0.7453047159088426,3.0346042893430605,2.2166123872313888,2.035879305318322,1448.2703377768753,711.3733775836133,1792.9537581841535,22.849351267189622,49.22624872637181,1124.787848714915,2.1543827722171667,upper,49.462004789149276,22.849351267189622,113.36677817432933,False,22.627275592624265,minicircle.spm,parent_dir/child_dir_1 diff --git a/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_all_images.pkl b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_all_images.pkl new file mode 100644 index 0000000000..986125e09e Binary files /dev/null and b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_all_images.pkl differ diff --git a/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_crop_data.pkl b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_crop_data.pkl new file mode 100644 index 0000000000..54aa6f9c60 Binary files /dev/null and b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_crop_data.pkl differ diff --git a/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_grainstats.csv b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_grainstats.csv new file mode 100644 index 0000000000..dbeceb405d --- /dev/null +++ b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_grainstats.csv @@ -0,0 +1,3 @@ +,image,grain_number,grain_endpoints,grain_junctions,total_branch_lengths +0,test_image,0,0,14,5.755249825836855e-07 +1,test_image,1,0,12,5.747857998943139e-07 diff --git a/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_stats.csv b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_stats.csv new file mode 100644 index 0000000000..4e0198f6f8 --- /dev/null +++ b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_stats.csv @@ -0,0 +1,25 @@ +,image,grain_number,branch_distance,branch_type,connected_segments,mean_pixel_value,stdev_pixel_value,min_value,median_value,middle_value +0,test_image,0,29.857677897827887,2,"[1, 2, 6]",2.6683043052890536,0.3576668425125195,1.993486285613394,2.6464969869174926,2.6743964509318303 +1,test_image,0,28.644860587199464,2,"[0, 2, 6]",2.728953130422273,0.21545195325432523,2.5506513793095915,2.6838064728094184,2.6579579145357006 +2,test_image,0,1.8682724368761408,2,"[0, 1, 3, 4]",4.067541277649083,0.16905835685823303,3.7890382980796193,4.138536286707056,4.13958834095104 +3,test_image,0,146.60562184380709,2,"[2, 4, 9]",2.6693913722865377,0.2947637179915079,1.2200007753773205,2.7080175948240655,2.9115862857598005 +4,test_image,0,26.170179495009112,2,"[2, 3, 9]",2.834959559756504,0.3436409841063358,2.535374801513602,2.7607877845412387,2.7651337225310604 +5,test_image,0,224.58506419260596,2,"[6, 7, 10, 11]",2.700995491724554,0.2653435468736685,1.787397162651758,2.7330374920770115,2.2573671026451536 +6,test_image,0,6.9504086553142095,2,"[0, 1, 5, 7]",3.626373814193468,0.6406710041895474,2.588846104614013,3.84195778773634,3.84195778773634 +7,test_image,0,31.203269242513674,2,"[5, 6, 8, 11]",2.843560036090906,0.5572924823368127,2.1307434183593315,2.705164120987648,2.75295411300335 +8,test_image,0,31.286996805637532,2,"[7, 9, 10, 11]",2.7443198761048193,0.3822328907364185,2.4149214282007074,2.653495727717319,2.4706619619066386 +9,test_image,0,5.570136218438069,2,"[3, 4, 8, 10]",2.7308714736136963,0.049010037214155706,2.6525934702850984,2.7378802570467666,2.749457125278597 +10,test_image,0,38.18835899001824,2,"[5, 8, 9, 11]",2.7864423529784395,0.15142352544214144,2.445066671627662,2.7710339081422832,2.6967078592376783 +11,test_image,0,4.59413621843807,2,"[5, 7, 8, 10]",4.126603901043527,0.37133617585304246,3.351114087306773,4.238726491399509,4.238726491399509 +0,test_image,1,37.70035899001824,2,"[1, 2, 3, 6]",2.7732793700561382,0.13491385165238218,2.445066671627662,2.7690167332032654,2.7141994425994813 +1,test_image,1,223.89492797416787,2,"[0, 2, 4, 9]",2.69655183205219,0.26110756906553556,1.8203021919659572,2.7314109425842963,2.2743843560760055 +2,test_image,1,5.08213621843807,2,"[0, 1, 3, 4]",4.029037066510506,0.4638075153095662,3.0770475331177405,4.18502108746829,4.18502108746829 +3,test_image,1,31.001133024075603,2,"[0, 2, 4, 6]",2.746190739871742,0.4020468170580353,2.443926567140926,2.6515637708741715,2.443926567140926 +4,test_image,1,31.405405460951744,2,"[1, 2, 3, 9]",2.8422583900088516,0.5622599874985189,2.1307434183593315,2.694237316611716,2.75295411300335 +5,test_image,1,147.49789428068328,2,"[6, 7, 11]",2.6728103603722104,0.2940846131940484,1.2200007753773205,2.712208521357603,2.9467266502403127 +6,test_image,1,5.77227243687614,2,"[0, 3, 5, 7]",2.7225654392347494,0.051361792194696554,2.6525934702850984,2.7127873939436142,2.7180535973290754 +7,test_image,1,25.682179495009112,2,"[5, 6, 11]",2.85129816853438,0.36309882085275597,2.535374801513602,2.7617733887534093,2.770465264733052 +8,test_image,1,27.870996805637535,2,"[9, 10, 11]",2.7169147220384864,0.1798385572282301,2.5506513793095915,2.6863743944825202,2.6579579145357006 +9,test_image,1,6.9504086553142095,2,"[1, 4, 8, 10]",3.679401125994582,0.5950886464826225,2.6485713386465375,3.84195778773634,3.84195778773634 +10,test_image,1,29.369677897827888,2,"[8, 9, 11]",2.640305654101202,0.3106025836163093,1.993486285613394,2.648283332587627,2.682776277975024 +11,test_image,1,2.558408655314211,2,"[5, 7, 8, 10]",3.974846970762266,0.23923528279875325,3.6040697432150033,4.075122442799481,4.075122442799481 diff --git a/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_all_images.pkl b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_all_images.pkl new file mode 100644 index 0000000000..22bd808c7e Binary files /dev/null and b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_all_images.pkl differ diff --git a/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_crop_data.pkl b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_crop_data.pkl new file mode 100644 index 0000000000..0cfa26e107 Binary files /dev/null and b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_crop_data.pkl differ diff --git a/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_grainstats.csv b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_grainstats.csv new file mode 100644 index 0000000000..f8c476af20 --- /dev/null +++ b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_grainstats.csv @@ -0,0 +1,2 @@ +,image,grain_number,grain_endpoints,grain_junctions,total_branch_lengths +0,test_image,0,0,13,9.685225788725929e-07 diff --git a/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_stats.csv b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_stats.csv new file mode 100644 index 0000000000..e503c214bb --- /dev/null +++ b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_stats.csv @@ -0,0 +1,13 @@ +,image,grain_number,branch_distance,branch_type,connected_segments,mean_pixel_value,stdev_pixel_value,min_value,median_value,middle_value +0,test_image,0,172.69207377569276,2,"[1, 2, 3, 8]",2.147391874570661,0.18580255309380492,1.4117732114082249,2.147881683634531,2.193944711921703 +1,test_image,0,338.03541679389866,2,"[0, 2, 7, 11]",2.1261598928571943,0.1841696602430932,1.127883594233125,2.1329128300921694,2.236342991231205 +2,test_image,0,0.6901362184380704,2,"[0, 1, 3, 4]",3.5129004230159246,0.2441665496774352,3.2687338733384856,3.5129004230159246,3.5129004230159246 +3,test_image,0,75.41508335877963,2,"[0, 2, 4, 8]",2.1781050277619003,0.22952699769811186,1.4117732114082249,2.178401374885392,2.2734312397952836 +4,test_image,0,31.893405460951744,2,"[2, 3, 5, 6]",2.1378544765518392,0.338975206263122,1.7555674668849865,2.094437231758162,2.2880045741720902 +5,test_image,0,51.51744873752279,2,"[4, 6, 7]",2.233921716584725,0.23610837345626215,1.7265179765777217,2.206985179983061,2.0220519901178635 +6,test_image,0,1.1781362184380704,2,"[4, 5, 7]",3.4335547526985297,0.29158462263860985,3.0268131230204918,3.578160584306969,3.6956905507681284 +7,test_image,0,153.8275289019402,2,"[1, 5, 6, 11]",2.1708118960130536,0.23399433389250326,1.32021443851052,2.1773699140297205,2.1908233129216366 +8,test_image,0,58.753721174398926,2,"[0, 3, 10, 11]",2.0020290183250613,0.3370066648426065,0.8859157914742134,2.08256912431822,2.247454035474802 +9,test_image,0,38.962222771580166,3,[10],2.2735993718579453,0.17804635219989257,2.0015543670262534,2.2595377460435833,2.2032782158757636 +10,test_image,0,0.488,2,"[8, 9, 11]",2.875087133879444,0.009881348191510873,2.865205785688025,2.875087133879444,2.875087133879444 +11,test_image,0,45.069405460951735,2,"[1, 7, 8, 10]",2.1854926308595277,0.22637773949873832,1.289281969607514,2.227074000531621,2.227074000531621 diff --git a/tests/resources/tracing/nodestats/catenanes_nodestats_all_images.pkl b/tests/resources/tracing/nodestats/catenanes_nodestats_all_images.pkl new file mode 100644 index 0000000000..a9b06ffe05 Binary files /dev/null and b/tests/resources/tracing/nodestats/catenanes_nodestats_all_images.pkl differ diff --git a/tests/resources/tracing/nodestats/catenanes_nodestats_branch_images.pkl b/tests/resources/tracing/nodestats/catenanes_nodestats_branch_images.pkl new file mode 100644 index 0000000000..676f09ceb3 Binary files /dev/null and b/tests/resources/tracing/nodestats/catenanes_nodestats_branch_images.pkl differ diff --git a/tests/resources/tracing/nodestats/catenanes_nodestats_data.pkl b/tests/resources/tracing/nodestats/catenanes_nodestats_data.pkl new file mode 100644 index 0000000000..2f78a8b297 Binary files /dev/null and b/tests/resources/tracing/nodestats/catenanes_nodestats_data.pkl differ diff --git a/tests/resources/tracing/nodestats/catenanes_nodestats_grainstats.csv b/tests/resources/tracing/nodestats/catenanes_nodestats_grainstats.csv new file mode 100644 index 0000000000..3e0c534882 --- /dev/null +++ b/tests/resources/tracing/nodestats/catenanes_nodestats_grainstats.csv @@ -0,0 +1,3 @@ +,image,grain_number,num_crossings,avg_crossing_confidence,min_crossing_confidence +grain_0,test_image,0,4,0.4013589828832889,0.2129989376767838 +grain_1,test_image,1,4,0.3441057054647598,0.17063184531586506 diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_no_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_no_pair_odd_branches.pkl new file mode 100644 index 0000000000..6d5cbc25f3 Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_no_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_pair_odd_branches.pkl new file mode 100644 index 0000000000..6d5cbc25f3 Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_no_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_no_pair_odd_branches.pkl new file mode 100644 index 0000000000..a16811b5bf Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_no_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_pair_odd_branches.pkl new file mode 100644 index 0000000000..3d12688a43 Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_data_no_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_data_no_pair_odd_branches.pkl new file mode 100644 index 0000000000..b85a1749aa Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_data_no_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_data_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_data_pair_odd_branches.pkl new file mode 100644 index 0000000000..d00659696d Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_data_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_no_pair_odd_branches.csv b/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_no_pair_odd_branches.csv new file mode 100644 index 0000000000..c08b2d241c --- /dev/null +++ b/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_no_pair_odd_branches.csv @@ -0,0 +1,2 @@ +,image,grain_number,num_crossings,avg_crossing_confidence,min_crossing_confidence +grain_0,test_image,0,5,0.07082753253520613,0.01059564637975774 diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_pair_odd_branches.csv b/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_pair_odd_branches.csv new file mode 100644 index 0000000000..c08b2d241c --- /dev/null +++ b/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_pair_odd_branches.csv @@ -0,0 +1,2 @@ +,image,grain_number,num_crossings,avg_crossing_confidence,min_crossing_confidence +grain_0,test_image,0,5,0.07082753253520613,0.01059564637975774 diff --git a/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_data.pkl b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_data.pkl new file mode 100644 index 0000000000..29949ab944 Binary files /dev/null and b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_data.pkl differ diff --git a/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_full_images.pkl b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_full_images.pkl new file mode 100644 index 0000000000..a41553a5e1 Binary files /dev/null and b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_full_images.pkl differ diff --git a/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_grainstats.csv b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_grainstats.csv new file mode 100644 index 0000000000..722a3a0b22 --- /dev/null +++ b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_grainstats.csv @@ -0,0 +1,3 @@ +,image,grain_number,num_mols,writhe_string +grain_0,catenane,0,2,++++ +grain_1,catenane,1,2,+-++ diff --git a/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_molstats.csv b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_molstats.csv new file mode 100644 index 0000000000..55ecf1c326 --- /dev/null +++ b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_molstats.csv @@ -0,0 +1,5 @@ +,image,grain_number,molecule_number,circular,topology,topology_flip,processing +0,catenane,0,0,True,4^2_1,2^2_1,nodestats +1,catenane,0,1,True,4^2_1,2^2_1,nodestats +2,catenane,1,0,True,2^2_1,0_1U0_1,nodestats +3,catenane,1,1,True,2^2_1,0_1U0_1,nodestats diff --git a/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_data.pkl b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_data.pkl new file mode 100644 index 0000000000..bcfd1062da Binary files /dev/null and b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_data.pkl differ diff --git a/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_full_images.pkl b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_full_images.pkl new file mode 100644 index 0000000000..4e4b9c550b Binary files /dev/null and b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_full_images.pkl differ diff --git a/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_grainstats.csv b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_grainstats.csv new file mode 100644 index 0000000000..d350d7923e --- /dev/null +++ b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_grainstats.csv @@ -0,0 +1,2 @@ +,image,grain_number,num_mols,writhe_string +grain_0,replication_intermediate,0,3,--- diff --git a/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_molstats.csv b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_molstats.csv new file mode 100644 index 0000000000..734d5cc3eb --- /dev/null +++ b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_molstats.csv @@ -0,0 +1,4 @@ +,image,grain_number,molecule_number,circular,topology,topology_flip,processing +0,replication_intermediate,0,0,False,linear,linear,nodestats +1,replication_intermediate,0,1,False,linear,linear,nodestats +2,replication_intermediate,0,2,False,linear,linear,nodestats diff --git a/tests/resources/tracing/splining/catenanes_splining_data.pkl b/tests/resources/tracing/splining/catenanes_splining_data.pkl new file mode 100644 index 0000000000..fce50d9faf Binary files /dev/null and b/tests/resources/tracing/splining/catenanes_splining_data.pkl differ diff --git a/tests/resources/tracing/splining/catenanes_splining_grainstats.csv b/tests/resources/tracing/splining/catenanes_splining_grainstats.csv new file mode 100644 index 0000000000..1981a8a281 --- /dev/null +++ b/tests/resources/tracing/splining/catenanes_splining_grainstats.csv @@ -0,0 +1,3 @@ +,image,grain_number,total_contour_length,average_end_to_end_distance +grain_0,catenane,0,1.1133149766322024e-06,0.0 +grain_1,catenane,1,1.1134528311181875e-06,0.0 diff --git a/tests/resources/tracing/splining/catenanes_splining_molstats.csv b/tests/resources/tracing/splining/catenanes_splining_molstats.csv new file mode 100644 index 0000000000..705a1cf40a --- /dev/null +++ b/tests/resources/tracing/splining/catenanes_splining_molstats.csv @@ -0,0 +1,5 @@ +,image,grain_number,molecule_number,contour_length,end_to_end_distance +0,catenane,0,0,8.466004494307275e-07,0.0 +1,catenane,0,1,2.667145272014749e-07,0.0 +2,catenane,1,0,8.465829590713624e-07,0.0 +3,catenane,1,1,2.66869872046825e-07,0.0 diff --git a/tests/resources/tracing/splining/rep_int_splining_data.pkl b/tests/resources/tracing/splining/rep_int_splining_data.pkl new file mode 100644 index 0000000000..69e845be22 Binary files /dev/null and b/tests/resources/tracing/splining/rep_int_splining_data.pkl differ diff --git a/tests/resources/tracing/splining/rep_int_splining_grainstats.csv b/tests/resources/tracing/splining/rep_int_splining_grainstats.csv new file mode 100644 index 0000000000..1577bc87ab --- /dev/null +++ b/tests/resources/tracing/splining/rep_int_splining_grainstats.csv @@ -0,0 +1,2 @@ +,image,grain_number,total_contour_length,average_end_to_end_distance +grain_0,replication_intermediate,0,1.7737493902268596e-06,1.6597964434071086e-07 diff --git a/tests/resources/tracing/splining/rep_int_splining_molstats.csv b/tests/resources/tracing/splining/rep_int_splining_molstats.csv new file mode 100644 index 0000000000..286b4241ca --- /dev/null +++ b/tests/resources/tracing/splining/rep_int_splining_molstats.csv @@ -0,0 +1,4 @@ +,image,grain_number,molecule_number,contour_length,end_to_end_distance +0,replication_intermediate,0,0,7.482001882005334e-07,1.6788686666919485e-07 +1,replication_intermediate,0,1,7.667436650901011e-07,1.670029939851379e-07 +2,replication_intermediate,0,2,2.5880553693622497e-07,1.6304907236779975e-07 diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 890679d466..f4bcfe40d4 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -44,10 +44,6 @@ def test_entry_point_help(capsys, option) -> None: ("grains", "--help"), ("grainstats", "-h"), ("grainstats", "--help"), - ("dnatracing", "-h"), - ("dnatracing", "--help"), - ("tracingstats", "-h"), - ("tracingstats", "--help"), ("create-config", "-h"), ("create-config", "--help"), ], diff --git a/tests/test_grains.py b/tests/test_grains.py index 6a43107696..4bfdb00335 100644 --- a/tests/test_grains.py +++ b/tests/test_grains.py @@ -118,6 +118,64 @@ def test_remove_small_objects(): np.testing.assert_array_equal(result, expected) +@pytest.mark.parametrize( + ("binary_image", "minimum_size_px", "minimum_bbox_size_px", "expected_image"), + [ + pytest.param( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + 8, + 4, + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + ) + ], +) +def test_remove_objects_too_small_to_process( + binary_image: npt.NDArray, minimum_size_px: int, minimum_bbox_size_px: int, expected_image: npt.NDArray +) -> None: + """Test the remove_objects_too_small_to_process method of the Grains class.""" + grains_object = Grains( + image=np.array([[0, 0], [0, 0]]), + filename="", + pixel_to_nm_scaling=1.0, + ) + + result = grains_object.remove_objects_too_small_to_process( + image=binary_image, minimum_size_px=minimum_size_px, minimum_bbox_size_px=minimum_bbox_size_px + ) + + np.testing.assert_array_equal(result, expected_image) + + @pytest.mark.parametrize( ("test_labelled_image", "area_thresholds", "expected"), [ @@ -383,6 +441,10 @@ def test_find_grains( remove_edge_intersecting_grains=remove_edge_intersecting_grains, ) + # Override grains' minimum grain size just for this test to allow for small grains in the test image + grains_object.minimum_grain_size_px = 1 + grains_object.minimum_bbox_size_px = 1 + grains_object.find_grains() result_removed_small_objects = grains_object.directions[direction]["removed_small_objects"] @@ -543,6 +605,10 @@ def test_find_grains_unet( remove_edge_intersecting_grains=True, ) + # Override grains' minimum grain size just for this test to allow for small grains in the test image + grains_object.minimum_grain_size_px = 1 + grains_object.minimum_bbox_size_px = 1 + grains_object.find_grains() result_removed_small_objects = grains_object.directions["above"]["removed_small_objects"] diff --git a/tests/test_io.py b/tests/test_io.py index 0384deeee0..a01e94fe4b 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -16,6 +16,7 @@ from topostats.io import ( LoadScans, convert_basename_to_relative_paths, + dict_almost_equal, dict_to_hdf5, dict_to_json, find_files, @@ -61,48 +62,6 @@ # pylint: disable=too-many-lines -def dict_almost_equal(dict1, dict2, abs_tol=1e-9): - """Recursively check if two dictionaries are almost equal with a given absolute tolerance. - - Parameters - ---------- - dict1: dict - First dictionary to compare. - dict2: dict - Second dictionary to compare. - abs_tol: float - Absolute tolerance to check for equality. - - Returns - ------- - bool - True if the dictionaries are almost equal, False otherwise. - """ - if dict1.keys() != dict2.keys(): - return False - - LOGGER.info("Comparing dictionaries") - - for key in dict1: - LOGGER.info(f"Comparing key {key}") - if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): - if not dict_almost_equal(dict1[key], dict2[key], abs_tol=abs_tol): - return False - elif isinstance(dict1[key], np.ndarray) and isinstance(dict2[key], np.ndarray): - if not np.allclose(dict1[key], dict2[key], atol=abs_tol): - LOGGER.info(f"Key {key} type: {type(dict1[key])} not equal: {dict1[key]} != {dict2[key]}") - return False - elif isinstance(dict1[key], float) and isinstance(dict2[key], float): - if not np.isclose(dict1[key], dict2[key], atol=abs_tol): - LOGGER.info(f"Key {key} type: {type(dict1[key])} not equal: {dict1[key]} != {dict2[key]}") - return False - elif dict1[key] != dict2[key]: - LOGGER.info(f"Key {key} not equal: {dict1[key]} != {dict2[key]}") - return False - - return True - - def test_get_date_time() -> None: """Test the fetching of a formatted date and time string.""" assert datetime.strptime(get_date_time(), "%Y-%m-%d %H:%M:%S") @@ -283,6 +242,13 @@ def test_load_array() -> None: False, id="float not equal", ), + pytest.param( + {"a": np.nan}, + {"a": np.nan}, + 0.0001, + True, + id="nan equal", + ), ], ) def test_dict_almost_equal(dict1: dict, dict2: dict, tolerance: float, expected: bool) -> None: diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 2e12dd16cf..3a96c4c40d 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -27,7 +27,7 @@ def test_melt_data(): df_to_melt = { "Image": ["im1", "im1", "im1", "im2", "im2", "im3", "im3"], "threshold": ["above", "above", "above", "below", "below", "above", "above"], - "molecule_number": [0, 1, 2, 0, 1, 0, 1], + "grain_number": [0, 1, 2, 0, 1, 0, 1], "basename": ["super/sub1", "super/sub1", "super/sub1", "super/sub1", "super/sub1", "super/sub2", "super/sub2"], "area": [10, 20, 30, 40, 50, 60, 70], } @@ -37,7 +37,7 @@ def test_melt_data(): melted_data = TopoSum.melt_data(df=df_to_melt, stat_to_summarize="area", var_to_label={"area": "AREA"}) expected = { - "molecule_number": [0, 1, 2, 0, 1, 0, 1], + "grain_number": [0, 1, 2, 0, 1, 0, 1], "basename": ["super/sub1", "super/sub1", "super/sub1", "super/sub1", "super/sub1", "super/sub2", "super/sub2"], "variable": ["AREA", "AREA", "AREA", "AREA", "AREA", "AREA", "AREA"], "value": [10, 20, 30, 40, 50, 60, 70], diff --git a/tests/test_plottingfuncs.py b/tests/test_plottingfuncs.py index 24fd1ad605..b8e3d5d82b 100644 --- a/tests/test_plottingfuncs.py +++ b/tests/test_plottingfuncs.py @@ -71,7 +71,6 @@ def test_load_mplstyle(style: str, axes_titlesize: int, font_size: float, image_ def test_dilate_binary_image(binary_image: np.ndarray, dilation_iterations: int, expected: np.ndarray) -> None: """Test the dilate binary images function of plottingfuncs.py.""" result = dilate_binary_image(binary_image=binary_image, dilation_iterations=dilation_iterations) - np.testing.assert_array_equal(result, expected) @@ -253,7 +252,7 @@ def test_plot_and_save_non_square_bounding_box( @pytest.mark.mpl_image_compare(baseline_dir="resources/img/") def test_mask_cmap(plotting_config: dict, tmp_path: Path) -> None: """Test the plotting of a mask with a different colourmap (blu).""" - plotting_config["mask_cmap"] = "blu" + plotting_config["mask_cmap"] = "blue" fig, _ = Images( data=ARRAY, output_dir=tmp_path, diff --git a/tests/test_processing.py b/tests/test_processing.py index 10e0ae8d6d..598bea9791 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -14,7 +14,6 @@ from topostats.processing import ( check_run_steps, process_scan, - run_dnatracing, run_filters, run_grains, run_grainstats, @@ -24,6 +23,8 @@ BASE_DIR = Path.cwd() RESOURCES = BASE_DIR / "tests/resources" +# pylint: disable=too-many-positional-arguments + # Can't see a way of parameterising with pytest-regtest as it writes to a file based on the file/function # so instead we run three regression tests. @@ -33,16 +34,20 @@ def test_process_scan_below(regtest, tmp_path, process_scan_config: dict, load_s process_scan_config["grains"]["threshold_std_dev"]["below"] = 0.8 process_scan_config["grains"]["smallest_grain_size_nm2"] = 10 process_scan_config["grains"]["absolute_area_threshold"]["below"] = [1, 1000000000] - process_scan_config["grains"]["direction"] = "below" + # Make sure the pruning won't remove our only grain + process_scan_config["disordered_tracing"]["pruning_params"]["max_length"] = None img_dic = load_scan_data.img_dict - _, results, img_stats, _ = process_scan( + _, results, _, img_stats, _, _ = process_scan( topostats_object=img_dic["minicircle_small"], base_dir=BASE_DIR, filter_config=process_scan_config["filter"], grains_config=process_scan_config["grains"], grainstats_config=process_scan_config["grainstats"], - dnatracing_config=process_scan_config["dnatracing"], + disordered_tracing_config=process_scan_config["disordered_tracing"], + nodestats_config=process_scan_config["nodestats"], + ordered_tracing_config=process_scan_config["ordered_tracing"], + splining_config=process_scan_config["splining"], plotting_config=process_scan_config["plotting"], output_dir=tmp_path, ) @@ -61,13 +66,16 @@ def test_process_scan_below_height_profiles(tmp_path, process_scan_config: dict, process_scan_config["grains"]["direction"] = "below" img_dic = load_scan_data.img_dict - _, _, _, height_profiles = process_scan( + _, _, height_profiles, _, _, _ = process_scan( topostats_object=img_dic["minicircle_small"], base_dir=BASE_DIR, filter_config=process_scan_config["filter"], grains_config=process_scan_config["grains"], grainstats_config=process_scan_config["grainstats"], - dnatracing_config=process_scan_config["dnatracing"], + disordered_tracing_config=process_scan_config["disordered_tracing"], + nodestats_config=process_scan_config["nodestats"], + ordered_tracing_config=process_scan_config["ordered_tracing"], + splining_config=process_scan_config["splining"], plotting_config=process_scan_config["plotting"], output_dir=tmp_path, ) @@ -92,13 +100,16 @@ def test_process_scan_above(regtest, tmp_path, process_scan_config: dict, load_s process_scan_config["grains"]["absolute_area_threshold"]["below"] = [1, 1000000000] img_dic = load_scan_data.img_dict - _, results, img_stats, _ = process_scan( + _, results, _, img_stats, _, _ = process_scan( topostats_object=img_dic["minicircle_small"], base_dir=BASE_DIR, filter_config=process_scan_config["filter"], grains_config=process_scan_config["grains"], grainstats_config=process_scan_config["grainstats"], - dnatracing_config=process_scan_config["dnatracing"], + disordered_tracing_config=process_scan_config["disordered_tracing"], + nodestats_config=process_scan_config["nodestats"], + ordered_tracing_config=process_scan_config["ordered_tracing"], + splining_config=process_scan_config["splining"], plotting_config=process_scan_config["plotting"], output_dir=tmp_path, ) @@ -115,13 +126,16 @@ def test_process_scan_above_height_profiles(tmp_path, process_scan_config: dict, process_scan_config["grains"]["absolute_area_threshold"]["below"] = [1, 1000000000] img_dic = load_scan_data.img_dict - _, _, _, height_profiles = process_scan( + _, _, height_profiles, _, _, _ = process_scan( topostats_object=img_dic["minicircle_small"], base_dir=BASE_DIR, filter_config=process_scan_config["filter"], grains_config=process_scan_config["grains"], grainstats_config=process_scan_config["grainstats"], - dnatracing_config=process_scan_config["dnatracing"], + disordered_tracing_config=process_scan_config["disordered_tracing"], + nodestats_config=process_scan_config["nodestats"], + ordered_tracing_config=process_scan_config["ordered_tracing"], + splining_config=process_scan_config["splining"], plotting_config=process_scan_config["plotting"], output_dir=tmp_path, ) @@ -148,13 +162,16 @@ def test_process_scan_both(regtest, tmp_path, process_scan_config: dict, load_sc process_scan_config["grains"]["direction"] = "both" img_dic = load_scan_data.img_dict - _, results, img_stats, _ = process_scan( + _, results, _, img_stats, _, _ = process_scan( topostats_object=img_dic["minicircle_small"], base_dir=BASE_DIR, filter_config=process_scan_config["filter"], grains_config=process_scan_config["grains"], grainstats_config=process_scan_config["grainstats"], - dnatracing_config=process_scan_config["dnatracing"], + disordered_tracing_config=process_scan_config["disordered_tracing"], + nodestats_config=process_scan_config["nodestats"], + ordered_tracing_config=process_scan_config["ordered_tracing"], + splining_config=process_scan_config["splining"], plotting_config=process_scan_config["plotting"], output_dir=tmp_path, ) @@ -199,13 +216,16 @@ def test_save_cropped_grains( process_scan_config["plotting"]["savefig_dpi"] = 50 img_dic = load_scan_data.img_dict - _, _, _, _ = process_scan( + _, _, _, _, _, _ = process_scan( topostats_object=img_dic["minicircle_small"], base_dir=BASE_DIR, filter_config=process_scan_config["filter"], grains_config=process_scan_config["grains"], grainstats_config=process_scan_config["grainstats"], - dnatracing_config=process_scan_config["dnatracing"], + disordered_tracing_config=process_scan_config["disordered_tracing"], + nodestats_config=process_scan_config["nodestats"], + ordered_tracing_config=process_scan_config["ordered_tracing"], + splining_config=process_scan_config["splining"], plotting_config=process_scan_config["plotting"], output_dir=tmp_path, ) @@ -244,13 +264,16 @@ def test_save_format(process_scan_config: dict, load_scan_data: LoadScans, tmp_p process_scan_config["plotting"] = update_plotting_config(process_scan_config["plotting"]) img_dic = load_scan_data.img_dict - _, _, _, _ = process_scan( + _, _, _, _, _, _ = process_scan( topostats_object=img_dic["minicircle_small"], base_dir=BASE_DIR, filter_config=process_scan_config["filter"], grains_config=process_scan_config["grains"], grainstats_config=process_scan_config["grainstats"], - dnatracing_config=process_scan_config["dnatracing"], + disordered_tracing_config=process_scan_config["disordered_tracing"], + nodestats_config=process_scan_config["nodestats"], + ordered_tracing_config=process_scan_config["ordered_tracing"], + splining_config=process_scan_config["splining"], plotting_config=process_scan_config["plotting"], output_dir=tmp_path, ) @@ -263,78 +286,241 @@ def test_save_format(process_scan_config: dict, load_scan_data: LoadScans, tmp_p assert guess.extension == extension +# noqa: PLR0913 +# pylint: disable=too-many-arguments @pytest.mark.parametrize( - ("filter_run", "grains_run", "grainstats_run", "dnatracing_run", "log_msg"), + ( + "filter_run", + "grains_run", + "grainstats_run", + "disordered_tracing_run", + "nodestats_run", + "ordered_tracing_run", + "splining_run", + "log_msg", + ), [ - ( + pytest.param( + True, + True, + True, + True, + False, + True, + True, + "Splining enabled but NodeStats disabled. Tracing will use the 'old' method.", + id="Splining, Ordered Tracing, Disordered Tracing, Grainstats, Grains and Filters no Nodestats", + ), + pytest.param( + True, + True, + False, + True, + False, + True, + True, + "Splining enabled but Grainstats disabled. Please check your configuration file.", + id="Splining, Ordered Tracing, Disordered Tracing, Grains and Filters enabled but no Grainstats or " + + "NodeStats", + ), + pytest.param( + True, + False, False, False, False, True, - "DNA tracing enabled but Grainstats disabled. Please check your configuration file.", + True, + "Splining enabled but Disordered Tracing disabled. Please check your configuration file.", + id="Splining, Ordered Tracing and Filters enabled but no NodeStats, Disordered Tracing or Grainstats", ), - ( + pytest.param( + False, False, + True, + True, False, True, True, - "DNA tracing enabled but Grains disabled. Please check your configuration file.", + "Splining enabled but Grains disabled. Please check your configuration file.", + id="Splining, Ordered Tracing, Disordered Tracing, and Grainstats enabled but no NodeStats, Grains or " + + "Filters", ), - ( + pytest.param( + False, + False, + False, False, True, + False, + False, + "NodeStats enabled but Disordered Tracing disabled. Please check your configuration file.", + id="Nodestats enabled but no Ordered Tracing, Disordered Tracing, Grainstats, Grains or Filters", + ), + pytest.param( + False, + False, + False, True, True, - "DNA tracing enabled but Filters disabled. Please check your configuration file.", + False, + False, + "NodeStats enabled but Grainstats disabled. Please check your configuration file.", + id="Nodestats, Disordered Tracing enabled but no Grainstats, Grains or Filters", ), - ( + pytest.param( False, False, True, + True, + True, + False, + False, + "NodeStats enabled but Grains disabled. Please check your configuration file.", + id="Nodestats, Disordered Tracing, Grainstats enabled but no Grains or Filters", + ), + pytest.param( + False, + True, + True, + True, + True, + False, + False, + "NodeStats enabled but Filters disabled. Please check your configuration file.", + id="Nodestats, Disordered Tracing, Grainstats Grains enabled but no Filters", + ), + pytest.param( + False, + False, + False, + True, + False, + False, + False, + "Disordered Tracing enabled but Grainstats disabled. Please check your configuration file.", + id="Disordered Tracing enabled but no Grainstats, Grains or Filters", + ), + pytest.param( + False, + False, + True, + True, + False, + False, + False, + "Disordered Tracing enabled but Grains disabled. Please check your configuration file.", + id="Disordered tracing and Grainstats enabled but no Grains or Filters", + ), + pytest.param( + False, + True, + True, + True, + False, + False, + False, + "Disordered Tracing enabled but Filters disabled. Please check your configuration file.", + id="Disordered tracing, Grains and Grainstats enabled but no Filters", + ), + pytest.param( + False, + False, + True, + False, + False, + False, False, "Grainstats enabled but Grains disabled. Please check your configuration file.", + id="Grainstats enabled but no Grains or Filters", ), - ( + pytest.param( False, True, True, False, + False, + False, + False, "Grainstats enabled but Filters disabled. Please check your configuration file.", + id="Grains enabled and Grainstats but no Filters", ), - ( + pytest.param( False, True, False, False, + False, + False, + False, "Grains enabled but Filters disabled. Please check your configuration file.", + id="Grains enabled but not Filters", + ), + pytest.param( + True, + False, + False, + False, + False, + False, + False, + "Configuration run options are consistent, processing can proceed.", + id="Consistent configuration upto Filters", + ), + pytest.param( + True, + True, + False, + False, + False, + False, + False, + "Configuration run options are consistent, processing can proceed.", + id="Consistent configuration upto Grains", ), - ( + pytest.param( + True, + True, True, False, False, False, + False, "Configuration run options are consistent, processing can proceed.", + id="Consistent configuration upto Grainstats", ), - ( + pytest.param( + True, True, True, + True, + False, False, False, "Configuration run options are consistent, processing can proceed.", + id="Consistent configuration upto DNA tracing", ), - ( + pytest.param( True, True, True, + True, + True, + False, False, "Configuration run options are consistent, processing can proceed.", + id="Consistent configuration upto Nodestats", ), - ( + pytest.param( + True, + True, + True, True, True, True, True, "Configuration run options are consistent, processing can proceed.", + id="Consistent configuration upto Splining", ), ], ) @@ -342,21 +528,38 @@ def test_check_run_steps( filter_run: bool, grains_run: bool, grainstats_run: bool, - dnatracing_run: bool, + disordered_tracing_run: bool, + nodestats_run: bool, + ordered_tracing_run: bool, + splining_run: bool, log_msg: str, caplog, ) -> None: """Test the logic which checks whether enabled processing options are consistent.""" - check_run_steps(filter_run, grains_run, grainstats_run, dnatracing_run) + check_run_steps( + filter_run, grains_run, grainstats_run, disordered_tracing_run, nodestats_run, ordered_tracing_run, splining_run + ) assert log_msg in caplog.text -# noqa: disable=too-many-arguments # pylint: disable=too-many-arguments @pytest.mark.parametrize( - ("filter_run", "grains_run", "grainstats_run", "dnatracing_run", "log_msg1", "log_msg2"), + ( + "filter_run", + "grains_run", + "grainstats_run", + "disordered_tracing_run", + "nodestats_run", + "ordered_tracing_run", + "splining_run", + "log_msg1", + "log_msg2", + ), [ pytest.param( + False, + False, + False, False, False, False, @@ -370,8 +573,11 @@ def test_check_run_steps( False, False, False, - "Detection of grains disabled, returning empty data frame.", - "minicircle_small.png", + False, + False, + False, + "Detection of grains disabled, GrainStats will not be run.", + "", id="Only filtering enabled", ), pytest.param( @@ -379,8 +585,11 @@ def test_check_run_steps( True, False, False, + False, + False, + False, "Calculation of grainstats disabled, returning empty dataframe and empty height_profiles.", - "minicircle_small_above_masked.png", + "", id="Filtering and Grain enabled", ), pytest.param( @@ -388,19 +597,23 @@ def test_check_run_steps( True, True, False, + False, + False, + False, "Processing grain", - "Calculation of DNA Tracing disabled, returning grainstats data frame.", + "Calculation of Disordered Tracing disabled, returning empty dictionary.", id="Filtering, Grain and GrainStats enabled", ), - pytest.param( - True, - True, - True, - True, - "Traced grain 3 of 3", - "Combining ['above'] grain statistics and dnatracing statistics", - id="Filtering, Grain, GrainStats and DNA Tracing enabled", - ), + # @ns-rse 2024-09-13 : Parameters need updating so test is performed. + # pytest.param( + # True, + # True, + # True, + # True, + # "Traced grain 3 of 3", + # "Combining ['above'] grain statistics and dnatracing statistics", + # id="Filtering, Grain, GrainStats and DNA Tracing enabled", + # ), ], ) def test_process_stages( @@ -410,7 +623,10 @@ def test_process_stages( filter_run: bool, grains_run: bool, grainstats_run: bool, - dnatracing_run: bool, + disordered_tracing_run: bool, + nodestats_run: bool, + ordered_tracing_run: bool, + splining_run: bool, log_msg1: str, log_msg2: str, caplog, @@ -425,14 +641,20 @@ def test_process_stages( process_scan_config["filter"]["run"] = filter_run process_scan_config["grains"]["run"] = grains_run process_scan_config["grainstats"]["run"] = grainstats_run - process_scan_config["dnatracing"]["run"] = dnatracing_run - _, _, _, _ = process_scan( + process_scan_config["disordered_tracing"]["run"] = disordered_tracing_run + process_scan_config["nodestats"]["run"] = nodestats_run + process_scan_config["ordered_tracing"]["run"] = ordered_tracing_run + process_scan_config["splining"]["run"] = splining_run + _, _, _, _, _, _ = process_scan( topostats_object=img_dic["minicircle_small"], base_dir=BASE_DIR, filter_config=process_scan_config["filter"], grains_config=process_scan_config["grains"], grainstats_config=process_scan_config["grainstats"], - dnatracing_config=process_scan_config["dnatracing"], + disordered_tracing_config=process_scan_config["disordered_tracing"], + nodestats_config=process_scan_config["nodestats"], + ordered_tracing_config=process_scan_config["ordered_tracing"], + splining_config=process_scan_config["splining"], plotting_config=process_scan_config["plotting"], output_dir=tmp_path, ) @@ -446,13 +668,16 @@ def test_process_scan_no_grains(process_scan_config: dict, load_scan_data: LoadS img_dic = load_scan_data.img_dict process_scan_config["grains"]["threshold_std_dev"]["above"] = 1000 process_scan_config["filter"]["remove_scars"]["run"] = False - _, _, _, _ = process_scan( + _, _, _, _, _, _ = process_scan( topostats_object=img_dic["minicircle_small"], base_dir=BASE_DIR, filter_config=process_scan_config["filter"], grains_config=process_scan_config["grains"], grainstats_config=process_scan_config["grainstats"], - dnatracing_config=process_scan_config["dnatracing"], + disordered_tracing_config=process_scan_config["disordered_tracing"], + nodestats_config=process_scan_config["nodestats"], + ordered_tracing_config=process_scan_config["ordered_tracing"], + splining_config=process_scan_config["splining"], plotting_config=process_scan_config["plotting"], output_dir=tmp_path, ) @@ -460,38 +685,6 @@ def test_process_scan_no_grains(process_scan_config: dict, load_scan_data: LoadS assert "No grains exist for the above direction. Skipping grainstats for above." in caplog.text -def test_process_scan_align_grainstats_dnatracing( - process_scan_config: dict, load_scan_data: LoadScans, tmp_path: Path -) -> None: - """Ensure molecule numbers from dnatracing align with those from grainstats. - - Sometimes grains are removed from tracing due to small size, however we need to ensure that tracing statistics for - those molecules that remain align with grain statistics. - - By setting processing parameters as below two molecules are purged for being too small after skeletonisation and so - do not have DNA tracing statistics (but they do have Grain Statistics). - """ - img_dic = load_scan_data.img_dict - process_scan_config["filter"]["remove_scars"]["run"] = False - process_scan_config["grains"]["absolute_area_threshold"]["above"] = [150, 3000] - process_scan_config["dnatracing"]["min_skeleton_size"] = 50 - _, results, _, _ = process_scan( - topostats_object=img_dic["minicircle_small"], - base_dir=BASE_DIR, - filter_config=process_scan_config["filter"], - grains_config=process_scan_config["grains"], - grainstats_config=process_scan_config["grainstats"], - dnatracing_config=process_scan_config["dnatracing"], - plotting_config=process_scan_config["plotting"], - output_dir=tmp_path, - ) - tracing_to_check = ["contour_length", "circular", "end_to_end_distance"] - - assert results.shape == (3, 25) - assert np.isnan(results.loc[2, "contour_length"]) - assert np.isnan(sum(results.loc[2, tracing_to_check])) - - def test_run_filters(process_scan_config: dict, load_scan_data: LoadScans, tmp_path: Path) -> None: """Test the filter_wrapper function of processing.py.""" img_dict = load_scan_data.img_dict @@ -590,6 +783,7 @@ def test_run_grainstats(process_scan_config: dict, tmp_path: Path) -> None: pixel_to_nm_scaling=0.4940029296875, grain_masks=grain_masks, filename="dummy filename", + basename=RESOURCES, grainstats_config=process_scan_config["grainstats"], plotting_config=process_scan_config["plotting"], grain_out_path=tmp_path, @@ -597,40 +791,33 @@ def test_run_grainstats(process_scan_config: dict, tmp_path: Path) -> None: assert isinstance(grainstats_df, pd.DataFrame) assert grainstats_df.shape[0] == 13 - assert len(grainstats_df.columns) == 21 - - -def test_run_dnatracing(process_scan_config: dict, tmp_path: Path) -> None: - """Test the dnatracing_wrapper function of processing.py.""" - # Load flattened image - flattened_image = np.load("./tests/resources/minicircle_cropped_flattened.npy") - # Load background and foreground of above and below directions for the mask tensors - mask_above_dna = np.load("./tests/resources/minicircle_cropped_masks_above.npy") - mask_above_background = np.load("./tests/resources/minicircle_cropped_masks_above_background.npy") - mask_below_dna = np.load("./tests/resources/minicircle_cropped_masks_below.npy") - mask_below_background = np.load("./tests/resources/minicircle_cropped_masks_below_background.npy") - - # Construct the full image tensor (background class and class 1 (dna)) - tensor_above = np.stack([mask_above_background, mask_above_dna], axis=-1) - tensor_below = np.stack([mask_below_background, mask_below_dna], axis=-1) - - grain_masks = {"above": tensor_above, "below": tensor_below} - - dnatracing_df, grain_trace_data = run_dnatracing( - image=flattened_image, - grain_masks=grain_masks, - pixel_to_nm_scaling=0.4940029296875, - image_path=tmp_path, - filename="dummy filename", - core_out_path=tmp_path, - grain_out_path=tmp_path, - dnatracing_config=process_scan_config["dnatracing"], - plotting_config=process_scan_config["plotting"], - results_df=pd.read_csv("./tests/resources/minicircle_cropped_grainstats.csv"), - ) - - assert isinstance(grain_trace_data, dict) - assert list(grain_trace_data.keys()) == ["above", "below"] - assert isinstance(dnatracing_df, pd.DataFrame) - assert dnatracing_df.shape[0] == 13 - assert len(dnatracing_df.columns) == 26 + assert len(grainstats_df.columns) == 22 + + +# ns-rse 2024-09-11 : Test disabled as run_dnatracing() has been removed in refactoring, needs updating/replacing to +# reflect the revised workflow/functions. +# def test_run_dnatracing(process_scan_config: dict, tmp_path: Path) -> None: +# """Test the dnatracing_wrapper function of processing.py.""" +# flattened_image = np.load("./tests/resources/minicircle_cropped_flattened.npy") +# mask_above = np.load("./tests/resources/minicircle_cropped_masks_above.npy") +# mask_below = np.load("./tests/resources/minicircle_cropped_masks_below.npy") +# grain_masks = {"above": mask_above, "below": mask_below} + +# dnatracing_df, grain_trace_data = run_dnatracing( +# image=flattened_image, +# grain_masks=grain_masks, +# pixel_to_nm_scaling=0.4940029296875, +# image_path=tmp_path, +# filename="dummy filename", +# core_out_path=tmp_path, +# grain_out_path=tmp_path, +# dnatracing_config=process_scan_config["dnatracing"], +# plotting_config=process_scan_config["plotting"], +# results_df=pd.read_csv("./tests/resources/minicircle_cropped_grainstats.csv"), +# ) + +# assert isinstance(grain_trace_data, dict) +# assert list(grain_trace_data.keys()) == ["above", "below"] +# assert isinstance(dnatracing_df, pd.DataFrame) +# assert dnatracing_df.shape[0] == 13 +# assert len(dnatracing_df.columns) == 26 diff --git a/tests/test_utils.py b/tests/test_utils.py index f373fa3ea3..a49ec74bf4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,6 +10,7 @@ ALL_STATISTICS_COLUMNS, bound_padded_coordinates_to_image, convert_path, + convolve_skeleton, create_empty_dataframe, get_thresholds, update_config, @@ -83,6 +84,7 @@ def test_update_plotting_config( "extracted_channel": { "filename": "00-raw_heightmap", "image_type": "non-binary", + "core_set": False, "savefig_dpi": 100, } }, @@ -96,6 +98,7 @@ def test_update_plotting_config( "filename": "00-raw_heightmap", "image_type": "non-binary", "savefig_dpi": 100, + "core_set": False, "pixel_interpolation": None, } }, @@ -112,6 +115,7 @@ def test_update_plotting_config( "filename": "00-raw_heightmap", "image_type": "non-binary", "savefig_dpi": 100, + "core_set": False, } }, }, @@ -124,6 +128,7 @@ def test_update_plotting_config( "filename": "00-raw_heightmap", "image_type": "non-binary", "savefig_dpi": 600, + "core_set": False, "pixel_interpolation": None, } }, @@ -181,8 +186,8 @@ def test_create_empty_dataframe() -> None: """Test the empty dataframe is created correctly.""" empty_df = create_empty_dataframe(ALL_STATISTICS_COLUMNS) - assert empty_df.index.name == "molecule_number" - assert "molecule_number" not in empty_df.columns + assert empty_df.index.name == "grain_number" + assert "grain_number" not in empty_df.columns assert empty_df.shape == (0, 26) assert {"image", "basename", "area"}.intersection(empty_df.columns) @@ -324,3 +329,353 @@ def test_bound_padded_coordinates_to_image(image: npt.NDArray, padding: int, exp for coordinate in coordinates: padded_coords = bound_padded_coordinates_to_image(coordinate, padding, image.shape) assert padded_coords == expected + + +@pytest.mark.parametrize( + ("skeleton", "target"), + [ + pytest.param( + np.asarray( + [ + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 0, 1, 0, 0], + [0, 1, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 0, 1, 0, 0], + [0, 1, 0, 0, 0], + ], + ), + id="Simple skeleton, no points to convolute", + ), + pytest.param( + np.asarray( + [ + [0, 0, 0], + [0, 1, 1], + [0, 1, 1], + ] + ), + np.asarray( + [ + [0, 0, 0], + [0, 3, 3], + [0, 3, 3], + ], + ), + id="Corner cluster", + ), + pytest.param( + np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 1, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 0, 3, 0, 0], + [0, 3, 3, 3, 0], + [0, 0, 3, 0, 0], + [0, 0, 0, 0, 0], + ], + ), + id="Small Cross", + ), + pytest.param( + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 3, 0, 0, 0], + [0, 2, 3, 3, 3, 2, 0], + [0, 0, 0, 3, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ], + ), + id="Bigger Cross", + ), + pytest.param( + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 3, 0, 0, 0], + [0, 2, 3, 3, 3, 2, 0], + [0, 0, 0, 3, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ], + ), + id="Large Cross", + ), + ], +) +def test_convolve_skeleton(skeleton: npt.NDArray, target: npt.NDArray) -> None: + """Test convolve_skeleton() function.""" + skeleton_convolved = convolve_skeleton(skeleton) + print(f"{skeleton_convolved=}") + np.testing.assert_array_equal(skeleton_convolved, target) + + +@pytest.mark.parametrize( + ("skeleton", "target"), + [ + pytest.param( + "utils_skeleton_linear1", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Skeleton 1 - Linear Skeleton", + ), + pytest.param( + "utils_skeleton_linear2", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 3, 3, 3, 1, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Skeleton 2 - T-junction and side-branch", + ), + pytest.param( + "utils_skeleton_linear3", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Skeleton 3 - Linear Skeleton with several branches", + ), + ], +) +def test_convolve_skeleton_random(skeleton: npt.NDArray, target: npt.NDArray, request) -> None: + """Test convolve_skeleton() function with random skeletons from conftest.py.""" + skeleton_convolved = convolve_skeleton(request.getfixturevalue(skeleton)) + np.testing.assert_array_equal(skeleton_convolved, target) + + +@pytest.mark.skip(reason="awaiting test development") +# @pytest.mark.parametrize() +def test_coords2_img() -> None: + """Test coords2_img() function.""" diff --git a/tests/tracing/conftest.py b/tests/tracing/conftest.py new file mode 100644 index 0000000000..aaf8668469 --- /dev/null +++ b/tests/tracing/conftest.py @@ -0,0 +1,190 @@ +"""Fixtures for the tracing tests.""" + +from pathlib import Path + +import numpy as np +import numpy.typing as npt +import pandas as pd +import pytest + +from topostats.tracing.nodestats import nodeStats +from topostats.tracing.skeletonize import getSkeleton, topostatsSkeletonize + +# This is required because of the inheritance used throughout +# pylint: disable=redefined-outer-name + +BASE_DIR = Path.cwd() +RESOURCES = BASE_DIR / "tests" / "resources" + +RNG = np.random.default_rng(seed=1000) + +# Derive fixtures for DNA Tracing +GRAINS = np.array( + [ + [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2], + [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 2], + [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 2], + [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2], + [0, 0, 3, 3, 3, 3, 3, 0, 0, 0, 2], + [0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 2], + [0, 0, 3, 3, 3, 3, 3, 0, 0, 0, 2], + [0, 0, 4, 4, 4, 4, 4, 0, 0, 0, 2], + ] +) +FULL_IMAGE = RNG.random((GRAINS.shape[0], GRAINS.shape[1])) + + +# DNA Tracing Fixtures +@pytest.fixture() +def minicircle_all_statistics() -> pd.DataFrame: + """Statistics for minicricle.""" + return pd.read_csv(RESOURCES / "minicircle_default_all_statistics.csv", header=0) + + +# Skeletonizing Fixtures +@pytest.fixture() +def skeletonize_get_skeleton() -> getSkeleton: + """Instantiate a getSkeleton object.""" + return getSkeleton(image=None, mask=None) + + +@pytest.fixture() +def skeletonize_circular() -> np.ndarray: + """Circular molecule for testing skeletonizing.""" + return np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 3, 2, 2, 2, 2, 2, 2, 2, 3, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 3, 2, 1, 1, 1, 1, 1, 2, 3, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 3, 2, 1, 0, 0, 0, 1, 2, 3, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 3, 2, 1, 0, 0, 0, 1, 2, 3, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 3, 2, 1, 0, 0, 0, 1, 2, 3, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 3, 2, 1, 1, 1, 1, 1, 2, 3, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 3, 2, 2, 2, 2, 2, 2, 2, 3, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 1, 0, 0], + [0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ) + + +@pytest.fixture() +def skeletonize_circular_bool_int(skeletonize_circular: np.ndarray) -> np.ndarray: + """Circular molecule for testing skeletonizing as a boolean integer array.""" + return np.array(skeletonize_circular, dtype="bool").astype(int) + + +@pytest.fixture() +def skeletonize_linear() -> np.ndarray: + """Linear molecule for testing skeletonizing.""" + return np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 2, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 3, 3, 4, 3, 2, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3, 3, 4, 4, 3, 2, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 4, 4, 3, 3, 2, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 3, 3, 2, 2, 1, 0, 0], + [0, 0, 0, 0, 1, 2, 2, 2, 3, 3, 3, 4, 4, 3, 2, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 2, 3, 3, 3, 4, 4, 4, 3, 3, 2, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 3, 4, 4, 4, 3, 3, 3, 2, 2, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 3, 4, 4, 3, 3, 2, 2, 2, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 3, 4, 3, 3, 2, 2, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 3, 4, 3, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 2, 3, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 2, 2, 3, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 3, 3, 3, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 3, 4, 4, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 3, 3, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ) + + +@pytest.fixture() +def skeletonize_linear_bool_int(skeletonize_linear) -> np.ndarray: + """Linear molecule for testing skeletonizing as a boolean integer array.""" + return np.array(skeletonize_linear, dtype="bool").astype(int) + + +@pytest.fixture() +def topostats_skeletonise(skeletonize_circular, skeletonize_circular_bool_int): + """TopostatsSkeletonise for testing individual functions.""" + return topostatsSkeletonize(skeletonize_circular, skeletonize_circular_bool_int, 0.6) + + +@pytest.fixture() +def catenane_image() -> npt.NDArray[np.number]: + """Image of a catenane molecule.""" + return np.load(RESOURCES / "catenane_image.npy") + + +@pytest.fixture() +def catenane_node_centre_mask() -> npt.NDArray[np.int32]: + """ + Catenane node centre mask. + + Effectively just the skeleton, but + with the nodes set to 2 while the skeleton is 1 and background is 0. + """ + return np.load(RESOURCES / "catenane_node_centre_mask.npy") + + +@pytest.fixture() +def catenane_connected_nodes() -> npt.NDArray[np.int32]: + """ + Return connected nodes of the catenane test image. + + Effectively just the skeleton, but with the extended nodes + set to 2 while the skeleton is 1 and background is 0. + """ + return np.load(RESOURCES / "catenane_connected_nodes.npy") + + +@pytest.fixture() +def nodestats_catenane( + catenane_image: npt.NDArray[np.number], +) -> nodeStats: + """Fixture for the nodeStats object for a catenated molecule, to be used in analyse_nodes.""" + catenane_smoothed_mask: npt.NDArray[np.bool_] = np.load(RESOURCES / "catenane_smoothed_mask.npy") + catenane_skeleton: npt.NDArray[np.bool_] = np.load(RESOURCES / "catenane_skeleton.npy") + catenane_node_centre_mask = np.load(RESOURCES / "catenane_node_centre_mask.npy") + catenane_connected_nodes = np.load(RESOURCES / "catenane_connected_nodes.npy") + + # Create a nodestats object + nodestats = nodeStats( + filename="test_catenane", + image=catenane_image, + mask=catenane_smoothed_mask, + smoothed_mask=catenane_smoothed_mask, + skeleton=catenane_skeleton, + pixel_to_nm_scaling=np.float64(0.18124609375), + n_grain=1, + node_joining_length=7, + node_extend_dist=14.0, + branch_pairing_length=20.0, + pair_odd_branches=True, + ) + + nodestats.node_centre_mask = catenane_node_centre_mask + nodestats.connected_nodes = catenane_connected_nodes + nodestats.skeleton = catenane_skeleton + + return nodestats diff --git a/tests/tracing/test_disordered_tracing.py b/tests/tracing/test_disordered_tracing.py new file mode 100644 index 0000000000..dd3cb30a6f --- /dev/null +++ b/tests/tracing/test_disordered_tracing.py @@ -0,0 +1,1446 @@ +# Disable ruff 301 - pickle loading is unsafe but we don't care for tests +# ruff: noqa: S301 +"""Test the disordered tracing module.""" + +import pickle as pkl +from pathlib import Path + +import numpy as np +import numpy.typing as npt +import pandas as pd +import pytest + +from topostats.io import dict_almost_equal # pylint: disable=no-name-in-module import-error +from topostats.tracing.disordered_tracing import crop_array, disordered_trace_grain, trace_image_disordered + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-locals +# pylint: disable=too-many-lines +# pylint: disable=unspecified-encoding +# pylint: disable=too-many-positional-arguments + +BASE_DIR = Path.cwd() +DISORDERED_TRACING_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "disordered_tracing" +GENERAL_RESOURCES = BASE_DIR / "tests" / "resources" + +TEST_LABELLED = np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 1, 0, 0, 2, 2, 2, 2, 2, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0], + [0, 3, 3, 3, 3, 3, 3, 0, 0, 2, 0, 0, 0, 2, 0], + [0, 0, 0, 0, 0, 0, 3, 0, 0, 2, 0, 0, 0, 2, 0], + [0, 0, 0, 0, 0, 0, 3, 0, 0, 2, 2, 2, 2, 2, 0], + [0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 3, 0, 4, 4, 4, 4, 4, 4, 0], + [0, 0, 0, 0, 0, 0, 3, 0, 4, 4, 4, 4, 4, 4, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 5, 5, 5, 5, 6, 6, 6, 6, 0, 0, 6, 0, 0, 0], + [0, 5, 5, 0, 0, 6, 0, 0, 6, 0, 0, 6, 0, 0, 0], + [0, 5, 5, 5, 5, 6, 0, 0, 6, 6, 6, 6, 6, 6, 0], + [0, 0, 0, 5, 5, 6, 6, 6, 6, 0, 0, 6, 0, 0, 0], + [0, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) + + +@pytest.mark.parametrize( + ("bounding_box", "target", "pad_width"), + [ + pytest.param( + (1, 1, 2, 7), + np.asarray( + [ + [1, 1, 1, 1, 1, 1], + ] + ), + 0, + id="Zero padding", + ), + pytest.param( + (1, 1, 2, 7), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + 1, + id="Single pixel padding", + ), + pytest.param( + (1, 1, 2, 7), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 3, 3, 3, 3, 3, 3, 0, 0], + ] + ), + 2, + id="Two pixel padding", + ), + pytest.param( + (1, 9, 6, 14), + np.asarray( + [ + [2, 2, 2, 2, 2], + [2, 0, 0, 0, 2], + [2, 0, 0, 0, 2], + [2, 0, 0, 0, 2], + [2, 2, 2, 2, 2], + ] + ), + 0, + id="Ring with zero padding", + ), + pytest.param( + (3, 1, 9, 7), + np.asarray( + [ + [3, 3, 3, 3, 3, 3], + [0, 0, 0, 0, 0, 3], + [0, 0, 0, 0, 0, 3], + [0, 0, 0, 0, 0, 3], + [0, 0, 0, 0, 0, 3], + [0, 0, 0, 0, 0, 3], + ] + ), + 0, + id="L with zero padding", + ), + pytest.param( + (7, 8, 9, 14), + np.asarray( + [ + [4, 4, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 4], + ] + ), + 0, + id="solid with zero padding", + ), + pytest.param( + (10, 1, 15, 5), + np.asarray( + [ + [5, 5, 5, 5], + [5, 5, 0, 0], + [5, 5, 5, 5], + [0, 0, 5, 5], + [5, 5, 5, 5], + ] + ), + 0, + id="small area with zero padding", + ), + pytest.param( + (10, 5, 14, 14), + np.asarray( + [ + [6, 6, 6, 6, 0, 0, 6, 0, 0], + [6, 0, 0, 6, 0, 0, 6, 0, 0], + [6, 0, 0, 6, 6, 6, 6, 6, 6], + [6, 6, 6, 6, 0, 0, 6, 0, 0], + ] + ), + 0, + id="larger area with zero pixel padding", + ), + pytest.param( + (10, 5, 14, 14), + np.asarray( + [ + [0, 0, 0, 3, 0, 4, 4, 4, 4, 4, 4, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [5, 5, 6, 6, 6, 6, 0, 0, 6, 0, 0, 0], + [0, 0, 6, 0, 0, 6, 0, 0, 6, 0, 0, 0], + [5, 5, 6, 0, 0, 6, 6, 6, 6, 6, 6, 0], + [5, 5, 6, 6, 6, 6, 0, 0, 6, 0, 0, 0], + [5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + 2, + id="larger area with two pixel padding", + ), + ], +) +def test_crop_array(bounding_box: tuple, target: np.array, pad_width: int) -> None: + """Test the cropping of images.""" + cropped = crop_array(TEST_LABELLED, bounding_box, pad_width) + np.testing.assert_array_equal(cropped, target) + + +@pytest.mark.parametrize( + ( + "cropped_image", + "cropped_mask", + "pixel_to_nm_scaling", + "mask_smoothing_params", + "skeletonisation_params", + "pruning_params", + "filename", + "min_skeleton_size", + "expected_smoothed_grain", + "expected_skeleton", + "expected_pruned_skeleton", + "expected_branch_types", + ), + [ + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 2, 1, 0, 0, 0, 0], + [0, 0, 1, 2, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [0, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": None, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="simple slightly curved line", + ), + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.bool_), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [0, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.2, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": None, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.float32, + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + id="test height bias: thick curve height weighting outer, strong height bias", + ), + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.bool_), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [0, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.8, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": None, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.float32, + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + id="test height bias: thick curve height weighting outer, weak height bias", + ), + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.bool_), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [0, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": 2, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int32, + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + id="test pruning: thick curve with tail, no pruning", + ), + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.bool_), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [0, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": 10, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int32, + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + id="test pruning: thick curve with tail, prune small branch", + ), + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.bool_), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [3, 5], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": None, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int32, + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 0, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 3, 0, 0, 0, 3, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 3, 3, 3, 3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + id="test re-add holes, 3 holes, one right size for re-adding", + ), + ], +) +def test_disordered_trace_grain( + cropped_image: npt.NDArray[np.float32], + cropped_mask: npt.NDArray[np.bool_], + pixel_to_nm_scaling: float, + mask_smoothing_params: dict, + skeletonisation_params: dict, + pruning_params: dict, + filename: str, + min_skeleton_size: int, + expected_smoothed_grain: npt.NDArray[np.bool_], + expected_skeleton: npt.NDArray[np.bool_], + expected_pruned_skeleton: npt.NDArray[np.bool_], + expected_branch_types: npt.NDArray[np.int32], +) -> None: + """Test the disorderedTrace() method.""" + result_dict = disordered_trace_grain( + cropped_image=cropped_image, + cropped_mask=cropped_mask, + pixel_to_nm_scaling=pixel_to_nm_scaling, + mask_smoothing_params=mask_smoothing_params, + skeletonisation_params=skeletonisation_params, + pruning_params=pruning_params, + filename=filename, + min_skeleton_size=min_skeleton_size, + ) + + result_smoothed_grain = result_dict["smoothed_grain"] + result_skeleton = result_dict["skeleton"] + result_pruned_skeleton = result_dict["pruned_skeleton"] + result_branch_types = result_dict["branch_types"] + + np.testing.assert_array_equal(result_smoothed_grain, expected_smoothed_grain) + np.testing.assert_array_equal(result_skeleton, expected_skeleton) + np.testing.assert_array_equal(result_pruned_skeleton, expected_pruned_skeleton) + np.testing.assert_array_equal(result_branch_types, expected_branch_types) + + +@pytest.mark.parametrize( + ( + "image_filename", + "mask_filename", + "pixel_to_nm_scaling", + "min_skeleton_size", + "mask_smoothing_params", + "skeletonisation_params", + "pruning_params", + "expected_disordered_crop_data_filename", + "expected_disordered_tracing_grainstats_filename", + "expected_all_images_filename", + "expected_disordered_tracing_stats_filename", + ), + [ + pytest.param( + "example_catenanes.npy", + "example_catenanes_labelled_grain_mask_thresholded.npy", + # Pixel to nm scaling + 0.488, + # Min skeleton size + 10, + # Mask smoothing parameters + { + "gaussian_sigma": 2, + "dilation_iterations": 2, + "holearea_min_max": [10, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": 7.0, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + "catenanes_disordered_tracing_crop_data.pkl", + "catenanes_disordered_tracing_grainstats.csv", + "catenanes_disordered_tracing_all_images.pkl", + "catenanes_disordered_tracing_stats.csv", + id="catenane", + ), + pytest.param( + "example_rep_int.npy", + "example_rep_int_labelled_grain_mask_thresholded.npy", + # Pixel to nm scaling + 0.488, + # Min skeleton size + 10, + # Mask smoothing parameters + { + "gaussian_sigma": 2, + "dilation_iterations": 2, + "holearea_min_max": [10, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": 20.0, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + "rep_int_disordered_tracing_crop_data.pkl", + "rep_int_disordered_tracing_grainstats.csv", + "rep_int_disordered_tracing_all_images.pkl", + "rep_int_disordered_tracing_stats.csv", + id="replication intermediate", + ), + ], +) +def test_trace_image_disordered( + image_filename: str, + mask_filename: str, + pixel_to_nm_scaling: float, + min_skeleton_size: int, + mask_smoothing_params: dict, + skeletonisation_params: dict, + pruning_params: dict, + expected_disordered_crop_data_filename: str, + expected_disordered_tracing_grainstats_filename: str, + expected_all_images_filename: str, + expected_disordered_tracing_stats_filename: str, +) -> None: + """Test the trace image disordered method.""" + # Load the image + image = np.load(GENERAL_RESOURCES / image_filename) + mask = np.load(GENERAL_RESOURCES / mask_filename) + + ( + result_disordered_crop_data, + result_disordered_tracing_grainstats, + result_all_images, + result_disordered_tracing_stats, + ) = trace_image_disordered( + image=image, + grains_mask=mask, + filename="test_image", + pixel_to_nm_scaling=pixel_to_nm_scaling, + min_skeleton_size=min_skeleton_size, + mask_smoothing_params=mask_smoothing_params, + skeletonisation_params=skeletonisation_params, + pruning_params=pruning_params, + pad_width=1, + ) + + # DEBUGGING CODE + # Turning sub-structures into variables to be able to be inspected + # variable_smoothed_grain = result_all_images["smoothed_grain"] + # variable_skeleton = result_all_images["skeleton"] + # variable_pruned_skeleton = result_all_images["pruned_skeleton"] + # variable_branch_types = result_all_images["branch_types"] + + # Update expected values - CHECK RESULTS WITH EXPERT BEFORE UPDATING + # Pickle result_disordered_crop_data + # with open(DISORDERED_TRACING_RESOURCES / expected_disordered_crop_data_filename, "wb") as f: + # pkl.dump(result_disordered_crop_data, f) + + # # Save result_disordered_tracing_grainstats as a csv + # result_disordered_tracing_grainstats.to_csv( + # DISORDERED_TRACING_RESOURCES / expected_disordered_tracing_grainstats_filename + # ) + + # # Save result_all_images as a pickle + # with open(DISORDERED_TRACING_RESOURCES / expected_all_images_filename, "wb") as f: + # pkl.dump(result_all_images, f) + + # # Save result_disordered_tracing_stats dataframe as a csv + # result_disordered_tracing_stats.to_csv(DISORDERED_TRACING_RESOURCES / expected_disordered_tracing_stats_filename) + + # Load expected values + with Path.open(DISORDERED_TRACING_RESOURCES / expected_disordered_crop_data_filename, "rb") as f: + expected_disordered_crop_data = pkl.load(f) + + expected_disordered_tracing_grainstats = pd.read_csv( + DISORDERED_TRACING_RESOURCES / expected_disordered_tracing_grainstats_filename, index_col=0 + ) + + with Path.open(DISORDERED_TRACING_RESOURCES / expected_all_images_filename, "rb") as f: + expected_all_images = pkl.load(f) + + expected_disordered_tracing_stats = pd.read_csv( + DISORDERED_TRACING_RESOURCES / expected_disordered_tracing_stats_filename, index_col=0 + ) + + assert dict_almost_equal(result_disordered_crop_data, expected_disordered_crop_data, abs_tol=1e-11) + pd.testing.assert_frame_equal(result_disordered_tracing_grainstats, expected_disordered_tracing_grainstats) + assert dict_almost_equal(result_all_images, expected_all_images, abs_tol=1e-11) + pd.testing.assert_frame_equal(result_disordered_tracing_stats, expected_disordered_tracing_stats) diff --git a/tests/tracing/test_dnatracing_methods.py b/tests/tracing/test_dnatracing_methods.py deleted file mode 100644 index 06f4addfc7..0000000000 --- a/tests/tracing/test_dnatracing_methods.py +++ /dev/null @@ -1,785 +0,0 @@ -"""Additional tests of dnaTracing methods.""" - -from pathlib import Path - -import numpy as np -import pytest - -from topostats.tracing.dnatracing import crop_array, dnaTrace, round_splined_traces -from topostats.tracing.tracingfuncs import reorderTrace - -# This is required because of the inheritance used throughout -# pylint: disable=redefined-outer-name -BASE_DIR = Path.cwd() -RESOURCES = BASE_DIR / "tests" / "resources" -PIXEL_SIZE = 0.4940029296875 - -LINEAR_IMAGE = np.load(RESOURCES / "dnatracing_image_linear.npy") -LINEAR_MASK = np.load(RESOURCES / "dnatracing_mask_linear.npy") -CIRCULAR_IMAGE = np.load(RESOURCES / "dnatracing_image_circular.npy") -CIRCULAR_MASK = np.load(RESOURCES / "dnatracing_mask_circular.npy") -MIN_SKELETON_SIZE = 10 - - -@pytest.fixture() -def dnatrace() -> dnaTrace: - """dnaTrace object for use in tests.""" # noqa: D403 - return dnaTrace( - image=np.asarray([[1]]), - grain=None, - filename="test.spm", - pixel_to_nm_scaling=PIXEL_SIZE, - min_skeleton_size=MIN_SKELETON_SIZE, - ) - - -@pytest.fixture() -def dnatrace_spline(dnatrace: dnaTrace) -> dnaTrace: - """Instantiate a dnaTrace object for splining tests.""" - dnatrace.pixel_to_nm_scaling = 1.0 - dnatrace.n_grain = 1 - return dnatrace - - -GRAINS = {} -GRAINS["vertical"] = np.asarray( - [ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], - ] -) -GRAINS["horizontal"] = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) -GRAINS["diagonal1"] = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 0, 0, 0, 0], - [0, 0, 1, 1, 0, 0, 0, 0, 0], - [0, 1, 1, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) -GRAINS["diagonal2"] = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) -GRAINS["diagonal3"] = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 0, 0], - [0, 0, 0, 1, 1, 0, 0, 0, 0], - [0, 0, 1, 1, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) -GRAINS["circle"] = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0], - [0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) -GRAINS["blob"] = np.asarray( - [ - [0, 0, 0, 0, 0], - [0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [0, 0, 0, 0, 0], - ] -) -GRAINS["cross"] = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0], - [0, 1, 1, 1, 1, 1, 0], - [0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - ] -) -GRAINS["single_L"] = np.asarray( - [ - [0, 0, 0, 0], - [0, 0, 1, 0], - [0, 0, 1, 0], - [0, 0, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0], - ] -) -GRAINS["double_L"] = np.asarray( - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 1, 0], - [0, 0, 1, 1, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 1, 1, 0, 0], - [0, 0, 0, 0, 0], - ] -) -GRAINS["diagonal_end_single_L"] = np.asarray( - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 1, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 1, 1, 0, 0], - [0, 0, 0, 0, 0], - ] -) -GRAINS["diagonal_end_straight"] = np.asarray( - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 1, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], - ] -) -GRAINS["figure8"] = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 1, 0, 0, 1, 0], - [0, 1, 1, 1, 1, 1, 1, 1, 0], - [0, 1, 0, 0, 1, 0, 0, 0, 0], - [0, 1, 0, 0, 1, 0, 0, 0, 0], - [0, 1, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) -GRAINS["three_ends"] = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 1, 0, 0, 1, 0], - [0, 1, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) -GRAINS["six_ends"] = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 1, 0, 0, 1, 0], - [0, 1, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 1, 1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) - - -@pytest.mark.parametrize( - ("grain", "mol_is_circular"), - [ - (GRAINS["vertical"], False), - (GRAINS["horizontal"], False), - (GRAINS["diagonal1"], True), # This is wrong, this IS a linear molecule - (GRAINS["diagonal2"], False), - (GRAINS["diagonal3"], False), - (GRAINS["circle"], True), - (GRAINS["blob"], True), - (GRAINS["cross"], False), - (GRAINS["single_L"], False), - (GRAINS["double_L"], True), # This is wrong, this IS a linear molecule - (GRAINS["diagonal_end_single_L"], False), - (GRAINS["diagonal_end_straight"], False), - (GRAINS["figure8"], True), - (GRAINS["three_ends"], False), - (GRAINS["six_ends"], False), - ], -) -def test_linear_or_circular(dnatrace, grain: np.ndarray, mol_is_circular: bool) -> None: - """Test the linear_or_circular method with a range of different structures.""" - linear_coordinates = np.argwhere(grain == 1) - dnatrace.linear_or_circular(linear_coordinates) - assert dnatrace.mol_is_circular == mol_is_circular - - -TEST_LABELLED = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 1, 1, 1, 1, 0, 0, 2, 2, 2, 2, 2, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 2, 0], - [0, 3, 3, 3, 3, 3, 3, 0, 0, 2, 0, 0, 0, 2, 0], - [0, 0, 0, 0, 0, 0, 3, 0, 0, 2, 0, 0, 0, 2, 0], - [0, 0, 0, 0, 0, 0, 3, 0, 0, 2, 2, 2, 2, 2, 0], - [0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 3, 0, 4, 4, 4, 4, 4, 4, 0], - [0, 0, 0, 0, 0, 0, 3, 0, 4, 4, 4, 4, 4, 4, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 5, 5, 5, 5, 6, 6, 6, 6, 0, 0, 6, 0, 0, 0], - [0, 5, 5, 0, 0, 6, 0, 0, 6, 0, 0, 6, 0, 0, 0], - [0, 5, 5, 5, 5, 6, 0, 0, 6, 6, 6, 6, 6, 6, 0], - [0, 0, 0, 5, 5, 6, 6, 6, 6, 0, 0, 6, 0, 0, 0], - [0, 5, 5, 5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) - - -@pytest.mark.parametrize( - ("bounding_box", "target", "pad_width"), - [ - ( - (1, 1, 2, 7), - np.asarray( - [ - [1, 1, 1, 1, 1, 1], - ] - ), - 0, - ), - ( - (1, 1, 2, 7), - np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - 1, - ), - ( - (1, 1, 2, 7), - np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 1, 1, 1, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 3, 3, 3, 3, 3, 3, 0, 0], - ] - ), - 2, - ), - ( - (1, 9, 6, 14), - np.asarray( - [ - [2, 2, 2, 2, 2], - [2, 0, 0, 0, 2], - [2, 0, 0, 0, 2], - [2, 0, 0, 0, 2], - [2, 2, 2, 2, 2], - ] - ), - 0, - ), - ( - (3, 1, 9, 7), - np.asarray( - [ - [3, 3, 3, 3, 3, 3], - [0, 0, 0, 0, 0, 3], - [0, 0, 0, 0, 0, 3], - [0, 0, 0, 0, 0, 3], - [0, 0, 0, 0, 0, 3], - [0, 0, 0, 0, 0, 3], - ] - ), - 0, - ), - ( - (7, 8, 9, 14), - np.asarray( - [ - [4, 4, 4, 4, 4, 4], - [4, 4, 4, 4, 4, 4], - ] - ), - 0, - ), - ( - (10, 1, 15, 5), - np.asarray( - [ - [5, 5, 5, 5], - [5, 5, 0, 0], - [5, 5, 5, 5], - [0, 0, 5, 5], - [5, 5, 5, 5], - ] - ), - 0, - ), - ( - (10, 5, 14, 14), - np.asarray( - [ - [6, 6, 6, 6, 0, 0, 6, 0, 0], - [6, 0, 0, 6, 0, 0, 6, 0, 0], - [6, 0, 0, 6, 6, 6, 6, 6, 6], - [6, 6, 6, 6, 0, 0, 6, 0, 0], - ] - ), - 0, - ), - ( - (10, 5, 14, 14), - np.asarray( - [ - [0, 0, 0, 3, 0, 4, 4, 4, 4, 4, 4, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [5, 5, 6, 6, 6, 6, 0, 0, 6, 0, 0, 0], - [0, 0, 6, 0, 0, 6, 0, 0, 6, 0, 0, 0], - [5, 5, 6, 0, 0, 6, 6, 6, 6, 6, 6, 0], - [5, 5, 6, 6, 6, 6, 0, 0, 6, 0, 0, 0], - [5, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - 2, - ), - ], -) -def test_crop_array(bounding_box: tuple, target: np.array, pad_width: int) -> None: - """Test the cropping of images.""" - cropped = crop_array(TEST_LABELLED, bounding_box, pad_width) - print(f"cropped :\n{cropped}") - np.testing.assert_array_equal(cropped, target) - - -@pytest.mark.parametrize( - ("trace_image", "step_size_m", "mol_is_circular", "smoothness", "expected_spline_image"), - # Note that some images here have gaps in them. This is NOT an issue with the splining, but - # a natural result of rounding the floating point coordinates to integers. We only do this - # to provide a visualisation for easy test comparison and does not affect results. - [ - # Test circular with no smoothing. The shape remains unchanged. The gap is just a result - # of rounding the floating point coordinates to integers for this neat visualisation - # and can be ignored. - ( - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 1, 1, 1, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 1, 1, 1, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - 1.0, - True, - 0.0, - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 1, 0, 1, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 1, 1, 1, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - ), - # Test circular with smoothing. The shape is smoothed out. - ( - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 1, 1, 1, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 1, 1, 1, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - 1.0, - True, - 10.0, - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 1, 1, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 1, 1, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 1, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 1, 1, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - ), - # Another circular with no smoothing. Relatively unchanged - ( - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 1, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 1, 0, 1, 0, 1, 0, 1, 0], - [0, 0, 0, 1, 0, 1, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - 10.0, - True, - 0.0, - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 1, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 1, 0, 1, 0, 1, 0, 1, 0], - [0, 0, 0, 1, 0, 1, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - ), - # Another circular with smoothing. The shape is smoothed out. - ( - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 1, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 1, 0, 1, 0, 1, 0, 1, 0], - [0, 0, 0, 1, 0, 1, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - 10.0, - True, - 5.0, - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 1, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 0], - [0, 0, 0, 1, 1, 1, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - ), - # Simple line with smoothing unchanged because too simple. - ( - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - 5.0, - False, - 5.0, - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - ), - # Another line with smoothing unchanged because too simple. - ( - np.array( - [ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], - ] - ), - 4.0, - False, - 5.0, - np.array( - [ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], - ] - ), - ), - # Complex line without smoothing, unchanged. - ( - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - 5.0, - False, - 0.0, - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - ), - # Complex line with smoothing, smoothed out. - ( - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 1, 0, 0], - [0, 0, 1, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - 5.0, - False, - 5.0, - np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 1, 0, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 1, 0, 0], - [0, 0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - ), - ], -) -def test_get_splined_traces( - dnatrace_spline: dnaTrace, - trace_image: np.ndarray, - step_size_m: float, - mol_is_circular: bool, - smoothness: float, - expected_spline_image: np.ndarray, -) -> None: - """Test the get_splined_traces() function of dnatracing.py.""" - # For development visualisations - keep this in for future use - # plt.imsave('./fitted_trace.png', trace_image) - - # Obtain the coords of the pixels from our test image - trace_coords = np.argwhere(trace_image == 1) - - # Get an ordered trace from our test trace images - if mol_is_circular: - fitted_trace, _whether_trace_completed = reorderTrace.circularTrace(trace_coords) - else: - fitted_trace = reorderTrace.linearTrace(trace_coords) - - # Set dnatrace smoothness accordingly - if mol_is_circular: - dnatrace_spline.spline_circular_smoothing = smoothness - else: - dnatrace_spline.spline_linear_smoothing = smoothness - - # Generate splined trace - dnatrace_spline.fitted_trace = fitted_trace - dnatrace_spline.step_size_m = step_size_m - # Fixed pixel to nm scaling since changing this is redundant due to the internal effect being linked to - # the step_size_m divided by this value, so changing both doesn't make sense. - dnatrace_spline.mol_is_circular = mol_is_circular - - # Spline the traces - dnatrace_spline.get_splined_traces() - - # Extract the splined trace - splined_trace = dnatrace_spline.splined_trace - - # This is just for easier human-readable tests. Turn the splined coords into a visualisation. - splined_image = np.zeros_like(trace_image) - splined_image[np.round(splined_trace[:, 0]).astype(int), np.round(splined_trace[:, 1]).astype(int)] = 1 - - # For development visualisations - keep this in for future use - # plt.imsave(f'./test_splined_image_{splined_image.shape}.png', splined_image) - - np.testing.assert_array_equal(splined_image, expected_spline_image) - - -def test_round_splined_traces(): - """Test the round splined traces function of dnatracing.py.""" - splined_traces = {"0": np.array([[1.2, 2.3], [3.4, 4.5]]), "1": None, "2": np.array([[5.6, 6.7], [7.8, 8.9]])} - expected_result = {"0": np.array([[1, 2], [3, 4]]), "1": None, "2": np.array([[6, 7], [8, 9]])} - result = round_splined_traces(splined_traces) - np.testing.assert_equal(result, expected_result) - - -@pytest.mark.parametrize( - ("tuple_list", "expected_result"), - [ - ( - [(1, 2, 3), (1, 2, 3), (1, 2, 3), (1, 2, 3)], - [(1, 2, 3)], - ), - ( - [(1, 2, 3), (1, 2, 3), (4, 5, 6), (4, 5, 6), (7, 8, 9), (10, 11, 12), (10, 11, 12)], - [(1, 2, 3), (4, 5, 6), (7, 8, 9), (10, 11, 12)], - ), - ([np.array((1, 2, 3)), np.array((1, 2, 3)), np.array((1, 2, 3)), np.array((1, 2, 3))], [(1, 2, 3)]), - ], -) -def test_remove_duplicate_consecutive_tuples(tuple_list: list[tuple], expected_result: list[tuple]) -> None: - """Test the remove_duplicate_consecutive_tuples function of dnatracing.py.""" - result = dnaTrace.remove_duplicate_consecutive_tuples(tuple_list) - - np.testing.assert_array_equal(result, expected_result) diff --git a/tests/tracing/test_dnatracing_multigrain.py b/tests/tracing/test_dnatracing_multigrain.py deleted file mode 100644 index 50436bb72f..0000000000 --- a/tests/tracing/test_dnatracing_multigrain.py +++ /dev/null @@ -1,410 +0,0 @@ -"""Tests for tracing images with multiple (2) grains.""" - -from pathlib import Path - -import numpy as np -import pandas as pd -import pytest - -from topostats.tracing.dnatracing import prep_arrays, trace_image, trace_mask - -# This is required because of the inheritance used throughout -# pylint: disable=redefined-outer-name -BASE_DIR = Path.cwd() -RESOURCES = BASE_DIR / "tests" / "resources" -PIXEL_SIZE = 0.4940029296875 -MIN_SKELETON_SIZE = 10 -PAD_WIDTH = 1 - -LINEAR_IMAGE = np.load(RESOURCES / "dnatracing_image_linear.npy") -LINEAR_MASK = np.load(RESOURCES / "dnatracing_mask_linear.npy") -CIRCULAR_IMAGE = np.load(RESOURCES / "dnatracing_image_circular.npy") -CIRCULAR_MASK = np.load(RESOURCES / "dnatracing_mask_circular.npy") -PADDED_LINEAR_IMAGE = np.pad(LINEAR_IMAGE, ((7, 6), (4, 4))) -PADDED_LINEAR_MASK = np.pad(LINEAR_MASK, ((7, 6), (4, 4))) -CIRCULAR_MASK = np.where(CIRCULAR_MASK == 1, 2, CIRCULAR_MASK) -MULTIGRAIN_IMAGE = np.concatenate((PADDED_LINEAR_IMAGE, CIRCULAR_IMAGE), axis=0) -MULTIGRAIN_MASK = np.concatenate((PADDED_LINEAR_MASK, CIRCULAR_MASK), axis=0) - - -RNG = np.random.default_rng(seed=1000) -SMALL_ARRAY_SIZE = (10, 10) -SMALL_ARRAY = np.asarray(RNG.random(SMALL_ARRAY_SIZE)) -SMALL_MASK = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 2, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 2, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 2, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 2, 0, 0], - [0, 0, 0, 3, 0, 0, 0, 0, 0, 0], - [0, 3, 3, 3, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) - - -@pytest.mark.parametrize( - ("pad_width", "target_image", "target_mask"), - [ - ( - 0, - [ - np.asarray( - [ - [0.24837266, 0.18985741, 0.98399558], - [0.15393237, 0.69908928, 0.44724145], - [0.34320907, 0.8119946, 0.148494], - ] - ), - np.asarray([[0.321028], [0.808018], [0.856378], [0.835525]]), - np.asarray([[0.658189, 0.516731, 0.823857], [0.431365, 0.076602, 0.098614]]), - ], - [ - np.asarray([[1, 1, 1], [1, 0, 0], [1, 0, 0]]), - np.asarray([[1], [1], [1], [1]]), - np.asarray([[0, 0, 1], [1, 1, 1]]), - ], - ), - ( - 1, - [ - np.asarray( - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.52138574, 0.60384185, 0.4709418, 0.20324794, 0.52875903, 0.0], - [0.0, 0.80537222, 0.24837266, 0.18985741, 0.98399558, 0.66999717, 0.0], - [0.0, 0.97476378, 0.15393237, 0.69908928, 0.44724145, 0.01751321, 0.0], - [0.0, 0.13645032, 0.34320907, 0.8119946, 0.148494, 0.05932569, 0.0], - [0.0, 0.55868699, 0.00288863, 0.29775757, 0.05379911, 0.56766875, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ] - ), - np.asarray( - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.20391323, 0.62506469, 0.65260432, 0.0], - [0.0, 0.38123661, 0.32102791, 0.94254467, 0.0], - [0.0, 0.42015645, 0.80801771, 0.00950759, 0.0], - [0.0, 0.72427372, 0.85637809, 0.54431565, 0.0], - [0.0, 0.29894051, 0.83552541, 0.18450649, 0.0], - [0.0, 0.39284504, 0.45345328, 0.27428462, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ] - ), - np.asarray( - [ - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.60442473, 0.73460399, 0.98840812, 0.89224091, 0.51964178, 0.0], - [0.0, 0.23502814, 0.65818858, 0.51673102, 0.82385723, 0.18965801, 0.0], - [0.0, 0.62741705, 0.43136525, 0.07660225, 0.09861362, 0.06647744, 0.0], - [0.0, 0.69874299, 0.88569365, 0.93542321, 0.19316749, 0.95909555, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ] - ), - ], - [ - np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 0, 0], - [0, 0, 1, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - ] - ), - np.asarray( - [ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - ] - ), - np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0], - [0, 0, 1, 1, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - ] - ), - ], - ), - ], -) -def test_prep_arrays(pad_width: int, target_image: np.ndarray, target_mask: np.ndarray) -> None: - """Tests the image and masks are correctly prepared to lists.""" - images, masks = prep_arrays(image=SMALL_ARRAY, labelled_grains_mask=SMALL_MASK, pad_width=pad_width) - for (grain, image), (grain, mask) in zip(images.items(), masks.items()): - np.testing.assert_array_almost_equal(image, target_image[grain]) - np.testing.assert_array_equal(mask, target_mask[grain]) - - -def test_image_trace_unequal_arrays() -> None: - """Test arrays that are unequal throw a ValueError.""" - irregular_mask = np.zeros((MULTIGRAIN_IMAGE.shape[0] + 4, MULTIGRAIN_IMAGE.shape[1] + 5)) - with pytest.raises(ValueError): # noqa: PT011 - trace_image( - image=MULTIGRAIN_IMAGE, - grains_mask=irregular_mask, - filename="dummy", - pixel_to_nm_scaling=PIXEL_SIZE, - min_skeleton_size=MIN_SKELETON_SIZE, - skeletonisation_method="topostats", - pad_width=PAD_WIDTH, - cores=1, - ) - - -TARGET_ARRAY = np.asarray( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] -) - - -@pytest.mark.parametrize( - ("grain_anchors", "ordered_traces", "image_shape", "expected", "pad_width"), - [ - # Ensure that grains whose traces are outside of the image bounds are skipped - ( - [ - [0, 0], - [7, 7], - ], - { - "0": np.asarray([[1, 1], [1, 2], [1, 3]]), # Grain 0's points plus anchor 0 is inside image bounds - "1": np.asarray([[1, 1], [1, 2], [1, 3]]), # Grain 1's points plus anchor 1 are outside image bounds - }, - (5, 5), - np.asarray( - [ - [0, 0, 0, 0, 0], - [0, 1, 1, 1, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - ] - ), - 0, - ), - # pad_width = 0 - ( - [[0, 0], [0, 9], [7, 0], [5, 4], [10, 7]], - { - "0": np.asarray([[1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6]]), # Horizontal grain - "1": np.asarray([[1, 1], [2, 1], [3, 1], [4, 1]]), # Vertical grain - "2": np.asarray([[1, 1], [2, 2], [3, 3], [4, 4]]), # Diagonal grain - "3": np.asarray([[1, 1], [1, 2], [1, 3], [1, 4], [2, 4], [3, 4]]), # L-shaped grain grain - "4": np.asarray( # Small square - [ - [1, 1], - [1, 2], - [1, 3], - [1, 4], - [2, 1], - [2, 4], - [3, 1], - [3, 4], - [4, 1], - [4, 2], - [4, 3], - [4, 4], - ] - ), - }, - (16, 13), - TARGET_ARRAY, - 0, - ), - # pad_width = 1 - ( - [[0, 0], [0, 9], [7, 0], [4, 3], [9, 6]], - { - "0": np.asarray([[2, 2], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7]]), # Horizontal grain - "1": np.asarray([[2, 2], [3, 2], [4, 2], [5, 2]]), # Vertical grain - "2": np.asarray([[2, 2], [3, 3], [4, 4], [5, 5]]), # Diagonal grain - "3": np.asarray([[3, 3], [3, 4], [3, 5], [3, 6], [4, 6], [5, 6]]), # L-shaped grain grain - "4": np.asarray( # Small square - [ - [3, 3], - [3, 4], - [3, 5], - [3, 6], - [4, 3], - [4, 6], - [5, 3], - [5, 6], - [6, 3], - [6, 4], - [6, 5], - [6, 6], - ] - ), - }, - (16, 13), - TARGET_ARRAY, - 1, - ), - # pad_width = 2 - ( - [[0, 0], [0, 9], [7, 0], [4, 3], [9, 6]], - { - "0": np.asarray([[3, 3], [3, 4], [3, 5], [3, 6], [3, 7], [3, 8]]), # Horizontal grain - "1": np.asarray([[3, 3], [4, 3], [5, 3], [6, 3]]), # Vertical grain - "2": np.asarray([[3, 3], [4, 4], [5, 5], [6, 6]]), # Diagonal grain - "3": np.asarray([[4, 4], [4, 5], [4, 6], [4, 7], [5, 7], [6, 7]]), # L-shaped grain grain - "4": np.asarray( # Small square - [ - [4, 4], - [4, 5], - [4, 6], - [4, 7], - [5, 4], - [5, 7], - [6, 4], - [6, 7], - [7, 4], - [7, 5], - [7, 6], - [7, 7], - ] - ), - }, - (16, 13), - TARGET_ARRAY, - 2, - ), - ], -) -def test_trace_mask( - grain_anchors: list, ordered_traces: list, image_shape: tuple, pad_width: int, expected: np.ndarray -) -> None: - """Test the trace_mask.""" - image = trace_mask(grain_anchors, ordered_traces, image_shape, pad_width) - np.testing.assert_array_equal(image, expected) - - -@pytest.mark.parametrize( - ("image", "skeletonisation_method", "cores", "statistics", "ordered_trace_start", "ordered_trace_end"), - [ - ( - "multigrain_topostats", - "topostats", - 1, - pd.DataFrame( - { - "molecule_number": [0, 1], - "image": ["multigrain_topostats", "multigrain_topostats"], - "contour_length": [5.684734982126663e-08, 7.574136072208753e-08], - "circular": [False, True], - "end_to_end_distance": [3.120049919984285e-08, 0.000000e00], - } - ), - [np.asarray([6, 25]), np.asarray([31, 32])], - [np.asarray([65, 47]), np.asarray([31, 31])], - ), - ( - "multigrain_zhang", - "zhang", - 1, - pd.DataFrame( - { - "molecule_number": [0, 1], - "image": ["multigrain_zhang", "multigrain_zhang"], - "contour_length": [6.194694383968301e-08, 8.187508931608563e-08], - "circular": [False, False], - "end_to_end_distance": [2.257869018994927e-08, 1.2389530445725336e-08], - } - ), - [np.asarray([5, 28]), np.asarray([16, 66])], - [np.asarray([61, 32]), np.asarray([33, 54])], - ), - ( - "multigrain_lee", - "lee", - 1, - pd.DataFrame( - { - "molecule_number": [0, 1], - "image": ["multigrain_lee", "multigrain_lee"], - "contour_length": [5.6550320018177204e-08, 8.062559919860786e-08], - "circular": [False, False], - "end_to_end_distance": [3.13837693459974e-08, 6.7191662793734405e-09], - } - ), - [np.asarray([4, 23]), np.asarray([18, 65])], - [np.asarray([65, 47]), np.asarray([34, 54])], - ), - ( - "multigrain_thin", - "thin", - 1, - pd.DataFrame( - { - "molecule_number": [0, 1], - "image": ["multigrain_thin", "multigrain_thin"], - "contour_length": [5.4926652806911664e-08, 3.6512544238919696e-08], - "circular": [False, False], - "end_to_end_distance": [4.367667613976452e-08, 3.440332307376993e-08], - } - ), - [np.asarray([5, 23]), np.asarray([10, 58])], - [np.asarray([71, 81]), np.asarray([83, 30])], - ), - ], -) -def test_trace_image( - image: str, skeletonisation_method: str, cores: int, statistics, ordered_trace_start: list, ordered_trace_end: list -) -> None: - """Tests the processing of an image using trace_image() function. - - NB - This test isn't complete, there is only limited testing of the results["ordered_traces"]. - The results["image_trace"] that are not tested either, these are large arrays and constructing them in the test - is cumbersome. - Initial attempts at using SMALL_ARRAY/SMALL_MASK were unsuccessful as they were not traced because the grains - are < min_skeleton_size, adjusting this to 1 didn't help they still weren't skeletonised. - """ - results = trace_image( - image=MULTIGRAIN_IMAGE, - grains_mask=MULTIGRAIN_MASK, - filename=image, - pixel_to_nm_scaling=PIXEL_SIZE, - min_skeleton_size=MIN_SKELETON_SIZE, - skeletonisation_method=skeletonisation_method, - pad_width=PAD_WIDTH, - cores=cores, - ) - statistics.set_index(["molecule_number"], inplace=True) - pd.testing.assert_frame_equal(results["statistics"], statistics) - for ordered_trace, start, end in zip( - results["all_ordered_traces"].values(), ordered_trace_start, ordered_trace_end - ): - np.testing.assert_array_equal(ordered_trace[1], start) - np.testing.assert_array_equal(ordered_trace[-1], end) diff --git a/tests/tracing/test_dnatracing_single_grain.py b/tests/tracing/test_dnatracing_single_grain.py deleted file mode 100644 index 5515ba03b2..0000000000 --- a/tests/tracing/test_dnatracing_single_grain.py +++ /dev/null @@ -1,592 +0,0 @@ -"""Tests for tracing single molecules.""" - -from pathlib import Path - -import numpy as np -import pytest - -from topostats.tracing.dnatracing import ( - crop_array, - dnaTrace, - grain_anchor, - pad_bounding_box, - trace_grain, -) - -# This is required because of the inheritance used throughout -# pylint: disable=redefined-outer-name -BASE_DIR = Path.cwd() -RESOURCES = BASE_DIR / "tests" / "resources" -PIXEL_SIZE = 0.4940029296875 -MIN_SKELETON_SIZE = 10 -PAD_WIDTH = 30 - -LINEAR_IMAGE = np.load(RESOURCES / "dnatracing_image_linear.npy") -LINEAR_MASK = np.load(RESOURCES / "dnatracing_mask_linear.npy") -CIRCULAR_IMAGE = np.load(RESOURCES / "dnatracing_image_circular.npy") -CIRCULAR_MASK = np.load(RESOURCES / "dnatracing_mask_circular.npy") - - -@pytest.fixture() -def dnatrace_linear() -> dnaTrace: - """dnaTrace object instantiated with a single linear grain.""" # noqa: D403 - return dnaTrace( - image=LINEAR_IMAGE, - grain=LINEAR_MASK, - filename="linear", - pixel_to_nm_scaling=PIXEL_SIZE, - min_skeleton_size=MIN_SKELETON_SIZE, - skeletonisation_method="topostats", - ) - - -@pytest.fixture() -def dnatrace_circular() -> dnaTrace: - """dnaTrace object instantiated with a single linear grain.""" # noqa: D403 - return dnaTrace( - image=CIRCULAR_IMAGE, - grain=CIRCULAR_MASK, - filename="circular", - pixel_to_nm_scaling=PIXEL_SIZE, - min_skeleton_size=MIN_SKELETON_SIZE, - skeletonisation_method="topostats", - ) - - -@pytest.mark.parametrize( - ("dnatrace", "gauss_image_sum"), - [ - pytest.param("dnatrace_linear", 5.517763534147536e-06, id="linear"), - pytest.param("dnatrace_circular", 6.126947266262167e-06, id="circular"), - ], -) -def test_gaussian_filter(dnatrace: dnaTrace, gauss_image_sum: float, request) -> None: - """Test of the method.""" - dnatrace = request.getfixturevalue(dnatrace) - dnatrace.gaussian_filter() - assert dnatrace.gauss_image.sum() == pytest.approx(gauss_image_sum) - - -@pytest.mark.parametrize( - ("dnatrace", "skeletonisation_method", "length", "start", "end"), - [ - pytest.param( - "dnatrace_linear", "topostats", 120, np.asarray([28, 47]), np.asarray([106, 87]), id="linear topostats" - ), - pytest.param( - "dnatrace_circular", "topostats", 150, np.asarray([59, 59]), np.asarray([113, 54]), id="circular topostats" - ), - pytest.param("dnatrace_linear", "zhang", 170, np.asarray([28, 47]), np.asarray([106, 87]), id="linear zhang"), - pytest.param( - "dnatrace_circular", "zhang", 184, np.asarray([43, 95]), np.asarray([113, 54]), id="circular zhang" - ), - pytest.param("dnatrace_linear", "lee", 130, np.asarray([27, 45]), np.asarray([106, 87]), id="linear lee"), - pytest.param("dnatrace_circular", "lee", 177, np.asarray([45, 93]), np.asarray([114, 53]), id="circular lee"), - pytest.param("dnatrace_linear", "thin", 187, np.asarray([27, 45]), np.asarray([106, 83]), id="linear thin"), - pytest.param("dnatrace_circular", "thin", 190, np.asarray([38, 85]), np.asarray([115, 52]), id="circular thin"), - ], -) -def test_get_disordered_trace( # pylint: disable=too-many-positional-arguments - dnatrace: dnaTrace, skeletonisation_method: str, length: int, start: tuple, end: tuple, request -) -> None: - """Test of get_disordered_trace the method.""" - dnatrace = request.getfixturevalue(dnatrace) - dnatrace.skeletonisation_method = skeletonisation_method - dnatrace.gaussian_filter() - dnatrace.get_disordered_trace() - assert isinstance(dnatrace.disordered_trace, np.ndarray) - assert len(dnatrace.disordered_trace) == length - np.testing.assert_array_equal( - dnatrace.disordered_trace[0,], - start, - ) - np.testing.assert_array_equal( - dnatrace.disordered_trace[-1,], - end, - ) - - -# Currently linear molecule isn't detected as linear, although it was when selecting and extracting in a Notebook -@pytest.mark.parametrize( - ("dnatrace", "mol_is_circular"), - [ - # pytest.param("dnatrace_linear"), False, id="linear"), - pytest.param("dnatrace_circular", True, id="circular"), - ], -) -def test_linear_or_circular(dnatrace: dnaTrace, mol_is_circular: int, request) -> None: - """Test of the linear_or_circular method.""" - dnatrace = request.getfixturevalue(dnatrace) - dnatrace.min_skeleton_size = MIN_SKELETON_SIZE - dnatrace.gaussian_filter() - dnatrace.get_disordered_trace() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - assert dnatrace.mol_is_circular == mol_is_circular - - -@pytest.mark.parametrize( - ("dnatrace", "length", "start", "end"), - [ - pytest.param("dnatrace_linear", 118, np.asarray([28, 48]), np.asarray([88, 70]), id="linear"), - pytest.param("dnatrace_circular", 151, np.asarray([59, 59]), np.asarray([59, 59]), id="circular"), - ], -) -def test_get_ordered_traces(dnatrace: dnaTrace, length: int, start: np.array, end: np.array, request) -> None: - """Test of the get_ordered_traces method. - - Note the coordinates at the start and end differ from the fixtures for test_get_disordered_trace, but that the - circular molecule starts and ends in the same place but the linear doesn't (even though it is currently reported as - being circular!). - """ - dnatrace = request.getfixturevalue(dnatrace) - dnatrace.gaussian_filter() - dnatrace.get_disordered_trace() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_ordered_traces() - assert isinstance(dnatrace.ordered_trace, np.ndarray) - assert len(dnatrace.ordered_trace) == length - np.testing.assert_array_equal( - dnatrace.ordered_trace[0,], - start, - ) - np.testing.assert_array_almost_equal( - dnatrace.ordered_trace[-1,], - end, - ) - - -@pytest.mark.parametrize( - ("dnatrace", "length", "start", "end"), - [ - pytest.param("dnatrace_linear", 118, 8.8224769e-10, 1.7610771e-09, id="linear"), - pytest.param("dnatrace_circular", 151, 2.5852866e-09, 2.5852866e-09, id="circular"), - ], -) -def test_get_ordered_trace_heights(dnatrace: dnaTrace, length: int, start: float, end: float, request) -> None: - """Test of the get_trace_heights method.""" - dnatrace = request.getfixturevalue(dnatrace) - dnatrace.gaussian_filter() - dnatrace.get_disordered_trace() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_ordered_traces() - dnatrace.get_ordered_trace_heights() - assert isinstance(dnatrace.ordered_trace_heights, np.ndarray) - assert len(dnatrace.ordered_trace_heights) == length - assert dnatrace.ordered_trace_heights[0] == pytest.approx(start, abs=1e-12) - assert dnatrace.ordered_trace_heights[-1] == pytest.approx(end, abs=1e-12) - - -@pytest.mark.parametrize( - ("dnatrace", "length", "start", "end"), - [ - pytest.param("dnatrace_linear", 118, 0.0, 6.8234101e-08, id="linear"), - pytest.param("dnatrace_circular", 151, 0.0, 8.3513084e-08, id="circular"), - ], -) -def test_ordered_get_trace_cumulative_distances( - dnatrace: dnaTrace, length: int, start: float, end: float, request -) -> None: - """Test of the get_trace_cumulative_distances method.""" - dnatrace = request.getfixturevalue(dnatrace) - dnatrace.gaussian_filter() - dnatrace.get_disordered_trace() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_ordered_traces() - dnatrace.get_ordered_trace_heights() - dnatrace.get_ordered_trace_cumulative_distances() - assert isinstance(dnatrace.ordered_trace_cumulative_distances, np.ndarray) - assert len(dnatrace.ordered_trace_cumulative_distances) == length - assert dnatrace.ordered_trace_cumulative_distances[0] == pytest.approx(start, abs=1e-11) - assert dnatrace.ordered_trace_cumulative_distances[-1] == pytest.approx(end, abs=1e-11) - # Check that the cumulative distance is always increasing - assert np.all(np.diff(dnatrace.ordered_trace_cumulative_distances) > 0) - - -@pytest.mark.parametrize( - ("coordinate_list", "pixel_to_nm_scaling", "target_list"), - [ - pytest.param(np.asarray([[1, 1], [1, 2]]), 1.0, np.asarray([0.0, 1.0]), id="Horizontal line; scaling 1.0"), - pytest.param(np.asarray([[1, 1], [1, 2]]), 0.5, np.asarray([0.0, 0.5]), id="Horizontal line; scaling 0.5"), - pytest.param(np.asarray([[1, 1], [2, 2]]), 1.0, np.asarray([0.0, np.sqrt(2)]), id="Diagonal line; scaling 1.0"), - pytest.param( - np.asarray([[1, 1], [2, 2], [3, 2], [4, 2], [4, 3]]), - 1.0, - np.asarray([0.0, np.sqrt(2), np.sqrt(2) + 1.0, np.sqrt(2) + 2.0, np.sqrt(2) + 3.0]), - id="Complex line; scaling 1.0", - ), - ], -) -def test_coord_dist(coordinate_list: list, pixel_to_nm_scaling: float, target_list: list) -> None: - """Test of the coord_dist method.""" - cumulative_distance_list = dnaTrace.coord_dist(coordinate_list, pixel_to_nm_scaling) - - assert isinstance(cumulative_distance_list, np.ndarray) - assert cumulative_distance_list.shape[0] == target_list.shape[0] - np.testing.assert_array_almost_equal(cumulative_distance_list, target_list) - - -@pytest.mark.parametrize( - ("dnatrace", "length", "start", "end"), - [ - pytest.param("dnatrace_linear", 118, np.asarray([28, 49]), np.asarray([88, 75]), id="linear"), - pytest.param("dnatrace_circular", 151, np.asarray([58, 58]), np.asarray([58, 58]), id="circular"), - ], -) -def test_get_fitted_traces(dnatrace: dnaTrace, length: int, start: np.array, end: np.array, request) -> None: - """Test of the method get_fitted_traces().""" - dnatrace = request.getfixturevalue(dnatrace) - dnatrace.gaussian_filter() - dnatrace.get_disordered_trace() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_ordered_traces() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_fitted_traces() - assert isinstance(dnatrace.fitted_trace, np.ndarray) - assert len(dnatrace.fitted_trace) == length - np.testing.assert_array_equal( - dnatrace.fitted_trace[0,], - start, - ) - np.testing.assert_array_almost_equal( - dnatrace.fitted_trace[-1,], - end, - ) - - -@pytest.mark.parametrize( - ("dnatrace", "length", "start", "end"), - [ - pytest.param( - "dnatrace_linear", - 1652, - np.asarray([35.357143, 46.714286]), - np.asarray([35.357143, 46.714286]), - id="linear", - ), - pytest.param( - "dnatrace_circular", - 2114, - np.asarray([59.285714, 65.428571]), - np.asarray([59.285714, 65.428571]), - id="circular", - ), - ], -) -def test_get_splined_traces(dnatrace: dnaTrace, length: int, start: np.array, end: np.array, request) -> None: - """Test of the method for get_splined_traces().""" - dnatrace = request.getfixturevalue(dnatrace) - dnatrace.gaussian_filter() - dnatrace.get_disordered_trace() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_ordered_traces() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_fitted_traces() - dnatrace.get_splined_traces() - assert isinstance(dnatrace.splined_trace, np.ndarray) - assert len(dnatrace.splined_trace) == length - np.testing.assert_array_almost_equal( - dnatrace.splined_trace[0,], - start, - ) - np.testing.assert_array_almost_equal( - dnatrace.splined_trace[-1,], - end, - ) - - -@pytest.mark.parametrize( - ("dnatrace", "contour_length"), - [ - pytest.param("dnatrace_linear", 9.040267985905398e-08, id="linear"), - pytest.param("dnatrace_circular", 7.617314045334366e-08, id="circular"), - ], -) -def test_measure_contour_length(dnatrace: dnaTrace, contour_length: float, request) -> None: - """Test of the method measure_contour_length().""" - dnatrace = request.getfixturevalue(dnatrace) - dnatrace.gaussian_filter() - dnatrace.get_disordered_trace() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_ordered_traces() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_fitted_traces() - dnatrace.get_splined_traces() - dnatrace.measure_contour_length() - assert dnatrace.contour_length == pytest.approx(contour_length) - - -# Currently need an actual linear grain to test this. -@pytest.mark.parametrize( - ("dnatrace", "end_to_end_distance"), - [ - pytest.param("dnatrace_linear", 0, id="linear"), - pytest.param("dnatrace_circular", 0, id="circular"), - ], -) -def test_measure_end_to_end_distance(dnatrace: dnaTrace, end_to_end_distance: float, request) -> None: - """Test of the method measure_end_to_end_distance().""" - dnatrace = request.getfixturevalue(dnatrace) - dnatrace.gaussian_filter() - dnatrace.get_disordered_trace() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_ordered_traces() - dnatrace.linear_or_circular(dnatrace.disordered_trace) - dnatrace.get_fitted_traces() - dnatrace.get_splined_traces() - dnatrace.measure_end_to_end_distance() - assert dnatrace.end_to_end_distance == pytest.approx(end_to_end_distance) - - -@pytest.mark.parametrize( - ("bounding_box", "pad_width", "target_array"), - [ - ( - # Top right, no padding, does not extend beyond border - (1, 1, 4, 4), - 0, - np.asarray( - [ - [1, 0, 0], - [0, 0, 0], - [0, 0, 1], - ] - ), - ), - ( - # Top right, 1 cell padding, does not extend beyond border - (1, 1, 4, 4), - 1, - np.asarray( - [ - [0, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 1, 0], - [0, 0, 0, 0, 0], - ] - ), - ), - ( - # Top right, 2 cell padding, extends beyond border - (1, 1, 4, 4), - 2, - np.asarray( - [ - [0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - ] - ), - ), - ( - # Bottom left, no padding, does not extend beyond border - (6, 6, 9, 9), - 0, - np.asarray( - [ - [1, 0, 0], - [0, 0, 0], - [0, 0, 1], - ] - ), - ), - ( - # Bottom left, one cell padding, does not extend beyond border - (6, 6, 9, 9), - 1, - np.asarray( - [ - [0, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 1, 0], - [0, 0, 0, 0, 0], - ] - ), - ), - ( - # Bottom left, two cell padding, extends beyond border - (6, 6, 9, 9), - 2, - np.asarray( - [ - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0], - ] - ), - ), - ( - # Bottom left, three cell padding, extends beyond border - (6, 6, 9, 9), - 3, - np.asarray( - [ - [1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 0], - ] - ), - ), - ], -) -def test_crop_array(bounding_box: tuple, pad_width: int, target_array: list) -> None: - """Test cropping of arrays.""" - array = np.zeros((10, 10)) - upper_box = np.asarray([[1, 1], [3, 3]]) - lower_box = np.asarray([[6, 6], [8, 8]]) - array[upper_box[:, 0], upper_box[:, 1]] = 1 - array[lower_box[:, 0], lower_box[:, 1]] = 1 - cropped_array = crop_array(array, bounding_box, pad_width) - np.testing.assert_array_equal(cropped_array, target_array) - - -@pytest.mark.parametrize( - ("array_shape", "bounding_box", "pad_width", "target_coordinates"), - [ - ((10, 10), [1, 1, 5, 5], 1, [0, 0, 6, 6]), - ((10, 10), [1, 1, 5, 5], 3, [0, 0, 8, 8]), - ((10, 10), [4, 4, 5, 5], 1, [3, 3, 6, 6]), - ((10, 10), [4, 4, 5, 5], 3, [1, 1, 8, 8]), - ((10, 10), [4, 4, 5, 5], 6, [0, 0, 10, 10]), - ], -) -def test_pad_bounding_box(array_shape: tuple, bounding_box: list, pad_width: int, target_coordinates: tuple) -> None: - """Test the padding ofbounding boxes.""" - padded_bounding_box = pad_bounding_box(array_shape, bounding_box, pad_width) - assert padded_bounding_box == target_coordinates - - -@pytest.mark.parametrize( - ("array_shape", "bounding_box", "pad_width", "target_coordinates"), - [ - ((10, 10), [1, 1, 5, 5], 1, (0, 0)), - ((10, 10), [1, 1, 5, 5], 3, (0, 0)), - ((10, 10), [4, 4, 5, 5], 1, (3, 3)), - ((10, 10), [4, 4, 5, 5], 3, (1, 1)), - ((10, 10), [4, 4, 5, 5], 6, (0, 0)), - ], -) -def test_grain_anchor(array_shape: tuple, bounding_box: list, pad_width: int, target_coordinates: tuple) -> None: - """Test the extraction of padded bounding boxes.""" - padded_grain_anchor = grain_anchor(array_shape, bounding_box, pad_width) - assert padded_grain_anchor == target_coordinates - - -@pytest.mark.parametrize( - ( - "cropped_image", - "cropped_mask", - "filename", - "skeletonisation_method", - "end_to_end_distance", - "circular", - "contour_length", - ), - [ - ( - LINEAR_IMAGE, - LINEAR_MASK, - "linear_test_topostats", - "topostats", - 3.115753758716346e-08, - False, - 5.684734982126664e-08, - ), - ( - CIRCULAR_IMAGE, - CIRCULAR_MASK, - "circular_test_topostats", - "topostats", - 0, - True, - 7.617314045334366e-08, - ), - ( - LINEAR_IMAGE, - LINEAR_MASK, - "linear_test_zhang", - "zhang", - 2.6964685842539566e-08, - False, - 6.194694383968303e-08, - ), - ( - CIRCULAR_IMAGE, - CIRCULAR_MASK, - "circular_test_zhang", - "zhang", - 9.636691058914389e-09, - False, - 8.187508931608563e-08, - ), - ( - LINEAR_IMAGE, - LINEAR_MASK, - "linear_test_lee", - "lee", - 3.197879765453915e-08, - False, - 5.655032001817721e-08, - ), - ( - CIRCULAR_IMAGE, - CIRCULAR_MASK, - "circular_test_lee", - "lee", - 8.261640682714017e-09, - False, - 8.062559919860788e-08, - ), - ( - LINEAR_IMAGE, - LINEAR_MASK, - "linear_test_thin", - "thin", - 4.068855894099921e-08, - False, - 5.518856387362746e-08, - ), - ( - CIRCULAR_IMAGE, - CIRCULAR_MASK, - "circular_test_thin", - "thin", - 3.638262839374549e-08, - False, - 3.6512544238919716e-08, - ), - ], -) -def test_trace_grain( # pylint: disable=too-many-positional-arguments - cropped_image: np.ndarray, - cropped_mask: np.ndarray, - filename: str, - skeletonisation_method: str, - end_to_end_distance: float, - circular: bool, - contour_length: float, -) -> None: - """Test trace_grain function for tracing a single grain.""" - trace_stats = trace_grain( - cropped_image=cropped_image, - cropped_mask=cropped_mask, - pixel_to_nm_scaling=PIXEL_SIZE, - filename=filename, - min_skeleton_size=MIN_SKELETON_SIZE, - skeletonisation_method=skeletonisation_method, - ) - assert trace_stats["image"] == filename - assert trace_stats["end_to_end_distance"] == pytest.approx(end_to_end_distance) - assert trace_stats["circular"] == circular - assert trace_stats["contour_length"] == pytest.approx(contour_length) diff --git a/tests/tracing/test_nodestats.py b/tests/tracing/test_nodestats.py new file mode 100644 index 0000000000..4791a259db --- /dev/null +++ b/tests/tracing/test_nodestats.py @@ -0,0 +1,1034 @@ +# Disable ruff 301 - pickle loading is unsafe, but we don't care for tests. +# ruff: noqa: S301 +"""Test the nodestats module.""" + +import pickle +from pathlib import Path + +import numpy as np +import numpy.typing as npt +import pandas as pd +import pytest + +# pylint: disable=import-error +# pylint: disable=no-name-in-module +from topostats.io import dict_almost_equal +from topostats.tracing.nodestats import nodeStats, nodestats_image + +BASE_DIR = Path.cwd() +GENERAL_RESOURCES = BASE_DIR / "tests" / "resources" +DISORDERED_TRACING_RESOURCES = GENERAL_RESOURCES / "tracing" / "disordered_tracing" +NODESTATS_RESOURCES = GENERAL_RESOURCES / "tracing" / "nodestats" +# from topostats.tracing.nodestats import nodeStats + +# pylint: disable=unnecessary-pass +# pylint: disable=too-many-arguments +# pylint: disable=too-many-locals +# pylint: disable=too-many-lines + + +# @pytest.mark.parametrize() +def test_get_node_stats() -> None: + """Test of get_node_stats() method of nodeStats class.""" + pass + + +def test_check_node_errorless() -> None: + """Test of check_node_errorless() method of nodeStats class.""" + pass + + +def test_skeleton_image_to_graph() -> None: + """Test of skeleton_image_to_graph() method of nodeStats class.""" + pass + + +def test_graph_to_skeleton_image() -> None: + """Test of graph_to_skeleton_image() method of nodeStats class.""" + pass + + +def test_tidy_branches() -> None: + """Test of tidy_branches() method of nodeStats class.""" + pass + + +def test_keep_biggest_object() -> None: + """Test of keep_biggest_object() method of nodeStats class.""" + pass + + +def test_connect_close_nodes() -> None: + """Test of connect_close_nodes() method of nodeStats class.""" + pass + + +def test_highlight_node_centres() -> None: + """Test of highlight_node_centres() method of nodeStats class.""" + pass + + +def test_connect_extended_nodes() -> None: + """Test of connect_extended_nodes() method of nodeStats class.""" + pass + + +@pytest.mark.parametrize( + ("connected_nodes", "expected_nodes"), + [ + pytest.param( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 3, 1, 1, 1, 1, 1, 1, 1, 3, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="theta_grain", + ), + pytest.param( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 3, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 3, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="figure_8", + ), + ], +) +def test_connect_extended_nodes_nearest( + connected_nodes: npt.NDArray[np.number], expected_nodes: npt.NDArray[np.number] +) -> None: + """Test of connect_extended_nodes_nearest() method of nodeStats class. + + Needs a test for theta topology and figure 8. + """ + nodestats = nodeStats( + filename="dummy", + image=np.array([[0, 0, 0], [0, 1.5, 0], [0, 0, 0]]), + mask=np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]]), + smoothed_mask=np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]]), + skeleton=connected_nodes.astype(bool), + pixel_to_nm_scaling=np.float64(1.0), + n_grain=0, + node_joining_length=0.0, + node_extend_dist=14.0, + branch_pairing_length=20.0, + pair_odd_branches=True, + ) + nodestats.whole_skel_graph = nodestats.skeleton_image_to_graph(nodestats.skeleton) + result = nodestats.connect_extended_nodes_nearest(connected_nodes, node_extend_dist=8.0) + + assert np.array_equal(result, expected_nodes) + + +def test_find_branch_starts() -> None: + """Test of find_branch_starts() method of nodeStats class.""" + pass + + +# Create nodestats class using the cats image - will allow running the code for diagnostics +def test_analyse_nodes( + nodestats_catenane: nodeStats, +) -> None: + """Test of analyse_nodes() method of nodeStats class.""" + nodestats_catenane.analyse_nodes(max_branch_length=20) + + node_dict_result = nodestats_catenane.node_dicts + image_dict_result = nodestats_catenane.image_dict + + # Debugging + # Save the results to overwrite expected results + # with Path(GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_node_dict.pkl").open("wb") as f: + # pickle.dump(node_dict_result, f) + + # with Path(GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_image_dict.pkl").open("wb") as f: + # pickle.dump(image_dict_result, f) + + # np.save( + # GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_all_connected_nodes.npy", + # nodestats_catenane.all_connected_nodes, + # ) + + # Load the nodestats catenane node dict from pickle + with Path(GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_node_dict.pkl").open("rb") as f: + expected_nodestats_catenane_node_dict = pickle.load(f) + + # Load the nodestats catenane image dict from pickle + with Path(GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_image_dict.pkl").open("rb") as f: + expected_nodestats_catenane_image_dict = pickle.load(f) + + # Load the nodestats catenane all connected nodes from pickle + with Path(GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_all_connected_nodes.npy").open("rb") as f: + expected_nodestats_catenane_all_connected_nodes = np.load(f) + + assert dict_almost_equal(node_dict_result, expected_nodestats_catenane_node_dict) + assert dict_almost_equal(image_dict_result, expected_nodestats_catenane_image_dict) + np.testing.assert_array_equal( + nodestats_catenane.all_connected_nodes, expected_nodestats_catenane_all_connected_nodes + ) + + +@pytest.mark.parametrize( + ( + "branch_under_over_order", + "matched_branches_filename", + "masked_image_filename", + "branch_start_coords", + "ordered_branches_filename", + "pairs", + "average_trace_advised", + "image_shape", + "expected_branch_image_filename", + "expected_average_image_filename", + ), + [ + pytest.param( + np.array([0, 1]), + "catenane_node_0_matched_branches_analyse_node_branches.pkl", + "catenane_node_0_masked_image.pkl", + np.array([np.array([278, 353]), np.array([279, 351]), np.array([281, 352]), np.array([281, 354])]), + "catenane_node_0_ordered_branches.pkl", + np.array([(1, 3), (2, 0)]), + True, + (755, 621), + "catenane_node_0_branch_image.npy", + "catenane_node_0_avg_image.npy", + ) + ], +) +def test_add_branches_to_labelled_image( + branch_under_over_order: npt.NDArray[np.int32], + matched_branches_filename: str, + masked_image_filename: str, + branch_start_coords: npt.NDArray[np.int32], + ordered_branches_filename: str, + pairs: npt.NDArray[np.int32], + average_trace_advised: bool, + image_shape: tuple[int, int], + expected_branch_image_filename: str, + expected_average_image_filename: str, +) -> None: + """Test of add_branches_to_labelled_image() method of nodeStats class.""" + # Load the matched branches + with Path(GENERAL_RESOURCES / f"{matched_branches_filename}").open("rb") as f: + matched_branches: dict[int, dict[str, npt.NDArray[np.number]]] = pickle.load(f) + + # Load the masked image + with Path(GENERAL_RESOURCES / f"{masked_image_filename}").open("rb") as f: + masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] = pickle.load(f) + + # Load the ordered branches + with Path(GENERAL_RESOURCES / f"{ordered_branches_filename}").open("rb") as f: + ordered_branches: list[npt.NDArray[np.int32]] = pickle.load(f) + + # Load the branch image + expected_branch_image: npt.NDArray[np.int32] = np.load(GENERAL_RESOURCES / expected_branch_image_filename) + + # Load the average image + expected_average_image: npt.NDArray[np.float64] = np.load(GENERAL_RESOURCES / expected_average_image_filename) + + result_branch_image, result_average_image = nodeStats.add_branches_to_labelled_image( + branch_under_over_order=branch_under_over_order, + matched_branches=matched_branches, + masked_image=masked_image, + branch_start_coords=branch_start_coords, + ordered_branches=ordered_branches, + pairs=pairs, + average_trace_advised=average_trace_advised, + image_shape=image_shape, + ) + + np.testing.assert_equal(result_branch_image, expected_branch_image) + np.testing.assert_equal(result_average_image, expected_average_image) + + +# FIXME Need a test for not pairing odd branches. Will need a test image with 3-nodes. +@pytest.mark.parametrize( + ( + "p_to_nm", + "reduced_node_area_filename", + "branch_start_coords", + "max_length_px", + "reduced_skeleton_graph_filename", + "image", + "average_trace_advised", + "node_coord", + "pair_odd_branches", + "filename", + "resolution_threshold", + "expected_pairs", + "expected_matched_branches_filename", + "expected_ordered_branches_filename", + "expected_masked_image_filename", + "expected_branch_under_over_order", + "expected_conf", + "expected_singlet_branch_vectors", + ), + [ + pytest.param( + 0.18124609, + "catenane_node_0_reduced_node_area.npy", + np.array([np.array([278, 353]), np.array([279, 351]), np.array([281, 352]), np.array([281, 354])]), + 110.34720989566988, + "catenane_node_0_reduced_skeleton_graph.pkl", + "catenane_image", + True, + (np.int32(280), np.int32(353)), + True, + "catenane_test_image", + np.float64(1000 / 512), + np.array([(1, 3), (2, 0)]), + "catenane_node_0_matched_branches_analyse_node_branches.pkl", + "catenane_node_0_ordered_branches.pkl", + "catenane_node_0_masked_image.pkl", + np.array([0, 1]), + 0.48972025484111525, + [ + np.array([-0.97044686, -0.24131493]), + np.array([0.10375883, -0.99460249]), + np.array([0.98972257, -0.14300081]), + np.array([0.46367343, 0.88600618]), + ], + id="node 0", + ) + ], +) +def test_analyse_node_branches( + p_to_nm: float, + reduced_node_area_filename: npt.NDArray[np.int32], + branch_start_coords: npt.NDArray[np.int32], + max_length_px: np.int32, + reduced_skeleton_graph_filename: npt.NDArray[np.int32], + image: npt.NDArray[np.float64], + average_trace_advised: bool, + node_coord: tuple[np.int32, np.int32], + pair_odd_branches: np.bool_, + filename: str, + resolution_threshold: np.float64, + expected_pairs: npt.NDArray[np.int32], + expected_matched_branches_filename: str, + expected_ordered_branches_filename: str, + expected_masked_image_filename: str, + expected_branch_under_over_order: npt.NDArray[np.int32], + expected_conf: float, + expected_singlet_branch_vectors: list[npt.NDArray[np.int32]], + request, +) -> None: + """Test of analyse_node_branches() method of nodeStats class.""" + # Load the fixtures + image = request.getfixturevalue(image) + + # Load the reduced node area + reduced_node_area = np.load(GENERAL_RESOURCES / f"{reduced_node_area_filename}") + + # Load the reduced skeleton graph + with Path(GENERAL_RESOURCES / f"{reduced_skeleton_graph_filename}").open("rb") as f: + reduced_skeleton_graph = pickle.load(f) + + ( + result_pairs, + result_matched_branches, + result_ordered_branches, + result_masked_image, + result_branch_idx_order, + result_conf, + result_singlet_branch_vectors, + ) = nodeStats.analyse_node_branches( + p_to_nm=np.float64(p_to_nm), + reduced_node_area=reduced_node_area, + branch_start_coords=branch_start_coords, + max_length_px=max_length_px, + reduced_skeleton_graph=reduced_skeleton_graph, + image=image, + average_trace_advised=average_trace_advised, + node_coord=node_coord, + pair_odd_branches=pair_odd_branches, + filename=filename, + resolution_threshold=resolution_threshold, + ) + + # Load expected matched branches + with Path(GENERAL_RESOURCES / f"{expected_matched_branches_filename}").open("rb") as f: + expected_matched_branches = pickle.load(f) + + # Load expected masked image + with Path(GENERAL_RESOURCES / f"{expected_masked_image_filename}").open("rb") as f: + expected_masked_image = pickle.load(f) + # Load expected ordered branches + with Path(GENERAL_RESOURCES / f"{expected_ordered_branches_filename}").open("rb") as f: + expected_ordered_branches = pickle.load(f) + + np.testing.assert_equal(result_pairs, expected_pairs) + np.testing.assert_equal(result_matched_branches, expected_matched_branches) + np.testing.assert_equal(result_ordered_branches, expected_ordered_branches) + np.testing.assert_equal(result_masked_image, expected_masked_image) + np.testing.assert_equal(result_branch_idx_order, expected_branch_under_over_order) + np.testing.assert_almost_equal(result_conf, expected_conf, decimal=6) + np.testing.assert_almost_equal(result_singlet_branch_vectors, expected_singlet_branch_vectors, decimal=6) + + +@pytest.mark.parametrize( + ( + "pairs", + "ordered_branches_filename", + "reduced_skeleton_graph_filename", + "image", + "average_trace_advised", + "node_coords", + "filename", + "expected_matched_branches_filename", + "expected_masked_image_filename", + ), + [ + pytest.param( + np.array([[1, 3], [2, 0]]), + "catenane_node_0_ordered_branches.pkl", + "catenane_node_0_reduced_skeleton_graph.pkl", + "catenane_image", + True, + (280, 353), + "catenane_test_image", + "catenane_node_0_matched_branches_join_matching_branches_through_node.pkl", + "catenane_node_0_masked_image.pkl", + id="node 0", + ), + pytest.param( + np.array([[0, 3], [2, 1]]), + "catenane_node_1_ordered_branches.pkl", + "catenane_node_1_reduced_skeleton_graph.pkl", + "catenane_image", + True, + (312, 237), + "catenane_test_image", + "catenane_node_1_matched_branches_join_matching_branches_through_node.pkl", + "catenane_node_1_masked_image.pkl", + id="node 1", + ), + pytest.param( + np.array([[3, 1], [0, 2]]), + "catenane_node_2_ordered_branches.pkl", + "catenane_node_2_reduced_skeleton_graph.pkl", + "catenane_image", + True, + (407, 438), + "catenane_test_image", + "catenane_node_2_matched_branches_join_matching_branches_through_node.pkl", + "catenane_node_2_masked_image.pkl", + id="node 2", + ), + pytest.param( + np.array([[1, 3], [2, 0]]), + "catenane_node_3_ordered_branches.pkl", + "catenane_node_3_reduced_skeleton_graph.pkl", + "catenane_image", + True, + (451, 224), + "catenane_test_image", + "catenane_node_3_matched_branches_join_matching_branches_through_node.pkl", + "catenane_node_3_masked_image.pkl", + id="node 3", + ), + pytest.param( + np.array([[1, 3], [2, 0]]), + "catenane_node_4_ordered_branches.pkl", + "catenane_node_4_reduced_skeleton_graph.pkl", + "catenane_image", + True, + (558, 194), + "catenane_test_image", + "catenane_node_4_matched_branches_join_matching_branches_through_node.pkl", + "catenane_node_4_masked_image.pkl", + id="node 4", + ), + ], +) +def test_join_matching_branches_through_node( + pairs: npt.NDArray[np.int32], + ordered_branches_filename: str, + reduced_skeleton_graph_filename: str, + image: npt.NDArray[np.float64], + average_trace_advised: bool, + node_coords: tuple[np.int32, np.int32], + filename: str, + expected_matched_branches_filename: str, + expected_masked_image_filename: str, + request, +) -> None: + """Test of join_matching_branches_through_node() method of nodeStats class.""" + # Load the fixtures + image = request.getfixturevalue(image) + + # Load the ordered branches + with Path(GENERAL_RESOURCES / f"{ordered_branches_filename}").open("rb") as f: + ordered_branches = pickle.load(f) + + # Load the reduced skeleton graph + with Path(GENERAL_RESOURCES / f"{reduced_skeleton_graph_filename}").open("rb") as f: + reduced_skeleton_graph = pickle.load(f) + + # Load expected matched branches + with Path(GENERAL_RESOURCES / f"{expected_matched_branches_filename}").open("rb") as f: + expected_matched_branches = pickle.load(f) + + # Load expected masked image + with Path(GENERAL_RESOURCES / f"{expected_masked_image_filename}").open("rb") as f: + expected_masked_image = pickle.load(f) + + result_matched_branches, result_masked_image = nodeStats.join_matching_branches_through_node( + pairs=pairs, + ordered_branches=ordered_branches, + reduced_skeleton_graph=reduced_skeleton_graph, + image=image, + average_trace_advised=average_trace_advised, + node_coords=node_coords, + filename=filename, + ) + + np.testing.assert_equal(result_matched_branches, expected_matched_branches) + np.testing.assert_equal(result_masked_image, expected_masked_image) + + +@pytest.mark.parametrize( + ("reduced_node_area", "branch_start_coords", "max_length_px", "expected_ordered_branches", "expected_vectors"), + [ + pytest.param( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 3, 0, 0, 0, 0], + [0, 0, 1, 1, 3, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 3, 1, 0, 0], + [0, 0, 0, 0, 0, 3, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.array([[4, 3], [3, 4], [6, 5], [5, 6]]), + 2, + [ + np.array([[4, 3], [4, 2]]), + np.array([[3, 4], [2, 4]]), + np.array([[6, 5], [5, 6]]), + np.array([[5, 6], [6, 7]]), + ], + [ + np.array([0.0, -1.0]), + np.array([-1.0, 0.0]), + np.array([-0.70710678, 0.70710678]), + np.array([0.70710678, 0.70710678]), + ], + ), + ], +) +def test_get_ordered_branches_and_vectors( + reduced_node_area: npt.NDArray[np.int32], + branch_start_coords: npt.NDArray[np.int32], + max_length_px: np.int32, + expected_ordered_branches: list[npt.NDArray[np.int32]], + expected_vectors: list[npt.NDArray[np.int32]], +) -> None: + """Test of get_ordered_branches_and_vectors() method of nodeStats class.""" + result_ordered_branches, result_vectors = nodeStats.get_ordered_branches_and_vectors( + reduced_node_area=reduced_node_area, branch_start_coords=branch_start_coords, max_length_px=max_length_px + ) + np.testing.assert_equal(result_ordered_branches, expected_ordered_branches) + np.testing.assert_almost_equal(result_vectors, expected_vectors, decimal=6) + + +def test_sq() -> None: + """Test of sq() method of nodeStats class.""" + pass + + +def test_tri() -> None: + """Test of tri() method of nodeStats class.""" + pass + + +def test_auc() -> None: + """Test of auc() method of nodeStats class.""" + pass + + +def test_cross_confidence() -> None: + """Test of cross_confidence() method of nodeStats class.""" + pass + + +def test_recip() -> None: + """Test of recip() method of nodeStats class.""" + pass + + +def test_per_diff() -> None: + """Test of per_diff() method of nodeStats class.""" + pass + + +def test_detect_ridges() -> None: + """Test of detect_ridges() method of nodeStats class.""" + pass + + +def test_get_box_lims() -> None: + """Test of get_box_lims() method of nodeStats class.""" + pass + + +def test_order_branch() -> None: + """Test of order_branch() method of nodeStats class.""" + pass + + +def test_order_branch_from_start() -> None: + """Test of order_branch_from_start() method of nodeStats class.""" + pass + + +def test_local_area_sum() -> None: + """Test of local_area_sum() method of nodeStats class.""" + pass + + +def test_get_vector() -> None: + """Test of get_vector() method of nodeStats class.""" + pass + + +def test_calc_angles() -> None: + """Test of calc_angles() method of nodeStats class.""" + pass + + +def test_pair_vectors() -> None: + """Test of pair_vectors() method of nodeStats class.""" + pass + + +def test_best_matches() -> None: + """Test of best_matches() method of nodeStats class.""" + pass + + +def test_create_weighted_graph() -> None: + """Test of create_weighted_graph() method of nodeStats class.""" + pass + + +def test_pair_angles() -> None: + """Test of pair_angles() method of nodeStats class.""" + pass + + +def test_gaussian() -> None: + """Test of gaussian() method of nodeStats class.""" + pass + + +def test_fwhm() -> None: + """Test of fwhm() method of nodeStats class.""" + pass + + +def test_fwhm2() -> None: + """Test of fwhm2() method of nodeStats class.""" + pass + + +def test_peak_height() -> None: + """Test of peak_height() method of nodeStats class.""" + pass + + +def test_lin_interp() -> None: + """Test of lin_interp() method of nodeStats class.""" + pass + + +def test_close_coords() -> None: + """Test of close_coords() method of nodeStats class.""" + pass + + +def test_order_branches() -> None: + """Test of order_branches() method of nodeStats class.""" + pass + + +def test_binary_line() -> None: + """Test of binary_line() method of nodeStats class.""" + pass + + +def test_coord_dist() -> None: + """Test of coord_dist() method of nodeStats class.""" + pass + + +def test_coord_dist_rad() -> None: + """Test of coord_dist_rad() method of nodeStats class.""" + pass + + +def test_above_below_value_idx() -> None: + """Test of above_below_value_idx() method of nodeStats class.""" + pass + + +def test_average_height_trace() -> None: + """Test of average_height_trace() method of nodeStats class.""" + pass + + +def test_fill_holes() -> None: + """Test of fill_holes() method of nodeStats class.""" + pass + + +def test_remove_re_entering_branches() -> None: + """Test of remove_re_entering_branches() method of nodeStats class.""" + pass + + +@pytest.mark.parametrize( + ("node_image", "node_coordinate", "expected_node_image"), + [ + pytest.param( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 3, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 1, 1, 0, 3, 0, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0, 0, 0, 3, 1, 0, 0], + [0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 1, 0], + [0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.array([6, 7]), + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 3, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 1, 1, 0, 3, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + ) + ], +) +def test_only_centre_branches( + node_image: npt.NDArray[np.int32], + node_coordinate: npt.NDArray[np.int32], + expected_node_image: npt.NDArray[np.int32], +) -> None: + """Test of only_centre_branches() method of nodeStats class.""" + result_node_image = nodeStats.only_centre_branches(node_image, node_coordinate) + + np.testing.assert_equal(result_node_image, expected_node_image) + + +def test_average_uniques() -> None: + """Test of average_uniques() method of nodeStats class.""" + pass + + +def test_compile_trace() -> None: + """Test of compile_trace() method of nodeStats class.""" + pass + + +def test_get_minus_img() -> None: + """Test of get_minus_img() method of nodeStats class.""" + pass + + +def test_remove_common_values() -> None: + """Test of remove_common_values() method of nodeStats class.""" + pass + + +def test_trace() -> None: + """Test of trace() method of nodeStats class.""" + pass + + +def test_reduce_rows() -> None: + """Test of reduce_rows() method of nodeStats class.""" + pass + + +def test_get_trace_segment() -> None: + """Test of get_trace_segment() method of nodeStats class.""" + pass + + +def test_comb_xyzs() -> None: + """Test of comb_xyzs() method of nodeStats class.""" + pass + + +def test_remove_duplicates() -> None: + """Test of remove_duplicates() method of nodeStats class.""" + pass + + +def test_order_from_end() -> None: + """Test of order_from_end() method of nodeStats class.""" + pass + + +def test_get_trace_idxs() -> None: + """Test of get_trace_idxs() method of nodeStats class.""" + pass + + +def test_get_visual_img() -> None: + """Test of get_visual_img() method of nodeStats class.""" + pass + + +def test_average_crossing_confs() -> None: + """Test of average_crossing_confs() method of nodeStats class.""" + pass + + +def test_minimum_crossing_confs() -> None: + """Test minimum_crossing_confs() method of nodeStats class.""" + pass + + +@pytest.mark.parametrize( + ( + "image_filename", + "pixel_to_nm_scaling", + "disordered_tracing_crop_data_filename", + "node_joining_length", + "node_extend_dist", + "branch_pairing_length", + "pair_odd_branches", + "expected_nodestats_data_filename", + "expected_nodestats_grainstats_filename", + "expected_nodestats_all_images_filename", + "expected_nodestats_branch_images_filename", + ), + [ + pytest.param( + "example_catenanes.npy", + # Pixel to nm scaling + 0.488, + "catenanes_disordered_tracing_crop_data.pkl", + # Node joining length + 7.0, + # Node extend distance + 14.0, + # Branch pairing length + 20.0, + # Pair odd branches + True, + "catenanes_nodestats_data.pkl", + "catenanes_nodestats_grainstats.csv", + "catenanes_nodestats_all_images.pkl", + "catenanes_nodestats_branch_images.pkl", + id="catenane", + ), + pytest.param( + "example_rep_int.npy", + # Pixel to nm scaling + 0.488, + "rep_int_disordered_tracing_crop_data.pkl", + # Node joining length + 7.0, + # Node extend distance + 14.0, + # Branch pairing length + 20.0, + # Pair odd branches + False, + "rep_int_nodestats_data_no_pair_odd_branches.pkl", + "rep_int_nodestats_grainstats_no_pair_odd_branches.csv", + "rep_int_nodestats_all_images_no_pair_odd_branches.pkl", + "rep_int_nodestats_branch_images_no_pair_odd_branches.pkl", + id="replication_intermediate, not pairing odd branches", + ), + pytest.param( + "example_rep_int.npy", + # Pixel to nm scaling + 0.488, + "rep_int_disordered_tracing_crop_data.pkl", + # Node joining length + 7.0, + # Node extend distance + 14.0, + # Branch pairing length + 20.0, + # Pair odd branches + True, + "rep_int_nodestats_data_pair_odd_branches.pkl", + "rep_int_nodestats_grainstats_pair_odd_branches.csv", + "rep_int_nodestats_all_images_pair_odd_branches.pkl", + "rep_int_nodestats_branch_images_pair_odd_branches.pkl", + id="replication_intermediate, pairing odd branches", + ), + ], +) +def test_nodestats_image( + image_filename: str, + pixel_to_nm_scaling: float, + disordered_tracing_crop_data_filename: str, + node_joining_length: float, + node_extend_dist: float, + branch_pairing_length: float, + pair_odd_branches: bool, + expected_nodestats_data_filename: str, + expected_nodestats_grainstats_filename: str, + expected_nodestats_all_images_filename: str, + expected_nodestats_branch_images_filename: str, +) -> None: + """Test of nodestats_image() method of nodeStats class.""" + # Load the image + image = np.load(GENERAL_RESOURCES / image_filename) + # load disordered_tracing_crop_data from pickle + with Path(DISORDERED_TRACING_RESOURCES / disordered_tracing_crop_data_filename).open("rb") as f: + disordered_tracing_crop_data = pickle.load(f) + + ( + result_nodestats_data, + result_nodestats_grainstats, + result_nodestats_all_images, + result_nodestats_branch_images, + ) = nodestats_image( + image=image, + disordered_tracing_direction_data=disordered_tracing_crop_data, + filename="test_image", + pixel_to_nm_scaling=pixel_to_nm_scaling, + node_joining_length=node_joining_length, + node_extend_dist=node_extend_dist, + branch_pairing_length=branch_pairing_length, + pair_odd_branches=pair_odd_branches, + pad_width=1, + ) + + # # DEBUGGING (For viewing images) + # convolved_skeletons = result_all_images["convolved_skeletons"] + # node_centres = result_all_images["node_centres"] + # connected_nodes = result_all_images["connected_nodes"] + + # Save the results + + # Save the result_nodestats_data + with Path(NODESTATS_RESOURCES / expected_nodestats_data_filename).open("wb") as f: + pickle.dump(result_nodestats_data, f) + + # Save the result_stats_df as a csv + # result_nodestats_grainstats.to_csv(NODESTATS_RESOURCES / expected_nodestats_grainstats_filename) + + # # Save the result_all_images + # with Path(NODESTATS_RESOURCES / expected_nodestats_all_images_filename).open("wb") as f: + # pickle.dump(result_nodestats_all_images, f) + + # # Save the result_nodestats_branch_images + # with Path(NODESTATS_RESOURCES / expected_nodestats_branch_images_filename).open("wb") as f: + # pickle.dump(result_nodestats_branch_images, f) + + # Load expected data + + # Load the expected nodestats data + with Path(NODESTATS_RESOURCES / expected_nodestats_data_filename).open("rb") as f: + expected_nodestats_data = pickle.load(f) + + # Load the expected grainstats additions + expected_nodestats_grainstats = pd.read_csv( + NODESTATS_RESOURCES / expected_nodestats_grainstats_filename, index_col=0 + ) + + # Load the expected all images + with Path(NODESTATS_RESOURCES / expected_nodestats_all_images_filename).open("rb") as f: + expected_all_images = pickle.load(f) + + # Load the expected nodestats branch images + with Path(NODESTATS_RESOURCES / expected_nodestats_branch_images_filename).open("rb") as f: + expected_nodestats_branch_images = pickle.load(f) + + assert dict_almost_equal(result_nodestats_data, expected_nodestats_data, abs_tol=1e-3) + pd.testing.assert_frame_equal(result_nodestats_grainstats, expected_nodestats_grainstats) + assert dict_almost_equal(result_nodestats_all_images, expected_all_images) + assert dict_almost_equal(result_nodestats_branch_images, expected_nodestats_branch_images) diff --git a/tests/tracing/test_ordered_tracing.py b/tests/tracing/test_ordered_tracing.py new file mode 100644 index 0000000000..400430f535 --- /dev/null +++ b/tests/tracing/test_ordered_tracing.py @@ -0,0 +1,360 @@ +# Disable ruff 301 - pickle loading is unsafe, but we don't care for tests +# ruff: noqa: S301 +"""Test the ordered tracing module.""" + +import pickle +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from topostats.tracing.ordered_tracing import linear_or_circular, ordered_tracing_image + +BASE_DIR = Path.cwd() +GENERAL_RESOURCES = BASE_DIR / "tests" / "resources" +ORDERED_TRACING_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "ordered_tracing" +NODESTATS_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "nodestats" +DISORDERED_TRACING_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "disordered_tracing" + +# pylint: disable=unspecified-encoding +# pylint: disable=too-many-locals +# pylint: disable=too-many-arguments + +GRAINS = {} +GRAINS["vertical"] = np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + ] +) +GRAINS["horizontal"] = np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) +GRAINS["diagonal1"] = np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) +GRAINS["diagonal2"] = np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) +GRAINS["diagonal3"] = np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) +GRAINS["circle"] = np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) +GRAINS["blob"] = np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [0, 0, 0, 0, 0], + ] +) +GRAINS["cross"] = np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] +) +GRAINS["single_L"] = np.asarray( + [ + [0, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 1, 0], + [0, 0, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0], + ] +) +GRAINS["double_L"] = np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 1, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 1, 0, 0], + [0, 0, 0, 0, 0], + ] +) +GRAINS["diagonal_end_single_L"] = np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 1, 0, 0], + [0, 0, 0, 0, 0], + ] +) +GRAINS["diagonal_end_straight"] = np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + ] +) +GRAINS["figure8"] = np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) +GRAINS["three_ends"] = np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) +GRAINS["six_ends"] = np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ] +) + + +@pytest.mark.parametrize( + ("grain", "mol_is_circular"), + [ + pytest.param(GRAINS["vertical"], False, id="vertical"), + pytest.param(GRAINS["horizontal"], False, id="horizontal"), + pytest.param(GRAINS["diagonal1"], True, id="diagonal1"), # This is wrong, this IS a linear molecule + pytest.param(GRAINS["diagonal2"], False, id="diagonal2"), + pytest.param(GRAINS["diagonal3"], False, id="diagonal3"), + pytest.param(GRAINS["circle"], True, id="circle"), + pytest.param(GRAINS["blob"], True, id="blob"), + pytest.param(GRAINS["cross"], False, id="cross"), + pytest.param(GRAINS["single_L"], False, id="singl_L"), + pytest.param(GRAINS["double_L"], True, id="double_L"), # This is wrong, this IS a linear molecule + pytest.param(GRAINS["diagonal_end_single_L"], False, id="diagonal_end_single_L"), + pytest.param(GRAINS["diagonal_end_straight"], False, id="diagonal_end_straight"), + pytest.param(GRAINS["figure8"], True, id="figure8"), + pytest.param(GRAINS["three_ends"], False, id="three_ends"), + pytest.param(GRAINS["six_ends"], False, id="six_ends"), + ], +) +def test_linear_or_circular(grain: np.ndarray, mol_is_circular: bool) -> None: + """Test the linear_or_circular method with a range of different structures.""" + linear_coordinates = np.argwhere(grain == 1) + result = linear_or_circular(linear_coordinates) + assert result == mol_is_circular + + +@pytest.mark.parametrize( + ( + "image_filename", + "disordered_tracing_direction_data_filename", + "nodestats_data_filename", + "nodestats_branch_images_filename", + "filename", + "expected_ordered_tracing_data_filename", + "expected_ordered_tracing_grainstats_filename", + "expected_molstats_filename", + "expected_ordered_tracing_full_images_filename", + ), + [ + pytest.param( + "example_catenanes.npy", + "catenanes_disordered_tracing_crop_data.pkl", + "catenanes_nodestats_data.pkl", + "catenanes_nodestats_branch_images.pkl", + "catenane", # filename + "catenanes_ordered_tracing_data.pkl", + "catenanes_ordered_tracing_grainstats.csv", + "catenanes_ordered_tracing_molstats.csv", + "catenanes_ordered_tracing_full_images.pkl", + id="catenane", + ), + pytest.param( + "example_rep_int.npy", + "rep_int_disordered_tracing_crop_data.pkl", + "rep_int_nodestats_data_no_pair_odd_branches.pkl", + "rep_int_nodestats_branch_images_no_pair_odd_branches.pkl", + "replication_intermediate", # filename + "rep_int_ordered_tracing_data.pkl", + "rep_int_ordered_tracing_grainstats.csv", + "rep_int_ordered_tracing_molstats.csv", + "rep_int_ordered_tracing_full_images.pkl", + id="replication_intermediate", + ), + ], +) +def test_ordered_tracing_image( + image_filename: str, + disordered_tracing_direction_data_filename: str, + nodestats_data_filename: str, + nodestats_branch_images_filename: str, + filename: str, + expected_ordered_tracing_data_filename: str, + expected_ordered_tracing_grainstats_filename: str, + expected_molstats_filename: str, + expected_ordered_tracing_full_images_filename: str, +) -> None: + """Test the ordered tracing image method of ordered tracing.""" + # disordered_tracing_direction_data is the disordered tracing data + # for a particular threshold direction. + + # nodestats_direction_data contains both nodestats_data and nodestats_branch_images + + # Load the required data + image = np.load(GENERAL_RESOURCES / image_filename) + + with Path.open(DISORDERED_TRACING_RESOURCES / disordered_tracing_direction_data_filename, "rb") as f: + disordered_tracing_direction_data = pickle.load(f) + + with Path.open(NODESTATS_RESOURCES / nodestats_data_filename, "rb") as f: + nodestats_data = pickle.load(f) + + with Path.open(NODESTATS_RESOURCES / nodestats_branch_images_filename, "rb") as f: + nodestats_branch_images = pickle.load(f) + + nodestats_whole_data = {"stats": nodestats_data, "images": nodestats_branch_images} + + ( + result_ordered_tracing_data, + result_ordered_tracing_grainstats, + result_molstats_df, + result_ordered_tracing_full_images, + ) = ordered_tracing_image( + image=image, + disordered_tracing_direction_data=disordered_tracing_direction_data, + nodestats_direction_data=nodestats_whole_data, + filename=filename, + ordering_method="nodestats", + pad_width=1, + ) + + # # Debugging - grab variables to show images + # variable_ordered_traces = result_ordered_tracing_full_images["ordered_traces"] + # variable_all_molecules = result_ordered_tracing_full_images["all_molecules"] + # variable_over_under = result_ordered_tracing_full_images["over_under"] + # variable_trace_segments = result_ordered_tracing_full_images["trace_segments"] + + # # Save result ordered tracing data as pickle + # with Path.open(ORDERED_TRACING_RESOURCES / expected_ordered_tracing_data_filename, "wb") as f: + # pickle.dump(result_ordered_tracing_data, f) + + # # Save result grainstats additions as csv + # result_ordered_tracing_grainstats.to_csv(ORDERED_TRACING_RESOURCES / expected_ordered_tracing_grainstats_filename) + + # # Save the molstats dataframe as csv + # result_molstats_df.to_csv(ORDERED_TRACING_RESOURCES / expected_molstats_filename) + + # # Save result ordered tracing full images as pickle + # with Path.open(ORDERED_TRACING_RESOURCES / expected_ordered_tracing_full_images_filename, "wb") as f: + # pickle.dump(result_ordered_tracing_full_images, f) + + # Load the expected results + with Path.open(ORDERED_TRACING_RESOURCES / expected_ordered_tracing_data_filename, "rb") as f: + expected_ordered_tracing_data = pickle.load(f) + + expected_ordered_tracing_grainstats = pd.read_csv( + ORDERED_TRACING_RESOURCES / expected_ordered_tracing_grainstats_filename, index_col=0 + ) + + expected_molstats_df = pd.read_csv(ORDERED_TRACING_RESOURCES / expected_molstats_filename, index_col=0) + + with Path.open(ORDERED_TRACING_RESOURCES / expected_ordered_tracing_full_images_filename, "rb") as f: + expected_ordered_tracing_full_images = pickle.load(f) + + # Check the results + np.testing.assert_equal(result_ordered_tracing_data, expected_ordered_tracing_data) + pd.testing.assert_frame_equal(result_ordered_tracing_grainstats, expected_ordered_tracing_grainstats) + pd.testing.assert_frame_equal(result_molstats_df, expected_molstats_df) + np.testing.assert_equal(result_ordered_tracing_full_images, expected_ordered_tracing_full_images) diff --git a/tests/tracing/test_pruning.py b/tests/tracing/test_pruning.py new file mode 100644 index 0000000000..82f840dff6 --- /dev/null +++ b/tests/tracing/test_pruning.py @@ -0,0 +1,2832 @@ +"""Test the skeletonize module.""" + +from __future__ import annotations + +import numpy as np +import numpy.typing as npt +import pytest + +from topostats.tracing.pruning import ( + heightPruning, + local_area_sum, + order_branch_from_end, + rm_nibs, + topostatsPrune, +) + +# pylint: disable=too-many-lines +# pylint: disable=protected-access +# pylint: disable=too-many-arguments +# pylint: disable=too-many-positional-arguments + + +@pytest.mark.parametrize( + ( + "img_skeleton", + "max_length", + "height_threshold", + "method_values", + "method_outliers", + "target_ends", + "target_pruned_coords", + ), + [ + pytest.param( + "skeleton_loop1", + 49, + 90, + "min", + "abs", + np.asarray([[6, 33], [34, 18], [89, 29], [104, 77], [109, 105]]), + np.asarray( + [ + [43, 50], + [43, 51], + [43, 52], + [43, 53], + [43, 54], + [44, 55], + [45, 56], + [46, 57], + [47, 58], + [48, 59], + [49, 60], + [49, 61], + [49, 62], + [50, 60], + [50, 63], + [51, 60], + [51, 64], + [52, 59], + [52, 65], + [53, 59], + [53, 66], + [54, 59], + [54, 67], + [55, 58], + [55, 68], + [56, 58], + [56, 69], + [57, 58], + [57, 70], + [58, 58], + [58, 71], + [59, 58], + [59, 72], + [60, 58], + [60, 59], + [60, 73], + [61, 60], + [61, 74], + [62, 61], + [62, 75], + [63, 62], + [63, 76], + [64, 63], + [64, 77], + [65, 64], + [65, 78], + [66, 65], + [66, 79], + [67, 66], + [67, 80], + [68, 67], + [68, 81], + [69, 68], + [69, 82], + [70, 69], + [70, 83], + [71, 70], + [71, 84], + [72, 71], + [72, 85], + [73, 72], + [73, 86], + [74, 73], + [74, 87], + [75, 74], + [75, 88], + [76, 75], + [76, 89], + [77, 76], + [77, 90], + [78, 77], + [78, 91], + [79, 77], + [79, 92], + [80, 77], + [80, 92], + [81, 77], + [81, 93], + [82, 77], + [82, 94], + [83, 77], + [83, 95], + [84, 77], + [84, 96], + [85, 77], + [85, 97], + [86, 77], + [86, 98], + [87, 77], + [87, 99], + [88, 77], + [88, 100], + [89, 77], + [89, 101], + [90, 77], + [90, 102], + [91, 77], + [91, 103], + [92, 77], + [92, 104], + [93, 77], + [93, 105], + [94, 77], + [94, 106], + [95, 77], + [95, 106], + [96, 77], + [96, 106], + [97, 78], + [97, 106], + [98, 78], + [98, 106], + [99, 79], + [99, 106], + [100, 80], + [100, 106], + [101, 81], + [101, 105], + [102, 82], + [102, 83], + [102, 84], + [102, 85], + [102, 86], + [102, 87], + [102, 88], + [102, 89], + [102, 90], + [102, 91], + [102, 92], + [102, 93], + [102, 94], + [102, 95], + [102, 96], + [102, 97], + [102, 98], + [102, 99], + [102, 100], + [102, 101], + [102, 102], + [102, 103], + [102, 104], + [102, 105], + ] + ), + id="skeleton loop 1", + # marks=pytest.mark.skip(), + ), + pytest.param( + "skeleton_loop2", + 37.8, + 90, + "min", + "abs", + np.asarray([[3, 16], [5, 125], [38, 31], [42, 81], [85, 86]]), + np.asarray( + [ + [1, 47], + [1, 48], + [1, 49], + [1, 50], + [1, 51], + [1, 52], + [1, 53], + [1, 54], + [1, 55], + [1, 56], + [1, 57], + [2, 46], + [2, 58], + [2, 102], + [2, 103], + [2, 104], + [2, 105], + [2, 106], + [2, 107], + [2, 108], + [2, 109], + [2, 110], + [2, 111], + [2, 112], + [2, 113], + [2, 114], + [2, 115], + [2, 116], + [2, 117], + [2, 118], + [2, 119], + [2, 120], + [3, 45], + [3, 59], + [3, 101], + [3, 121], + [4, 44], + [4, 60], + [4, 100], + [4, 122], + [4, 123], + [5, 43], + [5, 61], + [5, 99], + [5, 124], + [5, 125], + [6, 42], + [6, 62], + [6, 98], + [7, 41], + [7, 63], + [7, 97], + [8, 40], + [8, 64], + [8, 96], + [9, 39], + [9, 65], + [9, 95], + [10, 38], + [10, 66], + [10, 94], + [11, 37], + [11, 67], + [11, 93], + [12, 36], + [12, 68], + [12, 92], + [13, 35], + [13, 69], + [13, 91], + [14, 34], + [14, 70], + [14, 90], + [15, 33], + [15, 71], + [15, 89], + [16, 33], + [16, 72], + [16, 88], + [17, 33], + [17, 73], + [17, 87], + [18, 33], + [18, 74], + [18, 86], + [19, 33], + [19, 75], + [19, 85], + [20, 33], + [20, 76], + [20, 84], + [21, 33], + [21, 77], + [21, 83], + [22, 33], + [22, 78], + [22, 82], + [23, 33], + [23, 79], + [23, 80], + [23, 81], + [24, 33], + [24, 80], + [25, 33], + [25, 80], + [26, 33], + [26, 80], + [27, 33], + [27, 80], + [28, 33], + [28, 80], + [29, 33], + [29, 80], + [30, 33], + [30, 79], + [31, 33], + [31, 79], + [32, 33], + [32, 79], + [33, 33], + [33, 79], + [34, 33], + [34, 79], + [35, 33], + [35, 79], + [36, 33], + [36, 79], + [37, 33], + [37, 79], + [38, 33], + [38, 34], + [38, 35], + [38, 79], + [39, 36], + [39, 79], + [40, 37], + [40, 79], + [41, 38], + [41, 39], + [41, 79], + [42, 40], + [42, 79], + [42, 80], + [43, 41], + [43, 42], + [43, 79], + [44, 43], + [44, 44], + [44, 78], + [45, 45], + [45, 46], + [45, 77], + [46, 47], + [46, 76], + [47, 48], + [47, 49], + [47, 75], + [48, 50], + [48, 51], + [48, 52], + [48, 74], + [49, 53], + [49, 54], + [49, 73], + [50, 55], + [50, 56], + [50, 72], + [51, 57], + [51, 58], + [51, 71], + [52, 59], + [52, 60], + [52, 61], + [52, 70], + [53, 62], + [53, 63], + [53, 64], + [53, 69], + [54, 65], + [54, 66], + [54, 67], + [54, 68], + [55, 68], + [56, 68], + [57, 69], + [58, 69], + [59, 69], + [60, 70], + [61, 70], + [62, 70], + [63, 71], + [64, 71], + [65, 72], + [66, 72], + [67, 73], + [68, 73], + [69, 73], + [70, 74], + [71, 74], + [72, 75], + [73, 75], + [74, 76], + [75, 76], + [76, 77], + [77, 78], + [78, 78], + [79, 79], + [80, 79], + [81, 79], + [82, 79], + [83, 79], + [84, 80], + [85, 81], + [85, 85], + [85, 86], + [86, 82], + [86, 83], + [86, 84], + ] + ), + id="skeleton loop 2", + # marks=pytest.mark.skip(), + ), + pytest.param( + "skeleton_linear1", + 42, + 90, + "min", + "abs", + np.asarray( + [ + [2, 71], + [5, 115], + [10, 31], + [12, 126], + [14, 123], + [22, 122], + [29, 22], + [45, 16], + [88, 23], + [112, 102], + ] + ), + np.asarray( + [ + [10, 115], + [10, 116], + [10, 117], + [10, 118], + [10, 119], + [11, 114], + [11, 120], + [11, 121], + [12, 114], + [12, 121], + [13, 114], + [14, 113], + [15, 112], + [16, 111], + [17, 110], + [18, 109], + [19, 108], + [20, 107], + [21, 107], + [22, 106], + [22, 107], + [23, 105], + [24, 31], + [24, 104], + [25, 32], + [25, 33], + [25, 103], + [26, 34], + [26, 35], + [26, 36], + [26, 102], + [27, 37], + [27, 38], + [27, 101], + [28, 39], + [28, 40], + [28, 41], + [28, 42], + [28, 43], + [28, 100], + [29, 44], + [29, 99], + [30, 45], + [30, 98], + [31, 46], + [31, 97], + [32, 47], + [32, 96], + [33, 48], + [33, 95], + [34, 49], + [34, 94], + [35, 50], + [35, 93], + [36, 51], + [36, 92], + [37, 52], + [37, 91], + [38, 53], + [38, 90], + [39, 54], + [39, 87], + [39, 88], + [39, 89], + [40, 55], + [40, 80], + [40, 81], + [40, 82], + [40, 83], + [40, 84], + [40, 85], + [40, 86], + [41, 56], + [41, 78], + [41, 79], + [42, 57], + [42, 77], + [43, 58], + [43, 76], + [44, 59], + [44, 76], + [45, 60], + [45, 76], + [46, 61], + [46, 76], + [47, 62], + [47, 76], + [48, 63], + [48, 76], + [49, 64], + [49, 76], + [50, 65], + [50, 76], + [51, 66], + [51, 74], + [51, 75], + [51, 76], + [52, 67], + [52, 72], + [52, 73], + [52, 77], + [53, 68], + [53, 69], + [53, 70], + [53, 71], + [53, 78], + [54, 36], + [54, 67], + [54, 79], + [55, 35], + [55, 36], + [55, 66], + [55, 80], + [56, 35], + [56, 37], + [56, 64], + [56, 65], + [56, 81], + [57, 36], + [57, 37], + [57, 63], + [57, 82], + [58, 37], + [58, 62], + [58, 82], + [59, 38], + [59, 61], + [59, 83], + [60, 39], + [60, 58], + [60, 59], + [60, 60], + [60, 83], + [61, 40], + [61, 55], + [61, 56], + [61, 57], + [61, 83], + [62, 41], + [62, 53], + [62, 54], + [62, 83], + [63, 42], + [63, 51], + [63, 52], + [63, 83], + [64, 43], + [64, 50], + [64, 84], + [65, 43], + [65, 44], + [65, 45], + [65, 46], + [65, 47], + [65, 48], + [65, 49], + [65, 84], + [66, 31], + [66, 32], + [66, 33], + [66, 34], + [66, 35], + [66, 36], + [66, 37], + [66, 38], + [66, 39], + [66, 40], + [66, 41], + [66, 42], + [66, 84], + [67, 30], + [67, 84], + [68, 85], + [69, 85], + [70, 85], + [71, 85], + [72, 86], + [73, 86], + [74, 86], + [75, 86], + [76, 87], + [77, 87], + [78, 88], + [79, 88], + [80, 89], + [81, 89], + [82, 90], + [83, 90], + [84, 91], + [85, 92], + [86, 92], + [87, 93], + [88, 93], + [89, 93], + [90, 93], + [91, 93], + [92, 93], + [93, 93], + [94, 93], + [95, 94], + [96, 94], + [97, 94], + [98, 94], + [99, 95], + [100, 95], + [101, 96], + [102, 96], + [103, 97], + [104, 97], + [105, 98], + [106, 99], + [107, 99], + [108, 100], + [109, 100], + [110, 100], + [111, 101], + [112, 102], + ] + ), + id="skeleton linear 1", + # marks=pytest.mark.skip(), + ), + pytest.param( + "skeleton_linear2", + 10, + 90, + "min", + "abs", + np.asarray([[75, 84], [88, 76], [102, 18]]), + np.asarray( + [ + [69, 67], + [69, 68], + [69, 69], + [69, 70], + [70, 66], + [70, 71], + [70, 72], + [71, 66], + [71, 73], + [71, 74], + [71, 75], + [72, 65], + [72, 76], + [72, 77], + [73, 65], + [73, 78], + [73, 79], + [73, 80], + [74, 64], + [74, 65], + [74, 81], + [74, 82], + [75, 62], + [75, 63], + [75, 66], + [75, 83], + [75, 84], + [76, 59], + [76, 60], + [76, 61], + [76, 67], + [77, 57], + [77, 58], + [77, 68], + [78, 54], + [78, 55], + [78, 56], + [78, 69], + [79, 52], + [79, 53], + [79, 70], + [80, 50], + [80, 51], + [80, 71], + [81, 49], + [81, 72], + [82, 48], + [82, 73], + [83, 47], + [83, 73], + [84, 46], + [84, 74], + [85, 45], + [85, 74], + [86, 44], + [86, 75], + [87, 43], + [87, 75], + [88, 42], + [88, 76], + [89, 41], + [90, 40], + [91, 39], + [92, 38], + [93, 37], + [94, 36], + [95, 35], + [96, 34], + [97, 33], + [98, 32], + [99, 31], + [100, 30], + [101, 29], + [102, 18], + [102, 19], + [102, 20], + [102, 21], + [102, 22], + [102, 28], + [103, 23], + [103, 24], + [103, 25], + [103, 26], + [103, 27], + ] + ), + id="skeleton linear 2", + # marks=pytest.mark.skip(), + ), + pytest.param( + "skeleton_linear3", + 31, + 90, + "min", + "abs", + np.asarray( + [ + [9, 88], + [13, 64], + [15, 35], + [15, 61], + [17, 81], + [19, 109], + [34, 31], + [40, 98], + [40, 121], + [68, 33], + [68, 92], + [93, 36], + [116, 114], + ] + ), + np.asarray( + [ + [15, 62], + [15, 63], + [16, 63], + [17, 63], + [17, 82], + [17, 83], + [18, 62], + [18, 83], + [19, 62], + [19, 83], + [20, 62], + [20, 83], + [21, 62], + [21, 83], + [22, 62], + [22, 83], + [23, 62], + [23, 83], + [24, 62], + [24, 83], + [25, 62], + [25, 83], + [26, 62], + [26, 83], + [27, 62], + [27, 83], + [28, 62], + [28, 83], + [29, 62], + [29, 83], + [30, 62], + [30, 83], + [31, 62], + [31, 83], + [32, 62], + [32, 83], + [33, 62], + [33, 82], + [34, 62], + [34, 81], + [34, 109], + [35, 62], + [35, 80], + [35, 109], + [36, 62], + [36, 79], + [37, 63], + [37, 77], + [37, 78], + [38, 64], + [38, 75], + [38, 76], + [39, 65], + [39, 73], + [39, 74], + [40, 66], + [40, 72], + [41, 67], + [41, 71], + [42, 68], + [42, 69], + [42, 70], + [43, 69], + [44, 69], + [45, 68], + [46, 68], + [47, 68], + [48, 68], + [49, 68], + [50, 67], + [51, 67], + [52, 66], + [53, 64], + [53, 65], + [54, 62], + [54, 63], + [55, 61], + [56, 60], + [57, 60], + [58, 61], + [59, 61], + [60, 61], + [61, 61], + [62, 61], + [63, 61], + [64, 61], + [65, 61], + [66, 61], + [67, 61], + [68, 61], + [69, 61], + [70, 61], + [71, 61], + [72, 61], + [73, 61], + [74, 61], + [75, 61], + [76, 61], + [77, 61], + [78, 61], + [79, 61], + [80, 61], + [81, 62], + [82, 63], + [82, 64], + [83, 65], + [83, 66], + [84, 67], + [84, 68], + [85, 69], + [85, 70], + [86, 71], + [86, 72], + [87, 73], + [87, 74], + [87, 75], + [87, 76], + [87, 77], + [87, 78], + [87, 79], + [87, 80], + [88, 81], + [89, 82], + [90, 83], + [91, 84], + [91, 85], + [91, 86], + [91, 87], + [91, 88], + [91, 89], + [91, 90], + [91, 91], + [91, 92], + [91, 93], + [91, 94], + [91, 95], + [91, 96], + [91, 97], + [91, 98], + [91, 99], + [91, 100], + [91, 101], + [92, 102], + [93, 102], + [94, 102], + [95, 102], + [96, 102], + [97, 102], + [98, 102], + [99, 102], + [100, 102], + [101, 102], + [102, 102], + [103, 103], + [104, 104], + [105, 105], + [106, 106], + [107, 107], + [108, 108], + [109, 109], + [110, 110], + [110, 111], + [110, 112], + [110, 113], + [110, 114], + [111, 115], + [112, 115], + [113, 115], + [114, 115], + [115, 115], + [116, 114], + ] + ), + id="skeleton linear 3", + # marks=pytest.mark.skip(), + ), + pytest.param( + "pruning_skeleton", + 10, + 90, + "min", + "abs", + np.asarray([[4, 28], [6, 26], [20, 11]]), + np.asarray( + [ + [3, 21], + [3, 22], + [3, 23], + [3, 24], + [3, 25], + [3, 26], + [4, 20], + [4, 27], + [4, 28], + [5, 19], + [6, 18], + [7, 17], + [8, 17], + [9, 16], + [10, 16], + [11, 15], + [12, 15], + [13, 14], + [14, 13], + [15, 13], + [16, 12], + [17, 12], + [18, 12], + [19, 12], + [20, 11], + ] + ), + id="Linear array with two forks at one end", + marks=pytest.mark.skip( + reason="two branches at the end does not have 'nibs' removed by the rm_nibs() " + "function, instead a T-shaped junction remains. This is only removed by calling " + "getSkeleton(method='zhang').get_skeleton() which is done under the " + "prune_all_skeletons() method but this method is is to be removed since looping over " + "all skeletons is outside of the scope/remit of the code and handled in process_scan()." + ), + ), + ], +) +class TestTopoStatsPrune: + """Tests of topostatsPrune() class.""" + + def topostats_pruner( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + request, + ) -> None: + """Instantiate a topostatsPrune object.""" + img_skeleton = request.getfixturevalue(img_skeleton) + return topostatsPrune( + img_skeleton["img"], + img_skeleton["skeleton"], + 1, + max_length, + height_threshold, + method_values, + method_outliers, + ) + + def test_find_branch_ends( + self, + img_skeleton: dict, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + target_ends: npt.NDArray, + target_pruned_coords: npt.NDArray, # pylint: disable=unused-argument + request, + ) -> None: + """ + Test of topostats_find_branch_ends() method of topostatsPrune class. + + Currently have to convert the coordinates of the skeleton to a list otherwise + genTracingFuncs.count_and_get_neighbours() always returns 8 as assessing whether the coordinates which are a + list are within the 2D Numpy array always returns True Once tests are in place we can look at refactoring all + these classes to work with Numpy arrays rather than flipping back and forth between Numpy arrays and lists as is + the current situation. (Took 2 hrs to work this out!) + """ + pruner = self.topostats_pruner( + img_skeleton, + max_length, + height_threshold, + method_values, + method_outliers, + request, + ) + coords = np.argwhere(request.getfixturevalue(img_skeleton)["skeleton"] == 1).tolist() + ends = pruner._find_branch_ends(coords) + np.testing.assert_array_equal(ends, target_ends) + + def test_prune_by_length( + self, + img_skeleton: dict, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + target_ends: npt.NDArray, # pylint: disable=unused-argument + target_pruned_coords: npt.NDArray, + request, + ) -> None: + """ + Test of topostats_prune_by_length() method of topostatsPrune class. + + Relies on rm_nib() and tests and currently the below tests fail and nibs appear _not_ to be removed. + + This is a large function which does a lot and should probably be refactored into smaller, the branch checking + could be a method of its own. + """ + pruner = self.topostats_pruner( + img_skeleton, + max_length, + height_threshold, + method_values, + method_outliers, + request, + ) + pruned_skeleton = pruner._prune_by_length(pruner.skeleton, pruner.max_length) + pruned_coords = np.argwhere(pruned_skeleton == 1) + np.testing.assert_array_equal(pruned_coords, target_pruned_coords) + + +@pytest.mark.parametrize( + ( + "img_skeleton", + "max_length", + "height_threshold", + "method_values", + "method_outliers", + "target_skeleton", + ), + [ + pytest.param( + "pruning_skeleton", + None, + None, + "min", + "abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.uint8, + ), + id="Pruning by length and height disabled", + # marks=pytest.mark.skip(), + ), + pytest.param( + "pruning_skeleton", + 25, + 90, + "mid", + "abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.uint8, + ), + id="Length pruning enabled (25) removes everything bar the branching node.", + # marks=pytest.mark.skip(), + ), + pytest.param( + "pruning_skeleton", + None, + 5.0e-19, + "min", + "abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.uint8, + ), + id="Height pruning based on minimum, height threshold 5.0e-19", + # marks=pytest.mark.skip(), + ), + pytest.param( + "pruning_skeleton", + None, + 8.0e-19, + "median", + "abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.uint8, + ), + id="Height pruning based on median, height threshold 8.0e-19", + # marks=pytest.mark.skip(), + ), + pytest.param( + "pruning_skeleton", + None, + 7.7e-19, + "mid", + "abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.uint8, + ), + id="Height pruning based on mid(dle), height threshold 7.7e-19", + # marks=pytest.mark.skip(), + ), + pytest.param( + "pruning_skeleton", + None, + 1.0e-19, + "min", + "mean_abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.uint8, + ), + id="Height pruning based on minimum, height threshold mean - threshold (1.0e-19) difference", + # marks=pytest.mark.skip(), + ), + pytest.param( + "pruning_skeleton", + None, + 1.0e-19, + "min", + "iqr", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.uint8, + ), + id="Height pruning based on minimum, height threshold lower quartile - 1.5 x Interquartile range.", + # marks=pytest.mark.skip(), + ), + ], +) +class TestTopoStatsPruneMethods: + """Tests of topostatsPrune() class.""" + + def topostats_pruner( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + request, + ) -> topostatsPrune: + """Instantiate a topostatsPrune object.""" + img_skeleton = request.getfixturevalue(img_skeleton) + return topostatsPrune( + img_skeleton["img"], + img_skeleton["skeleton"], + 1, + max_length, + height_threshold, + method_values, + method_outliers, + ) + + def test_prune_skeleton( + self, + img_skeleton: dict, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + target_skeleton: npt.NDArray, + request, + ) -> None: + """Test of topostats_prune_all_skeletons() method of topostatsPrune class.""" + pruner = self.topostats_pruner( + img_skeleton, + max_length, + height_threshold, + method_values, + method_outliers, + request, + ) + pruned_skeleton = pruner.prune_skeleton() + np.testing.assert_array_equal(pruned_skeleton, target_skeleton) + + +# Tests for heightPruning class +@pytest.mark.parametrize( + ( + "img_skeleton", + "max_length", + "height_threshold", + "method_values", + "method_outliers", + "convolved_skeleton_target", + "segmented_skeleton_target", + "labelled_skeleton_target", + "branch_mins_target", + "branch_medians_target", + "branch_middles_target", + "abs_thresh_idx_target", + "mean_abs_thresh_idx_target", + "iqr_thresh_idx_target", + "check_skeleton_one_object_target", + "height_prune_target", + ), + [ + pytest.param( + { + "skeleton": np.asarray( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ] + ), + "img": np.asarray( + [ + [0, 0, 0, 0, 0, 0], + [0, 9, 0, 0, 0, 0], + [0, 9, 0, 0, 0, 0], + [0, 10, 0, 0, 0, 0], + [0, 0, 11, 8, 8, 0], + [0, 13, 0, 0, 0, 0], + [0, 7, 0, 0, 0, 0], + [0, 7, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ] + ), + }, + 10, + 9, + "min", + "abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0], + [0, 2, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 3, 1, 2, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 2, 2, 0], + [0, 3, 0, 0, 0, 0], + [0, 3, 0, 0, 0, 0], + [0, 3, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ] + ), + np.asarray([9, 8, 7]), # branch_mins + np.asarray([9, 8, 7]), # branch_medians : Used in _abs_thresh_idx with threshold of 9 + np.asarray([9, 8, 7]), # branch_middles + np.asarray([2, 3]), # abs_thresh_idx : index + 1 in above noted array that are < 9 + np.asarray([2, 3]), # mean_abs_thresh_idx + np.asarray([]), # iqr_thresh_idx + False, + np.asarray( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ] + ), + id="basic T-shape", + ), + pytest.param( + { + "skeleton": np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + "img": np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 10, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 9, 0, 0, 0, 0, 0, 10, 0], + [0, 0, 0, 9, 0, 0, 0, 0, 8, 0], + [0, 0, 7, 0, 9, 0, 0, 0, 8, 0], + [0, 7, 0, 0, 0, 9, 0, 10, 0, 0], + [0, 7, 0, 0, 0, 0, 9, 0, 0, 0], + [0, 7, 0, 0, 0, 0, 0, 9, 0, 0], + [0, 7, 0, 0, 0, 0, 0, 0, 9, 0], + [0, 10, 0, 0, 0, 0, 0, 0, 10, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + }, + 10, + 9, + "min", + "abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 2, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 2, 0], + [0, 0, 0, 3, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 3, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 2, 0, 0, 0, 0, 0, 0, 2, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 2, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 2, 0], + [0, 0, 3, 0, 4, 0, 0, 0, 2, 0], + [0, 3, 0, 0, 0, 4, 0, 2, 0, 0], + [0, 3, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 3, 0, 0, 0, 0, 0, 5, 0, 0], + [0, 3, 0, 0, 0, 0, 0, 0, 5, 0], + [0, 3, 0, 0, 0, 0, 0, 0, 5, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray([9, 8, 7, 9, 9]), # branch_mins + np.asarray([9.5, 9, 7, 9, 9]), # branch_medians : Used in _abs_thresh_idx with threshold of 9 + np.asarray([9.5, 8, 7, 9, 9]), # branch_middles + np.asarray([3]), # abs_thresh_idx : index + 1 in above noted array that are < 9 + np.asarray([3]), # mean_abs_thresh_idx + np.asarray([]), # iqr_thresh_idx + False, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="disjointed x-shape", + ), + pytest.param( + { + "skeleton": np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + "img": np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 9, 8, 8, 9, 0, 0, 0, 0, 0], + [0, 0, 9, 0, 0, 0, 0, 9, 0, 0, 0, 0], + [0, 9, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0], + [0, 9, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0], + [0, 0, 9, 7, 7, 9, 0, 0, 9, 0, 0, 0], + [0, 9, 0, 0, 0, 0, 0, 0, 9, 0, 6, 0], + [0, 9, 0, 0, 0, 0, 0, 0, 9, 0, 6, 0], + [0, 0, 9, 0, 0, 0, 0, 9, 0, 9, 0, 0], + [0, 0, 0, 9, 8, 8, 9, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + }, + 10, + 9, + "min", + "abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 3, 1, 1, 2, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 3, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 2, 2, 2, 0, 0, 1, 0, 0, 0], + [0, 3, 0, 0, 0, 0, 0, 0, 1, 0, 4, 0], + [0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0], + [0, 0, 3, 0, 0, 0, 0, 3, 0, 4, 0, 0], + [0, 0, 0, 3, 3, 3, 3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray([8, 7, 8, 6]), # branch_mins + np.asarray([9, 7, 9, 6]), # branch_medians : Used in _abs_thresh_idx with threshold of 9 + np.asarray([8.5, 7, 8.5, 6]), # branch_middles + np.asarray([2, 4]), # abs_thresh_idx : index + 1 in above noted array that are < 9 + np.asarray([2, 4]), # mean_abs_thresh_idx + np.asarray([4]), # iqr_thresh_idx + False, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="ring with branches", + # marks=pytest.mark.xfail(reason="splits ring in two places, should this be so?"), + ), + pytest.param( + { + "skeleton": np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 1, 0, 0], + [0, 0, 1, 0, 1, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + "img": np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 15, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 10, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 10, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 10, 9, 0, 0, 0, 14, 0, 0], + [0, 0, 10, 0, 9, 0, 8, 0, 0, 0], + [0, 0, 10, 0, 0, 9, 0, 0, 0, 0], + [0, 0, 10, 0, 0, 9, 0, 0, 0, 0], + [0, 0, 10, 0, 0, 9, 0, 0, 0, 0], + [0, 0, 10, 0, 0, 0, 12, 0, 0, 0], + [0, 0, 10, 0, 0, 0, 0, 12, 0, 0], + [0, 0, 0, 11, 0, 0, 0, 0, 12, 0], + [0, 0, 0, 10, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 10, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 10, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 10, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + }, + 10, + 10, + "min", + "abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 2, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 3, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 3, 3, 0, 0, 0, 2, 0, 0], + [0, 0, 3, 0, 1, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 3, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 2, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 2, 0, 0], + [0, 0, 0, 0, 3, 0, 2, 0, 0, 0], + [0, 0, 4, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 4, 0, 0, 5, 0, 0, 0, 0], + [0, 0, 4, 0, 0, 5, 0, 0, 0, 0], + [0, 0, 4, 0, 0, 0, 5, 0, 0, 0], + [0, 0, 4, 0, 0, 0, 0, 5, 0, 0], + [0, 0, 0, 4, 0, 0, 0, 0, 5, 0], + [0, 0, 0, 4, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 4, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 4, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 4, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray([10, 8, 9, 10, 9]), # branch_mins + np.asarray([12.5, 11.0, 9.0, 10.0, 12.0]), # branch_medians : Used in _abs_thresh_idx with threshold of 9 + np.asarray([12.5, 11, 9, 10.5, 12]), # branch_middles + np.asarray([3]), # abs_thresh_idx : index + 1 in above noted array that are < 10 + np.asarray([3]), # mean_abs_thresh_idx + np.asarray([2]), # iqr_thresh_idx + False, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="long straight skeleton with forked branch", + # marks=pytest.mark.xfail(reason="Not sure middles are correct, arbitrarily takes point left or right of" + # "even lengthed branches see region 1"), + ), + pytest.param( + { + "skeleton": np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + "img": np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 9, 9, 9, 0, 0], + [0, 9, 0, 0, 0, 9, 0], + [0, 9, 0, 0, 0, 9, 0], + [0, 0, 10, 5, 7, 0, 0], + [0, 9, 0, 0, 0, 9, 0], + [0, 9, 0, 0, 0, 9, 0], + [0, 0, 9, 9, 9, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + }, + 10, + 9, + "min", + "abs", + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 3, 1, 3, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 0, 2, 0, 0, 0], + [0, 3, 0, 0, 0, 3, 0], + [0, 3, 0, 0, 0, 3, 0], + [0, 0, 3, 3, 3, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray([9, 5, 9]), # branch_mins + np.asarray( + [ + 9, + 5, + 9, + ] + ), # branch_medians : Used in _abs_thresh_idx with threshold of 9 + np.asarray([9, 5, 9]), # branch_middles + np.asarray([2]), # abs_thresh_idx : index + 1 in above noted array that are < 10 + np.asarray([2]), # mean_abs_thresh_idx + np.asarray([2]), # iqr_thresh_idx + False, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + id="figure 8", + # marks=pytest.mark.xfail(reason="Not sure middles are correct, arbitrarily takes point left or gith of even + # lengthed branches see region 1"), + ), + ], +) +class TestHeightPruningBasic: + """Tests for heightPruning() class using very basic shapes.""" + + # pylint: disable=unused-argument + # pylint: disable=too-many-locals + def topostats_height_pruner( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + ) -> None: + """Instantiate a topostatsPrune object.""" + return heightPruning( + img_skeleton["img"], + img_skeleton["skeleton"], + max_length, + height_threshold, + method_values, + method_outliers, + ) + + def test_convolve_skeleton( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """Test convolve_skeleton() method of heightPruning class.""" + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + np.testing.assert_array_equal(height_pruning.skeleton_convolved, convolved_skeleton_target) + + def test_segment_skeleton( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """Test segment_skeleton() method of heightPruning class.""" + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + np.testing.assert_array_equal(height_pruning.skeleton_branches, segmented_skeleton_target) + + def test_label_branches( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """Test label_branches() method of HeightPruning class.""" + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + np.testing.assert_array_equal(height_pruning.skeleton_branches_labelled, labelled_skeleton_target) + + def test_get_branch_mins( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """Test of get_branch_mins() method of heightPruning class.""" + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + branch_mins = height_pruning._get_branch_mins(height_pruning.skeleton_branches_labelled) + np.testing.assert_array_equal(branch_mins, branch_mins_target) + + def test_get_branch_medians( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """Test of get_branch_medians() method of heightPruning class.""" + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + branch_medians = height_pruning._get_branch_medians(height_pruning.skeleton_branches_labelled) + np.testing.assert_array_equal(branch_medians, branch_medians_target) + + def test_get_branch_middles( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """ + Test of get_branch_middles() method of heightPruning class. + + NB - Surprised that when the branch has an even number of points in it that the median height of the two middle + points is not returned. + """ + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + branch_middles = height_pruning._get_branch_middles(height_pruning.skeleton_branches_labelled) + np.testing.assert_array_equal(branch_middles, branch_middles_target) + + def test_get_abs_thresh_idx( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """Test of get_abs_thresh_idx(self) method of heightPruning class.""" + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + abs_thresh_idx = height_pruning._get_abs_thresh_idx(branch_medians_target, height_pruning.height_threshold) + np.testing.assert_array_equal(abs_thresh_idx, abs_thresh_idx_target) + + def test_get_mean_abs_thresh_idx( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """Test of get_mean_abs_thresh_idx() method of heightPruning class.""" + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + mean_abs_thresh_idx = height_pruning._get_mean_abs_thresh_idx( + branch_medians_target, + height_pruning.height_threshold / 9, + height_pruning.image, + height_pruning.skeleton, + ) + np.testing.assert_array_equal(mean_abs_thresh_idx, mean_abs_thresh_idx_target) + + def test_get_iqr_thresh_idx( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """Test of get_iqr_thresh_idx() method of heightPruning class.""" + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + iqr_thresh_idx = height_pruning._get_iqr_thresh_idx( + height_pruning.image, height_pruning.skeleton_branches_labelled + ) + np.testing.assert_array_equal(iqr_thresh_idx, iqr_thresh_idx_target) + + def test_check_skeleton_one_object( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """Test of check_skeleton_one_object() method of heightPruning class.""" + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + check_skeleton_one_object = height_pruning.check_skeleton_one_object(height_pruning.skeleton_branches_labelled) + assert check_skeleton_one_object == check_skeleton_one_object_target + + # @pytest.mark.skip(reason="No actual test yet!") + # def test_filter_segments( + # self, + # img_skeleton: str, + # max_length: float, + # height_threshold: float, + # method_values: str, + # method_outliers: str, + # convolved_skeleton_target: npt.NDArray, + # segmented_skeleton_target: npt.NDArray, + # labelled_skeleton_target: npt.NDArray, + # branch_mins_target: npt.NDArray, + # branch_medians_target: npt.NDArray, + # branch_middles_target: npt.NDArray, + # abs_thresh_idx_target: npt.NDArray, + # mean_abs_thresh_idx_target: npt.NDArray, + # iqr_thresh_idx_target: npt.NDArray, + # check_skeleton_one_object_target: bool, + # remove_bridges_target: npt.NDArray, + # ) -> None: + # """Test of filter_segments() method of heightPruning class.""" + + def test_height_prune( + self, + img_skeleton: str, + max_length: float, + height_threshold: float, + method_values: str, + method_outliers: str, + convolved_skeleton_target: npt.NDArray, + segmented_skeleton_target: npt.NDArray, + labelled_skeleton_target: npt.NDArray, + branch_mins_target: npt.NDArray, + branch_medians_target: npt.NDArray, + branch_middles_target: npt.NDArray, + abs_thresh_idx_target: npt.NDArray, + mean_abs_thresh_idx_target: npt.NDArray, + iqr_thresh_idx_target: npt.NDArray, + check_skeleton_one_object_target: bool, + height_prune_target: npt.NDArray, + ) -> None: + """Test of remove_bridges() method of heightPruning class.""" + height_pruning = self.topostats_height_pruner( + img_skeleton, max_length, height_threshold, method_values, method_outliers + ) + height_prune = height_pruning.height_prune() + np.testing.assert_array_equal(height_prune, height_prune_target) + + +# @pytest.mark.parametrize( +# ("img_skeleton", "max_length", "height_threshold", "method_values", "method_outliers"), +# [pytest.param("skeleton_loop1", id="skeleton loop1")], +# ) +# class TestHeightPruningImages: +# """Tests for heightPruning() class using dummy skeletons and heights.""" + +# img_skeleton = request.getfixturevalue(img_skeleton) +# img = img_skeleton["height"] +# skeleton = img_skeleton["skeleton"] +# self.height_pruning = heightPruning(img, skeleton, height_threshold, method_values, method_outlier) + +# @pytest.mark.skip(reason="awaiting test development") +# def test_get_branch_mins(self) -> None: +# """Test of conve_get_branch_mins() method of heightPruning class.""" + +# @pytest.mark.skip(reason="awaiting test development") +# def test_get_branch_medians(self) -> None: +# """Test of conv_get_branch_medians() method of heightPruning class.""" + +# @pytest.mark.skip(reason="awaiting test development") +# def test_get_branch_middles(self) -> None: +# """Test of conv_get_branch_middles() method of heightPruning class.""" + +# @pytest.mark.skip(reason="awaiting test development") +# def test_get_abs_thresh_idx(self) -> None: +# """Test of conv_get_abs_thresh_idx(self) method of heightPruning class.""" + +# @pytest.mark.skip(reason="awaiting test development") +# def test_get_mean_abs_thresh_idx(self) -> None: +# """Test of conv_get_mean_abs_thresh_idx() method of heightPruning class.""" + +# @pytest.mark.skip(reason="awaiting test development") +# def test_get_iqr_thresh_idx(self) -> None: +# """Test of conv_get_iqr_thresh_idx() method of heightPruning class.""" + +# @pytest.mark.skip(reason="awaiting test development") +# def test_check_skeleton_one_object(self) -> None: +# """Test of conv_check_skeleton_one_object() method of heightPruning class.""" + +# @pytest.mark.skip(reason="awaiting test development") +# def test_remove_bridges(self) -> None: +# """Test of remove_bridges() method of heightPruning class.""" + + +@pytest.mark.parametrize( + ("skeleton_nodeless", "start", "max_length", "order_branch_target"), + [ + pytest.param( + np.asarray( + [ + [0, 0, 0], + [0, 2, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 0, 0], + ] + ), + (1, 1), + 9, + np.asarray( + [ + [1, 1], + [2, 1], + [3, 1], + [4, 1], + [5, 1], + ] + ), + id="single vertical branch", + ), + pytest.param( + np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 2, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 1, 0], + [0, 0, 0, 0, 0], + ] + ), + (1, 1), + 9, + np.asarray( + [ + [1, 1], + [2, 1], + [3, 1], + [4, 1], + [5, 2], + [5, 3], + ] + ), + id="single bent branch, start top left", + ), + pytest.param( + np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 2, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 1, 1, 0, 0], + [0, 0, 0, 0, 0], + ] + ), + (1, 3), + 9, + np.asarray( + [ + [1, 3], + [2, 3], + [3, 3], + [4, 3], + [5, 2], + [5, 1], + ] + ), + id="single bent branch, start top right", + ), + pytest.param( + np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 2, 1, 0, 0], + [0, 0, 0, 0, 0], + ] + ), + (5, 1), + 9, + np.asarray( + [ + [5, 1], + [5, 2], + [4, 3], + [3, 3], + [2, 3], + [1, 3], + ] + ), + id="single bent branch, start bottom left", + ), + pytest.param( + np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 2, 1, 0, 0], + [0, 0, 0, 0, 0], + ] + ), + (5, 1), + 9, + np.asarray( + [ + [5, 1], + [5, 2], + [4, 3], + [3, 3], + [2, 3], + [1, 2], + [2, 1], + ] + ), + id="single curled branch, start bottom right", + ), + ], +) +def test_order_branch_from_end( + skeleton_nodeless: npt.NDArray, start: list, max_length: int, order_branch_target: npt.NDArray +) -> None: + """Test of order_branch_from_start() function.""" + order_branch = order_branch_from_end(skeleton_nodeless, start, max_length) + np.testing.assert_array_equal(order_branch, order_branch_target) + + +# Tests for pruneSkeleton class +# @pytest.mark.parametrize( +# ("img_skeleton", "max_length", "height_threshold", "method_values", "method_outliers"), +# [pytest.param("skeleton_loop1", 10, 90, "min", "abs", id="skeleton loop1")], +# ) +# class TestPruneSkeleton: +# img_skeleton = request.getfixturevalue(img_skeleton) +# img = img_skeleton["height"] +# skeleton = img_skeleton["skeleton"] +# topostats_pruner = pruneSkeleton(img, skeleton, height_threshold, method_values, method_outlier) + +# @pytest.mark.skip(reason="awaiting test development") +# def test_prune_skeleton() -> None: +# """Test of prune_skeleton() method of pruneSkeleton class.""" + +# @pytest.mark.skip(reason="awaiting test development") +# def test_prune_method() -> None: +# """Test of prune_method() method of pruneSkeleton class.""" + +# @pytest.mark.skip(reason="awaiting test development") +# def test_prune_topostats() -> None: +# """Test of prune_topostats() method of pruneSkeleton class.""" + +# @pytest.mark.skip(reason="awaiting test development") +# def test_prune_conv() -> None: +# """Test of prune_conv() method of pruneSkeleton class.""" + + +@pytest.mark.parametrize( + ("img", "target"), + [ + pytest.param( + np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 1, 0], + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ), + id="linear with single 1-pixel branch", + ), + pytest.param( + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0], + [0, 1, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 1, 0, 1, 0], + [0, 1, 0, 0, 1, 0, 0], + [0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0], + [0, 1, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 1, 0, 0], + [0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + id="circular with single 1-pixel branch internally", + ), + ], +) +def test_rm_nibs(img: npt.NDArray, target: npt.NDArray) -> None: + """Test of rm_nibs() function.""" + clean_skeleton = rm_nibs(img) + np.testing.assert_array_equal(clean_skeleton, target) + + +@pytest.mark.parametrize( + ("img", "point", "local_pixels_target", "local_pixels_sum_target"), + [ + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [1, 1], + np.asarray([1, 1, 1, 1, 0, 1, 1, 1, 1]), + 8, + id="3x3 binary all 1's; point as list", + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + (1, 1), + np.asarray([1, 1, 1, 1, 0, 1, 1, 1, 1]), + 8, + id="3x3 binary all 1's; point as tuple", + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + np.asarray([1, 1]), + np.asarray([1, 1, 1, 1, 0, 1, 1, 1, 1]), + 8, + id="3x3 binary all 1's; point as npt.NDArray", + ), + pytest.param( + np.asarray([[0, 1, 0], [1, 1, 1], [0, 1, 0]]), + [1, 1], + np.asarray([0, 1, 0, 1, 0, 1, 0, 1, 0]), + 4, + id="3x3 binary 0's and 1's 50/50; point as list", + ), + pytest.param( + np.asarray([[0, 1, 1], [1, 0, 1], [1, 1, 0]]), + [1, 1], + np.asarray([0, 1, 1, 1, 0, 1, 1, 1, 0]), + 6, + id="3x3 binary 0's and 1's 30/70; point as list", + ), + pytest.param( + np.asarray([[1, 2, 1], [1, 1, 1], [1, 1, 1]]), + [1, 1], + np.asarray([1, 2, 1, 1, 0, 1, 1, 1, 1]), + 9, + id="3x3 non-binary array; point as list", + marks=pytest.mark.xfail(reason="Array is not binary."), + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [0, 1], + np.asarray([0, 1, 1, 1, 0, 1, 1, 1, 0]), + 6, + id="3x3 1's; point on top edge", + marks=pytest.mark.xfail(reason="Point on top edge of image."), + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [1, 0], + np.asarray([0, 1, 1, 1, 0, 1, 1, 1, 0]), + 6, + id="3x3 1's; point on left edge", + marks=pytest.mark.xfail(reason="Point on left edge of image."), + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [1, 2], + np.asarray([0, 1, 1, 1, 0, 1, 1, 1, 0]), + 6, + id="3x3 1's; point on right edge", + marks=pytest.mark.xfail(reason="Point on right edge of image."), + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [2, 1], + np.asarray([0, 1, 1, 1, 0, 1, 1, 1, 0]), + 6, + id="3x3 1's; point on bottom edge", + marks=pytest.mark.xfail(reason="Point on bottom edge of image."), + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [2, 2], + np.asarray([0, 1, 1, 1, 0, 1, 1, 1, 0]), + 6, + id="3x3 1's; point on corner", + marks=pytest.mark.xfail(reason="Point on corner of image."), + ), + ], +) +def test_local_area_sum( + img: npt.NDArray, + point: list | tuple | npt.NDArray, + local_pixels_target: npt.NDArray, + local_pixels_sum_target: int, +) -> None: + """Test of local_area_sum() function.""" + local_pixels, local_pixels_sum = local_area_sum(img, point) + np.testing.assert_array_equal(local_pixels, local_pixels_target) + assert local_pixels_sum == local_pixels_sum_target + + +@pytest.mark.parametrize( + ("img", "point", "exception"), + [ + pytest.param( + np.asarray([[1, 2, 1], [1, 1, 1], [1, 1, 1]]), + [1, 1], + ValueError, + id="3x3 non-binary array in cells adjacent to point", + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 2, 1], [1, 1, 1]]), + [1, 1], + ValueError, + id="3x3 non-binary array in point", + ), + ], +) +def test_local_area_sum_value_error(img: npt.NDArray, point: list | tuple | npt.NDArray, exception: ValueError) -> None: + """Test local_area_sum() function raises error if non-binary array is passed.""" + with pytest.raises(exception): + local_area_sum(img, point) + + +@pytest.mark.parametrize( + ("img", "point", "exception"), + [ + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [0, 1], + IndexError, + id="3x3 1's; point on top edge", + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [1, 0], + IndexError, + id="3x3 1's; point on left edge", + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [1, 2], + IndexError, + id="3x3 1's; point on right edge", + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [2, 1], + IndexError, + id="3x3 1's; point on bottom edge", + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [2, 2], + IndexError, + id="3x3 1's; point on top left corner", + ), + pytest.param( + np.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), + [2, 2], + IndexError, + id="3x3 1's; point on bottom right corner", + ), + ], +) +def test_local_area_sum_index_error(img: npt.NDArray, point: list | tuple | npt.NDArray, exception: IndexError) -> None: + """Test local_area_sum() function raises error if point is on edge of array.""" + with pytest.raises(exception): + local_area_sum(img, point) diff --git a/tests/tracing/test_skeletonize.py b/tests/tracing/test_skeletonize.py index d7512b60ac..6ec062cbf6 100644 --- a/tests/tracing/test_skeletonize.py +++ b/tests/tracing/test_skeletonize.py @@ -1,209 +1,660 @@ """Test the skeletonize module.""" import numpy as np +import numpy.typing as npt import pytest -# pytest: disable= +from topostats.tracing.skeletonize import getSkeleton, topostatsSkeletonize + +# pylint: disable=unnecessary-pass # pytest: disable=import-error -from topostats.tracing.skeletonize import get_skeleton +# pylint: disable=too-many-positional-arguments + + +def test_skeletonize_method(skeletonize_get_skeleton: getSkeleton) -> None: + """Test unsupported method raises the appropriate error.""" + skeletonize_get_skeleton.method = "nonsense" + with pytest.raises(ValueError): # noqa: PT011 + skeletonize_get_skeleton.get_skeleton() + -CIRCULAR_TARGET = np.array( +@pytest.mark.parametrize( + ("image", "mask", "method", "height_bias", "shape", "array_sum", "target"), [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] + pytest.param( + "skeletonize_circular", + "skeletonize_circular_bool_int", + "zhang", + 0.6, + (21, 21), + 36, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Zhang, circular", + ), + pytest.param( + "skeletonize_circular", + "skeletonize_circular_bool_int", + "lee", + 0.6, + (21, 21), + 36, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Lee, circular", + ), + pytest.param( + "skeletonize_circular", + "skeletonize_circular_bool_int", + "medial_axis", + 0.6, + (21, 21), + 48, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Medial axis, circular", + ), + pytest.param( + "skeletonize_circular", + "skeletonize_circular_bool_int", + "thin", + 0.6, + (21, 21), + 28, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Thin, circular", + ), + pytest.param( + "skeletonize_circular", + "skeletonize_circular_bool_int", + "topostats", + 0.6, + (21, 21), + 36, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="TopoStats, circular, height_bias 0.6", + ), + pytest.param( + "skeletonize_circular", + "skeletonize_circular_bool_int", + "topostats", + 0.1, + (21, 21), + 36, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ), + id="TopoStats, circular, height_bias 0.1", + ), + pytest.param( + "skeletonize_circular", + "skeletonize_circular_bool_int", + "topostats", + 0.9, + (21, 21), + 36, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="TopoStats, circular, height_bias 0.9", + ), + pytest.param( + "skeletonize_linear", + "skeletonize_linear_bool_int", + "zhang", + 0.6, + (24, 20), + 20, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Zhang, linear", + ), + pytest.param( + "skeletonize_linear", + "skeletonize_linear_bool_int", + "lee", + 0.6, + (24, 20), + 17, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Lee, linear", + ), + pytest.param( + "skeletonize_linear", + "skeletonize_linear_bool_int", + "medial_axis", + 0.6, + (24, 20), + 35, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Medial axis, linear", + ), + pytest.param( + "skeletonize_linear", + "skeletonize_linear_bool_int", + "thin", + 0.6, + (24, 20), + 11, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="Thin, linear", + ), + pytest.param( + "skeletonize_linear", + "skeletonize_linear_bool_int", + "topostats", + 0.6, + (24, 20), + 17, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="TopoStats, linear, height_bias 0.6", + ), + pytest.param( + "skeletonize_linear", + "skeletonize_linear_bool_int", + "topostats", + 0.1, + (24, 20), + 13, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="TopoStats, linear, height_bias 0.1", + ), + pytest.param( + "skeletonize_linear", + "skeletonize_linear_bool_int", + "topostats", + 0.9, + (24, 20), + 14, + np.asarray( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="TopoStats, linear, height_bias 0.9", + ), + ], ) +def test_get_skeleton( # pylint: disable=too-many-arguments + skeletonize_get_skeleton: getSkeleton, + image: npt.NDArray, + mask: npt.NDArray, + method: str, + height_bias: float, + shape: tuple, + array_sum: int, + target: npt.NDArray, + request, +) -> None: + """Test the getSkeleton.get_skeleton() method.""" + skeletonize_get_skeleton.image = request.getfixturevalue(image) + skeletonize_get_skeleton.mask = request.getfixturevalue(mask) + skeletonize_get_skeleton.method = method + skeletonize_get_skeleton.height_bias = height_bias + skeleton = skeletonize_get_skeleton.get_skeleton() + assert isinstance(skeleton, np.ndarray) + assert skeleton.ndim == 2 + assert skeleton.shape == shape + assert skeleton.sum() == array_sum + np.testing.assert_array_equal(skeleton, target) -def test_skeletonize_method(skeletonize_circular_bool_int: np.ndarray) -> None: - """Test unsupported method raises the appropriate error.""" - with pytest.raises(ValueError): # noqa: PT011 - get_skeleton(skeletonize_circular_bool_int, method="nonsense") - - -def test_skeletonize_circular_zhang(skeletonize_circular_bool_int: np.ndarray) -> None: - """Test the Zhang method of skeletionzation on a circular object.""" - test = get_skeleton(skeletonize_circular_bool_int, method="zhang").astype(int) - assert isinstance(test, np.ndarray) - assert test.ndim == 2 - assert test.shape == (21, 21) - assert test.sum() == 36 - np.testing.assert_array_equal(test, CIRCULAR_TARGET) - - -def test_skeletonize_circular_lee(skeletonize_circular_bool_int: np.ndarray) -> None: - """Test the Lee method of skeletonization on a circular object.""" - test = get_skeleton(skeletonize_circular_bool_int, method="lee").astype(int) - assert isinstance(test, np.ndarray) - assert test.ndim == 2 - assert test.shape == (21, 21) - assert test.sum() == 36 - np.testing.assert_array_equal(test, CIRCULAR_TARGET) - - -def test_skeletonize_circular_thin(skeletonize_circular_bool_int: np.ndarray) -> None: - """Test the thin method of skeletonization on a circular object.""" - CIRCULAR_THIN_TARGET = np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ) - test = get_skeleton(skeletonize_circular_bool_int, method="thin").astype(int) - assert isinstance(test, np.ndarray) - assert test.ndim == 2 - assert test.shape == (21, 21) - assert test.sum() == 28 - np.testing.assert_array_equal(test, CIRCULAR_THIN_TARGET) - - -def test_skeletonize_linear_zhang(skeletonize_linear_bool_int: np.ndarray) -> None: - """Test the Zhang method of skeletonization on a linear object.""" - LINEAR_ZHANG_TARGET = np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ) - test = get_skeleton(skeletonize_linear_bool_int, method="zhang").astype(int) - assert isinstance(test, np.ndarray) - assert test.ndim == 2 - assert test.shape == (24, 20) - assert test.sum() == 20 - np.testing.assert_array_equal(test, LINEAR_ZHANG_TARGET) - - -def test_skeletonize_linear_lee(skeletonize_linear_bool_int: np.ndarray) -> None: - """Test the Lee method of skeletonization on a linear object.""" - LINEAR_LEE_TARGET = np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ) - test = get_skeleton(skeletonize_linear_bool_int, method="lee").astype(int) - assert isinstance(test, np.ndarray) - assert test.ndim == 2 - assert test.shape == (24, 20) - assert test.sum() == 17 - np.testing.assert_array_equal(test, LINEAR_LEE_TARGET) - - -def test_skeletonize_linear_thin(skeletonize_linear_bool_int: np.ndarray) -> None: - """Test the thin method of skeletonization on a linear object.""" - LINEAR_THIN_TARGET = np.array( - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ) - test = get_skeleton(skeletonize_linear_bool_int, method="thin").astype(int) - assert isinstance(test, np.ndarray) - assert test.ndim == 2 - assert test.shape == (24, 20) - assert test.sum() == 11 - np.testing.assert_array_equal(test, LINEAR_THIN_TARGET) +# Tests for topopstatsSkeletonize class +def test_do_skeletonising_iteration() -> None: + """Test of do skeletonising iteration.""" + pass + + +def test_delete_pixel_subit1() -> None: + """Test of method.""" + pass + + +def test_delete_pixel_subit2() -> None: + """Test of method.""" + pass + + +def test_binary_thin_check_a() -> None: + """Test of method.""" + pass + + +def test_binary_thin_check_b_returncount() -> None: + """Test of method.""" + pass + + +def test_binary_thin_check_c() -> None: + """Test of method.""" + pass + + +def test_binary_thin_check_d() -> None: + """Test of method.""" + pass + + +def test_binary_thin_check_csharp() -> None: + """Test of method.""" + pass + + +def test_binary_thin_check_dsharp() -> None: + """Test of method.""" + pass + + +def test_final_skeletonisation_iteration() -> None: + """Test of method.""" + pass + + +def test_binary_final_thin_check_a() -> None: + """Test of method.""" + pass + + +def test_binary_final_thin_check_b() -> None: + """Test of method.""" + pass + + +def test_binary_thin_check_diag() -> None: + """Test of method.""" + pass + + +def test_get_local_pixels_binary() -> None: + """Test of method.""" + pass + + +def test_do_skeletonising() -> None: + """Test of method.""" + pass + + +# Miscellaneous functions +def test_order_branch_from_start() -> None: + """Test of order_branch_from_start().""" + pass + + +def test_rm_nibs() -> None: + """Test of rm_nibs().""" + pass + + +def test_local_area_sum() -> None: + """Test local_area_sum().""" + pass + + +@pytest.mark.parametrize( + ("array", "seed", "target_array", "target_indicies"), + [ + ( + np.array([1, 1, 1, 1, 3, 3, 2, 1]), + 2024, + np.array([1, 1, 1, 1, 1, 2, 3, 3]), + np.array([3, 7, 1, 2, 0, 6, 4, 5]), + ) + ], +) +def test_sort_and_shuffle( + topostats_skeletonise: topostatsSkeletonize, + array: npt.NDArray, + seed: int, + target_array: npt.NDArray, + target_indicies: npt.NDArray, +) -> None: + """Tests the topostatsSkeletonize.sort_and_shuffle() method.""" + sort_and_shuffled_array, sort_and_shuffled_indicies = topostats_skeletonise.sort_and_shuffle(array, seed) + np.testing.assert_array_equal(sort_and_shuffled_array, target_array) + np.testing.assert_array_equal(sort_and_shuffled_indicies, target_indicies) diff --git a/tests/tracing/test_splining.py b/tests/tracing/test_splining.py new file mode 100644 index 0000000000..9a6502be93 --- /dev/null +++ b/tests/tracing/test_splining.py @@ -0,0 +1,434 @@ +# Disable ruff 301 - pickle loading is unsafe, but we don't care for tests +# ruff: noqa: S301 +"""Test the splining module.""" + +import pickle +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import pandas as pd +import pytest + +from topostats.io import dict_almost_equal +from topostats.tracing.splining import splineTrace, splining_image, windowTrace + +BASE_DIR = Path.cwd() +GENERAL_RESOURCES = BASE_DIR / "tests" / "resources" +SPLINING_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "splining" +ORDERED_TRACING_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "ordered_tracing" + +# pylint: disable=unspecified-encoding +# pylint: disable=too-many-locals +# pylint: disable=too-many-arguments +# pylint: disable=too-many-positional-arguments + +PIXEL_TRACE = np.array( + [[0, 0], [0, 1], [0, 2], [0, 3], [1, 3], [2, 3], [3, 3], [3, 2], [3, 1], [3, 0], [2, 0], [1, 0]] +).astype(np.int32) + + +def plot_spline_debugging( + image: npt.NDArray[np.float32], + result_all_splines_data: dict, + pixel_to_nm_scaling: float, +) -> None: + """ + Plot splines of an image overlaid on the image. + + Used for debugging changes to the splining code & visually ensuring the splines are correct. + + Parameters + ---------- + image : npt.NDArray[np.float32] + Image to plot the splines on. + result_all_splines_data : dict + Dictionary containing the spline coordinates for each molecule. + pixel_to_nm_scaling : float + Pixel to nm scaling factor. + """ + _, ax = plt.subplots(figsize=(10, 10)) + ax.imshow(image, cmap="gray") + # Array of lots of matplotlib colours + lots_of_colours = [ + "blue", + "green", + "red", + "cyan", + "magenta", + "yellow", + "black", + "white", + "orange", + "purple", + ] + for grain_key_index, grain_key in enumerate(result_all_splines_data.keys()): + print(f"Grain key: {grain_key}") + for mol_key_index, mol_key in enumerate(result_all_splines_data[grain_key].keys()): + spline_coords: npt.NDArray[np.float32] = result_all_splines_data[grain_key][mol_key]["spline_coords"] + bbox = result_all_splines_data[grain_key][mol_key]["bbox"] + bbox_min_col = bbox[0] + bbox_min_row = bbox[1] + previous_point = spline_coords[0] + colour = lots_of_colours[mol_key_index + grain_key_index * 3 % len(lots_of_colours)] + for point in spline_coords[1:]: + ax.plot( + [ + previous_point[1] / pixel_to_nm_scaling + bbox_min_row, + point[1] / pixel_to_nm_scaling + bbox_min_row, + ], + [ + previous_point[0] / pixel_to_nm_scaling + bbox_min_col, + point[0] / pixel_to_nm_scaling + bbox_min_col, + ], + color=colour, + linewidth=2, + ) + previous_point = point + plt.show() + + +@pytest.mark.parametrize( + ("tuple_list", "expected_result"), + [ + ( + [(1, 2, 3), (1, 2, 3), (1, 2, 3), (1, 2, 3)], + [(1, 2, 3)], + ), + ( + [(1, 2, 3), (1, 2, 3), (4, 5, 6), (4, 5, 6), (7, 8, 9), (10, 11, 12), (10, 11, 12)], + [(1, 2, 3), (4, 5, 6), (7, 8, 9), (10, 11, 12)], + ), + ([np.array((1, 2, 3)), np.array((1, 2, 3)), np.array((1, 2, 3)), np.array((1, 2, 3))], [(1, 2, 3)]), + ], +) +def test_remove_duplicate_consecutive_tuples(tuple_list: list[tuple], expected_result: list[tuple]) -> None: + """Test the remove_duplicate_consecutive_tuples function of splining.py.""" + result = splineTrace.remove_duplicate_consecutive_tuples(tuple_list) + + np.testing.assert_array_equal(result, expected_result) + + +@pytest.mark.parametrize( + ( + "image_filename", + "ordered_tracing_direction_data_filename", + "pixel_to_nm_scaling", + "splining_method", + "rolling_window_size", + "spline_step_size", + "spline_linear_smoothing", + "spline_circular_smoothing", + "spline_degree", + "filename", + "expected_all_splines_data_filename", + "expected_splining_grainstats_filename", + "expected_molstats_filename", + ), + [ + pytest.param( + "example_catenanes.npy", + "catenanes_ordered_tracing_data.pkl", + 1.0, # pixel_to_nm_scaling + # Splining parameters + "rolling_window", # splining_method + 20e-9, # rolling_window_size + 7.0e-9, # spline_step_size + 5.0, # spline_linear_smoothing + 5.0, # spline_circular_smoothing + 3, # spline_degree + "catenane", # filename + "catenanes_splining_data.pkl", + "catenanes_splining_grainstats.csv", + "catenanes_splining_molstats.csv", + id="catenane", + ), + pytest.param( + "example_rep_int.npy", + "rep_int_ordered_tracing_data.pkl", + 1.0, # pixel_to_nm_scaling + # Splining parameters + "rolling_window", # splining_method + 20e-9, # rolling_window_size + 7.0e-9, # spline_step_size + 5.0, # spline_linear_smoothing + 5.0, # spline_circular_smoothing + 3, # spline_degree + "replication_intermediate", # filename + "rep_int_splining_data.pkl", + "rep_int_splining_grainstats.csv", + "rep_int_splining_molstats.csv", + id="replication_intermediate", + ), + ], +) +def test_splining_image( # pylint: disable=too-many-positional-arguments + image_filename: str, + ordered_tracing_direction_data_filename: str, + pixel_to_nm_scaling: float, + splining_method: str, + rolling_window_size: float, + spline_step_size: float, + spline_linear_smoothing: float, + spline_circular_smoothing: float, + spline_degree: int, + filename: str, + expected_all_splines_data_filename: str, + expected_splining_grainstats_filename: str, + expected_molstats_filename: str, +) -> None: + """Test the splining_image function of the splining module.""" + # Load the data + image = np.load(GENERAL_RESOURCES / image_filename) + + # Load the ordered tracing direction data + with Path.open(ORDERED_TRACING_RESOURCES / ordered_tracing_direction_data_filename, "rb") as file: + ordered_tracing_direction_data = pickle.load(file) + + result_all_splines_data, result_splining_grainstats, result_molstats_df = splining_image( + image=image, + ordered_tracing_direction_data=ordered_tracing_direction_data, + pixel_to_nm_scaling=pixel_to_nm_scaling, + filename=filename, + method=splining_method, + rolling_window_size=rolling_window_size, + spline_step_size=spline_step_size, + spline_linear_smoothing=spline_linear_smoothing, + spline_circular_smoothing=spline_circular_smoothing, + spline_degree=spline_degree, + ) + + # When updating the test, you will want to verify that the splines are correct. Use + # plot_spline_debugging to plot the splines on the image. + + # # Save the results to update the test data + # # Save result splining data as pickle + # with Path.open(SPLINING_RESOURCES / expected_all_splines_data_filename, "wb") as file: + # pickle.dump(result_all_splines_data, file) + + # # Save result grainstats additions as csv + # result_splining_grainstats.to_csv(SPLINING_RESOURCES / expected_splining_grainstats_filename) + + # # Save result molstats as csv + # result_molstats_df.to_csv(SPLINING_RESOURCES / expected_molstats_filename) + + # Load the expected results + with Path.open(SPLINING_RESOURCES / expected_all_splines_data_filename, "rb") as file: + expected_all_splines_data = pickle.load(file) + + expected_splining_grainstats = pd.read_csv(SPLINING_RESOURCES / expected_splining_grainstats_filename, index_col=0) + expected_molstats_df = pd.read_csv(SPLINING_RESOURCES / expected_molstats_filename, index_col=0) + + # Check the results + assert dict_almost_equal(result_all_splines_data, expected_all_splines_data) + pd.testing.assert_frame_equal(result_splining_grainstats, expected_splining_grainstats) + pd.testing.assert_frame_equal(result_molstats_df, expected_molstats_df) + + +@pytest.mark.parametrize( + ("pixel_trace", "rolling_window_size", "pixel_to_nm_scaling", "expected_pooled_trace"), + [ + pytest.param( + PIXEL_TRACE, + np.float64(1.0), + 1.0, + np.array( + [ + [0.0, 1.0], + [0.0, 2.0], + [0.0, 3.0], + [1.0, 3.0], + [2.0, 3.0], + [3.0, 3.0], + [3.0, 2.0], + [3.0, 1.0], + [3.0, 0.0], + [2.0, 0.0], + [1.0, 0.0], + [0.0, 0.0], + ] + ).astype(np.float64), + id="4x4 box starting at 0, 0 with window size 1", + ), + pytest.param( + PIXEL_TRACE, + np.float64(2.0), + 1.0, + np.array( + [ + [0.0, 1.5], + [0.0, 2.5], + [0.5, 3.0], + [1.5, 3.0], + [2.5, 3.0], + [3.0, 2.5], + [3.0, 1.5], + [3.0, 0.5], + [2.5, 0.0], + [1.5, 0.0], + [0.5, 0.0], + [0.0, 0.5], + ] + ).astype(np.float64), + id="4x4 box starting at 0, 0 with window size 2", + ), + pytest.param( + PIXEL_TRACE, + np.float64(5.5), + 1.0, + np.array( + [ + [1.0, 2.5], + [1.5, 2.666666666], + [2.0, 2.5], + [2.5, 2.0], + [2.666666666, 1.5], + [2.5, 1.0], + [2.0, 0.5], + [1.5, 0.333333333], + [1.0, 0.5], + [0.5, 1.0], + [0.333333333, 1.5], + [0.5, 2.0], + ] + ).astype(np.float64), + id="4x4 box starting at 0, 0 with window size 5.5", + ), + pytest.param( + PIXEL_TRACE, + np.float64(2.0), + 0.5, + np.array( + [ + [0.25, 2.25], + [0.75, 2.75], + [1.5, 3.0], + [2.25, 2.75], + [2.75, 2.25], + [3.0, 1.5], + [2.75, 0.75], + [2.25, 0.25], + [1.5, 0.0], + [0.75, 0.25], + [0.25, 0.75], + [0.0, 1.5], + ] + ).astype(np.float64), + id="4x4 box starting at 0, 0 with window size 2 and scaling 2", + ), + ], +) +def test_pool_trace_circular( + pixel_trace: npt.NDArray[np.int32], + rolling_window_size: np.float64, + pixel_to_nm_scaling: float, + expected_pooled_trace: npt.NDArray[np.float64], +) -> None: + """Test of the pool_trace_circular function of the windowTrace class.""" + result_pooled_trace = windowTrace.pool_trace_circular(pixel_trace, rolling_window_size, pixel_to_nm_scaling) + + np.testing.assert_allclose(result_pooled_trace, expected_pooled_trace, atol=1e-6) + + +@pytest.mark.parametrize( + ("pixel_trace", "rolling_window_size", "pixel_to_nm_scaling", "expected_pooled_trace"), + [ + pytest.param( + PIXEL_TRACE, + np.float64(1.0), + 1.0, + np.array( + [ + [0.0, 0.0], + [0.0, 1.0], + [0.0, 2.0], + [0.0, 3.0], + [1.0, 3.0], + [2.0, 3.0], + [3.0, 3.0], + [3.0, 2.0], + [3.0, 1.0], + [3.0, 0.0], + [2.0, 0.0], + [1.0, 0.0], + ] + ).astype(np.float64), + id="4x4 box starting at 0, 0 with window size 1", + ), + pytest.param( + PIXEL_TRACE, + np.float64(2.0), + 1.0, + np.array( + [ + [0.0, 0.0], + [0.0, 0.5], + [0.0, 1.5], + [0.0, 2.5], + [0.5, 3.0], + [1.5, 3.0], + [2.5, 3.0], + [3.0, 2.5], + [3.0, 1.5], + [3.0, 0.5], + [2.5, 0.0], + [1.5, 0.0], + [1.0, 0.0], + ] + ).astype(np.float64), + id="4x4 box starting at 0, 0 with window size 2", + ), + pytest.param( + PIXEL_TRACE, + np.float64(5.5), + 1.0, + np.array( + [ + [0.0, 0.0], + [0.5, 2.0], + [1.0, 2.5], + [1.5, 2.66666666666], + [2.0, 2.5], + [2.5, 2.0], + [2.666666666, 1.5], + [2.5, 1.0], + [1.0, 0.0], + ] + ).astype(np.float64), + id="4x4 box starting at 0, 0 with window size 5.5", + ), + pytest.param( + PIXEL_TRACE, + np.float64(2.0), + 2.0, + np.array( + [ + [0.0, 0.0], + [0.0, 1.0], + [0.0, 2.0], + [0.0, 3.0], + [1.0, 3.0], + [2.0, 3.0], + [3.0, 3.0], + [3.0, 2.0], + [3.0, 1.0], + [3.0, 0.0], + [2.0, 0.0], + [1.0, 0.0], + ] + ).astype(np.float64), + id="4x4 box starting at 0, 0 with window size 2 and scaling 2", + ), + ], +) +def test_pool_trace_linear( + pixel_trace: npt.NDArray[np.int32], + rolling_window_size: np.float64, + pixel_to_nm_scaling: float, + expected_pooled_trace: npt.NDArray[np.float64], +) -> None: + """Test of the pool_trace_circular function of the windowTrace class.""" + result_pooled_trace = windowTrace.pool_trace_linear(pixel_trace, rolling_window_size, pixel_to_nm_scaling) + + np.testing.assert_allclose(result_pooled_trace, expected_pooled_trace, atol=1e-6) diff --git a/tests/tracing/test_tracing_dna.py b/tests/tracing/test_tracing_dna.py index e69de29bb2..574ca14a1b 100644 --- a/tests/tracing/test_tracing_dna.py +++ b/tests/tracing/test_tracing_dna.py @@ -0,0 +1 @@ +"""Dummy file since CI ruff seems to think this file exists but it doesn't.""" diff --git a/topostats/default_config.yaml b/topostats/default_config.yaml index 971455dae1..358ae7633e 100644 --- a/topostats/default_config.yaml +++ b/topostats/default_config.yaml @@ -1,7 +1,10 @@ +# Config file generated 2024-07-29 11:11:05 +# # For more information on configuration and how to use it: +# https://afm-spm.github.io/TopoStats/main/configuration.html base_dir: ./ # Directory in which to search for data files output_dir: ./output # Directory to output results to log_level: info # Verbosity of output. Options: warning, error, info, debug -cores: 2 # Number of CPU cores to utilise for processing multiple files simultaneously. +cores: 1 # Number of CPU cores to utilise for processing multiple files simultaneously. file_ext: .spm # File extension of the data files. loading: channel: Height # Channel to pull data from in the data files. @@ -53,22 +56,50 @@ grains: grainstats: run: true # Options : true, false edge_detection_method: binary_erosion # Options: canny, binary erosion. Do not change this unless you are sure of what this will do. - cropped_size: 40.0 # Length (in nm) of square cropped images (can take -1 for grain-sized box) + cropped_size: -1 # Length (in nm) of square cropped images (can take -1 for grain-sized box) extract_height_profile: true # Extract height profiles along maximum feret of molecules -dnatracing: +disordered_tracing: run: true # Options : true, false min_skeleton_size: 10 # Minimum number of pixels in a skeleton for it to be retained. - skeletonisation_method: topostats # Options : zhang | lee | thin | topostats + pad_width: 1 # Pixels to pad grains by when tracing + mask_smoothing_params: + gaussian_sigma: 2 # Gaussian smoothing parameter 'sigma' in pixels. + dilation_iterations: 2 # Number of dilation iterations to use for grain smoothing. + holearea_min_max: [0, null] # Range (min, max) of a hole area in nm to refil in the smoothed masks. + skeletonisation_params: + method: topostats # Options : zhang | lee | thin | topostats + height_bias: 0.6 # Percentage of lowest pixels to remove each skeletonisation iteration. 1 equates to zhang. + pruning_params: + method: topostats # Method to clean branches of the skeleton. Options : topostats + max_length: 10.0 # Maximum length in nm to remove a branch containing an endpoint. + height_threshold: # The height to remove branches below. + method_values: mid # The method to obtain a branch's height for pruning. Options : min | median | mid. + method_outlier: mean_abs # The method to prune branches based on height. Options : abs | mean_abs | iqr. +nodestats: + run: true # Options : true, false + node_joining_length: 7.0 # The distance over which to join nearby crossing points. + node_extend_dist: 14.0 # The distance over which to join nearby odd-branched nodes. + branch_pairing_length: 20.0 # The length from the crossing point to pair and trace, obtaining FWHM's. + pair_odd_branches: false # Whether to try and pair odd-branched nodes. Options: true and false. + pad_width: 1 # Pixels to pad grains by when tracing (should be the same as disordered_tracing). +ordered_tracing: + run: true + ordering_method: nodestats # The method of ordering the disordered traces. + pad_width: 1 # Pixels to pad grains by when tracing (should be the same as disordered_tracing). +splining: + run: true # Options : true, false + method: "rolling_window" # Options : "spline", "rolling_window" + rolling_window_size: 20.0e-9 # size in nm of the rolling window. spline_step_size: 7.0e-9 # The sampling rate of the spline in metres. spline_linear_smoothing: 5.0 # The amount of smoothing to apply to linear splines. - spline_circular_smoothing: 0.0 # The amount of smoothing to apply to circular splines. - pad_width: 1 # Cells to pad grains by when tracing + spline_circular_smoothing: 5.0 # The amount of smoothing to apply to circular splines. + spline_degree: 3 # The polynomial degree of the spline. # cores: 1 # Number of cores to use for parallel processing plotting: run: true # Options : true, false style: topostats.mplstyle # Options : topostats.mplstyle or path to a matplotlibrc params file savefig_format: null # Options : null, png, svg or pdf. tif is also available although no metadata will be saved. (defaults to png) See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html - savefig_dpi: null # Options : null (defaults to the value in topostats/plotting_dictionary.yaml), see https://afm-spm.github.io/TopoStats/main/configuration.html#further-customisation and https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html + savefig_dpi: 100 # Options : null (defaults to the value in topostats/plotting_dictionary.yaml), see https://afm-spm.github.io/TopoStats/main/configuration.html#further-customisation and https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html pixel_interpolation: null # Options : https://matplotlib.org/stable/gallery/images_contours_and_fields/interpolation_methods.html image_set: core # Options : all, core zrange: [null, null] # low and high height range for core images (can take [null, null]). low <= high @@ -76,7 +107,7 @@ plotting: axes: true # Options : true, false (due to off being a bool when parsed) num_ticks: [null, null] # Number of ticks to have along the x and y axes. Options : null (auto) or integer > 1 cmap: null # Colormap/colourmap to use (default is 'nanoscope' which is used if null, other options are 'afmhot', 'viridis' etc.) - mask_cmap: blu # Options : blu, jet_r and any in matplotlib + mask_cmap: blue_purple_green # Options : blu, jet_r and any in matplotlib histogram_log_axis: false # Options : true, false summary_stats: run: true # Whether to make summary plots for output data diff --git a/topostats/entry_point.py b/topostats/entry_point.py index 6e65f19c9a..3a01780e0e 100644 --- a/topostats/entry_point.py +++ b/topostats/entry_point.py @@ -255,20 +255,6 @@ def create_parser() -> arg.ArgumentParser: help="Path to a YAML configuration file.", ) - dnatracing_parser = subparsers.add_parser( - "dnatracing", - description="Load images with grains from '.topostats' files and trace DNA molecules.", - help="Load images with grains from '.topostats' files and trace DNA molecules.", - ) - dnatracing_parser.add_argument( - "-c", - "--config-file", - dest="config_file", - type=Path, - required=False, - help="Path to a YAML configuration file.", - ) - tracingstats_parser = subparsers.add_parser( "tracingstats", description="Load images with grains from '.topostats' files and trace DNA molecules.", diff --git a/topostats/grains.py b/topostats/grains.py index e2302a798b..46bd08bba2 100644 --- a/topostats/grains.py +++ b/topostats/grains.py @@ -166,6 +166,11 @@ def __init__( self.grainstats = None self.unet_config = unet_config + # Hardcoded minimum pixel size for grains. This should not be able to be changed by the user as this is + # determined by what is processable by the rest of the pipeline. + self.minimum_grain_size_px = 10 + self.minimum_bbox_size_px = 5 + def tidy_border(self, image: npt.NDArray, **kwargs) -> npt.NDArray: """ Remove grains touching the border. @@ -295,6 +300,41 @@ def remove_small_objects(self, image: np.array, **kwargs) -> npt.NDArray: return small_objects_removed > 0.0 return image + def remove_objects_too_small_to_process( + self, image: npt.NDArray, minimum_size_px: int, minimum_bbox_size_px: int + ) -> npt.NDArray[np.bool_]: + """ + Remove objects whose dimensions in pixels are too small to process. + + Parameters + ---------- + image : npt.NDArray + 2-D Numpy array of image. + minimum_size_px : int + Minimum number of pixels for an object. + minimum_bbox_size_px : int + Limit for the minimum dimension of an object in pixels. Eg: 5 means the object's bounding box must be at + least 5x5. + + Returns + ------- + npt.NDArray + 2-D Numpy array of image with objects removed that are too small to process. + """ + labelled_image = label(image) + region_properties = self.get_region_properties(labelled_image) + for region in region_properties: + # If the number of true pixels in the region is less than the minimum number of pixels, remove the region + if region.area < minimum_size_px: + labelled_image[labelled_image == region.label] = 0 + bbox_width = region.bbox[2] - region.bbox[0] + bbox_height = region.bbox[3] - region.bbox[1] + # If the minimum dimension of the bounding box is less than the minimum dimension, remove the region + if min(bbox_width, bbox_height) < minimum_bbox_size_px: + labelled_image[labelled_image == region.label] = 0 + + return labelled_image.astype(bool) + def area_thresholding(self, image: npt.NDArray, area_thresholds: tuple) -> npt.NDArray: """ Remove objects larger and smaller than the specified thresholds. @@ -440,8 +480,15 @@ def find_grains(self): self.directions[direction]["removed_noise"], self.absolute_area_threshold[direction], ) + self.directions[direction]["removed_objects_too_small_to_process"] = ( + self.remove_objects_too_small_to_process( + image=self.directions[direction]["removed_small_objects"], + minimum_size_px=self.minimum_grain_size_px, + minimum_bbox_size_px=self.minimum_bbox_size_px, + ) + ) self.directions[direction]["labelled_regions_02"] = self.label_regions( - self.directions[direction]["removed_small_objects"] + self.directions[direction]["removed_objects_too_small_to_process"] ) self.region_properties[direction] = self.get_region_properties( diff --git a/topostats/grainstats.py b/topostats/grainstats.py index 23ecd4ff28..8644bdf7e7 100644 --- a/topostats/grainstats.py +++ b/topostats/grainstats.py @@ -37,7 +37,7 @@ GRAIN_STATS_COLUMNS = [ - "molecule_number", + "grain_number", "centre_x", "centre_y", "radius_min", @@ -344,7 +344,7 @@ def calculate_stats(self) -> tuple(pd.DataFrame, dict): grainstats_df = pd.DataFrame(data=stats_array) else: grainstats_df = create_empty_dataframe() - grainstats_df.index.name = "molecule_number" + grainstats_df.index.name = "grain_number" grainstats_df["image"] = self.image_name return grainstats_df, grains_plot_data, all_height_profiles diff --git a/topostats/io.py b/topostats/io.py index 06b3a669e8..a26c10a778 100644 --- a/topostats/io.py +++ b/topostats/io.py @@ -37,6 +37,54 @@ # pylint: disable=too-many-lines +# Sylvia: Ruff says too complex but I think breaking this out would be more complex. +def dict_almost_equal(dict1: dict, dict2: dict, abs_tol: float = 1e-9): # noqa: C901 + """ + Recursively check if two dictionaries are almost equal with a given absolute tolerance. + + Parameters + ---------- + dict1 : dict + First dictionary to compare. + dict2 : dict + Second dictionary to compare. + abs_tol : float + Absolute tolerance to check for equality. + + Returns + ------- + bool + True if the dictionaries are almost equal, False otherwise. + """ + if dict1.keys() != dict2.keys(): + return False + + LOGGER.info("Comparing dictionaries") + + for key in dict1: + LOGGER.info(f"Comparing key {key}") + if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): + if not dict_almost_equal(dict1[key], dict2[key], abs_tol=abs_tol): + return False + elif isinstance(dict1[key], np.ndarray) and isinstance(dict2[key], np.ndarray): + if not np.allclose(dict1[key], dict2[key], atol=abs_tol): + LOGGER.info(f"Key {key} type: {type(dict1[key])} not equal: {dict1[key]} != {dict2[key]}") + return False + elif isinstance(dict1[key], float) and isinstance(dict2[key], float): + # Skip if both values are NaN + if not (np.isnan(dict1[key]) and np.isnan(dict2[key])): + # Check if both values are close + if not np.isclose(dict1[key], dict2[key], atol=abs_tol): + LOGGER.info(f"Key {key} type: {type(dict1[key])} not equal: {dict1[key]} != {dict2[key]}") + return False + + elif dict1[key] != dict2[key]: + LOGGER.info(f"Key {key} not equal: {dict1[key]} != {dict2[key]}") + return False + + return True + + def read_yaml(filename: str | Path) -> dict: """ Read a YAML file. @@ -1104,7 +1152,7 @@ def dict_to_hdf5(open_hdf5_file: h5py.File, group_path: str, dictionary: dict) - A dictionary of the data to save. """ for key, item in dictionary.items(): - LOGGER.debug(f"Saving key: {key}") + # LOGGER.info(f"Saving key: {key}") if item is None: LOGGER.warning(f"Item '{key}' is None. Skipping.") @@ -1135,7 +1183,7 @@ def dict_to_hdf5(open_hdf5_file: h5py.File, group_path: str, dictionary: dict) - try: open_hdf5_file[group_path + key] = item except Exception as e: - LOGGER.info(f"Cannot save key '{key}' to HDF5. Item type: {type(item)}. Skipping. {e}") + LOGGER.warning(f"Cannot save key '{key}' to HDF5. Item type: {type(item)}. Skipping. {e}") def hdf5_to_dict(open_hdf5_file: h5py.File, group_path: str) -> dict: diff --git a/topostats/measure/geometry.py b/topostats/measure/geometry.py new file mode 100644 index 0000000000..43a0117fed --- /dev/null +++ b/topostats/measure/geometry.py @@ -0,0 +1,334 @@ +"""Functions for measuring geometric properties of grains.""" + +from __future__ import annotations + +import math + +import networkx +import numpy as np +import numpy.typing as npt + + +def bounding_box_cartesian_points_float( + points: npt.NDArray[np.number], +) -> tuple[np.float64, np.float64, np.float64, np.float64]: + """ + Calculate the bounding box from a set of points. + + Parameters + ---------- + points : npt.NDArray[np.number] + Nx2 numpy array of points. + + Returns + ------- + Tuple[np.float64, np.float64, np.float64, np.float64] + Tuple of (min_x, min_y, max_x, max_y). + + Raises + ------ + ValueError + If the input array is not Nx2. + """ + if points.shape[1] != 2: + raise ValueError("Input array must be Nx2.") + x_coords, y_coords = points[:, 0].astype(np.float64), points[:, 1].astype(np.float64) + return (np.min(x_coords), np.min(y_coords), np.max(x_coords), np.max(y_coords)) + + +def bounding_box_cartesian_points_integer( + points: npt.NDArray[np.number], +) -> tuple[np.int32, np.int32, np.int32, np.int32]: + """ + Calculate the bounding box from a set of points. + + Parameters + ---------- + points : npt.NDArray[np.number] + Nx2 numpy array of points. + + Returns + ------- + Tuple[np.int32, np.int32, np.int32, np.int32] + Tuple of (min_x, min_y, max_x, max_y). + + Raises + ------ + ValueError + If the input array is not Nx2. + """ + if points.shape[1] != 2: + raise ValueError("Input array must be Nx2.") + x_coords, y_coords = points[:, 0].astype(np.int32), points[:, 1].astype(np.int32) + return (np.min(x_coords), np.min(y_coords), np.max(x_coords), np.max(y_coords)) + + +def do_points_in_arrays_touch( + points1: npt.NDArray[np.int32], points2: npt.NDArray[np.int32] +) -> tuple[bool, npt.NDArray[np.int32] | None, npt.NDArray[np.int32] | None]: + """ + Check if any points in two arrays are touching. + + Parameters + ---------- + points1 : npt.NDArray[np.int32] + Nx2 numpy array of points. + points2 : npt.NDArray[np.int32] + Mx2 numpy array of points. + + Returns + ------- + tuple[bool, npt.NDArray[np.int32] | None, npt.NDArray[np.int32] | None] + True if any points in the two arrays are touching, False otherwise, followed by the first touching point pair + that was found. If no points are touching, the second and third elements of the tuple will be None. + + Raises + ------ + ValueError + If the input arrays are not Nx2 and Mx2. + """ + if points1.shape[1] != 2 or points2.shape[1] != 2: + raise ValueError("Input arrays must be Nx2 and Mx2.") + + for point1 in points1: + for point2 in points2: + diff = np.abs(point1 - point2) + if np.all(diff <= 1): + return (True, point1, point2) + return (False, None, None) + + +# pylint: disable=too-many-locals +def calculate_shortest_branch_distances( + nodes_with_branch_starting_coords: dict[int, list[npt.NDArray[np.int32]]], + whole_skeleton_graph: networkx.classes.graph.Graph, +) -> tuple[npt.NDArray[np.number], npt.NDArray[np.int32], npt.NDArray[np.number]]: + """ + Calculate the shortest distances between branches emanating from nodes. + + Parameters + ---------- + nodes_with_branch_starting_coords : dict[int, list[npt.NDArray[np.int32]]] + Dictionary where the key is the node number and the value is an Nx2 numpy array of the starting coordinates + of its branches. + whole_skeleton_graph : networkx.classes.graph.Graph + Networkx graph representing the whole network. + + Returns + ------- + Tuple[npt.NDArray[np.number], npt.NDArray[np.int32], npt.NDArray[np.int32]] + - NxN numpy array of shortest distances between every node pair. Indexes of this array represent the nodes. + Eg for a 3x3 matrix, there are 3 nodes being compared with each other. + This matrix is diagonally symmetric and the diagonal values are 0 since a node is always 0 distance from itself. + - NxNx2 numpy array of indexes of the best branches to connect between each node pair. + Eg for node 1 and 3, the closest branches might be indexes 2 and 4, so the value at [1, 3] would be [2, 4]. + - NxNx2x2 numpy array of the coordinates of the branches to connect between each node pair. + Eg for node 1 and 3, the closest branches might be at coordinates [2, 3] and [4, 5], so the + value at [1, 3] would be [[2, 3], [4, 5]]. + """ + num_nodes = len(nodes_with_branch_starting_coords) + shortest_node_distances = np.zeros((num_nodes, num_nodes), dtype=np.float64) + # For storing the indexes of the branches that are the best candidate between two nodes. + # Eg: [[[0, 0], [1, 2]], [[1, 2], [0, 0]]] means that node 0's branch 0 connects with node 1's branch 2. + # Note that this matrix is symmetric about the diagonal as we double-iterate between all nodes. + shortest_distances_branch_indexes = np.zeros((num_nodes, num_nodes, 2), dtype=np.int32) + shortest_distances_branch_coordinates = np.zeros((num_nodes, num_nodes, 2, 2), dtype=object) + + # Iterate over the nodes twice to compare each combination of nodes. This double counts, so will create a symmetric + # matrix about the diagonal. + for node_index_i, (_node_i, node_branches_starts_coords_i) in enumerate(nodes_with_branch_starting_coords.items()): + for node_index_j, (_node_j, node_branches_starts_coords_j) in enumerate( + nodes_with_branch_starting_coords.items() + ): + # Don't compare the same node to itself + if node_index_i == node_index_j: + continue + # Store the shortest distance as we iterate. + shortest_distance = None + # For storing the pair of branch indexes that are the best candidate between the two nodes. + # Eg: (3, 2) means that node i's branch 3 connects with node j's branch 2. + shortest_distance_branch_indexes: tuple[int, int] | None = None + # Iteratively compare all branches from node1 to all branches from node2 + # to find the shortest distance between any two branches + for branch_index_i, branch_start_i in enumerate(node_branches_starts_coords_i): + for branch_index_j, branch_start_j in enumerate(node_branches_starts_coords_j): + shortest_path_length_between_branch_i_and_j = networkx.shortest_path_length( + whole_skeleton_graph, tuple(branch_start_i), tuple(branch_start_j) + ) + if shortest_distance is None or shortest_path_length_between_branch_i_and_j < shortest_distance: + shortest_distance = shortest_path_length_between_branch_i_and_j + # Store the indexes of the branches that are the shortest distance apart for node i and node j + shortest_distance_branch_indexes = (branch_index_i, branch_index_j) + + # Store the shortest distance between the two nodes + shortest_node_distances[node_index_i, node_index_j] = shortest_distance + # Store the indexes of the branches that are the shortest distance apart for node i and node j. + # Note this may be None as the nodes may not be connected? + shortest_distances_branch_indexes[node_index_i, node_index_j] = shortest_distance_branch_indexes + # Ensure that the nodes are connected before storing the coordinates of the branches starting coords + if shortest_distance_branch_indexes is not None: + # Add the coordinates of the branch pairs for each node-node combination. So for example, for + # node 0 and node 1, branches starting [6, 1] and [6, 11]: + # np.array([ [ [[0, 0][0, 0]] [[6 1][6 11]] ] [ [[6 11][6 1]] [[0, 0][0, 0]] ] ]) + # Where the square [0 0][0 0]s are for node 0 / node 0 and node 1 / node 1. + # And [6 1][6 11] is for node 0 / node 1, indicating that branches start at [6, 1] and [6, 11]. + shortest_distances_branch_coordinates[node_index_i, node_index_j] = ( + node_branches_starts_coords_i[shortest_distance_branch_indexes[0]], + node_branches_starts_coords_j[shortest_distance_branch_indexes[1]], + ) + else: + shortest_distances_branch_coordinates[node_index_i, node_index_j] = (None, None) + + return shortest_node_distances, shortest_distances_branch_indexes, shortest_distances_branch_coordinates + + +def connect_best_matches( + network_array_representation: npt.NDArray[np.int32], + whole_skeleton_graph: networkx.classes.graph.Graph, + match_indexes: npt.NDArray[np.int32], + shortest_distances_between_nodes: npt.NDArray[np.number], + shortest_distances_branch_indexes: npt.NDArray[np.int32], + emanating_branch_starts_by_node: dict[int, list[npt.NDArray[np.int32]]], + extend_distance: float = -1, +) -> npt.NDArray[np.int32]: + """ + Connect the branches between node pairs that have been deemed to be best matches. + + Parameters + ---------- + network_array_representation : npt.NDArray[np.int32] + 2D numpy array representing the network using integers to represent branches, nodes etc. + whole_skeleton_graph : networkx.classes.graph.Graph + Networkx graph representing the whole network. + match_indexes : npt.NDArray[np.int32] + Nx2 numpy array of indexes of the best matching nodes. + Eg: np.array([[1, 0], [2, 3]]) means that the best matching nodes are node 1 and node 0, and node 2 and node 3. + shortest_distances_between_nodes : npt.NDArray[np.number] + NxN numpy array of shortest distances between every node pair. + Index positions indicate which node it's referring to, so index 2, 3 will be the shortest distance + between nodes 2 and 3. + Values on the diagonal will be 0 because the shortest distance between a node and itself is 0. + Eg: np.array([[0.0, 6.0], [6.0, 0.0]]) means that the shortest distance between node 0 and node 1 is 6.0. + shortest_distances_branch_indexes : npt.NDArray[np.int32] + NxNx2 numpy array of indexes of the branches to connect between the best matching nodes. + Not entirely sure what it does so won't attempt to explain more to avoid confusion. + emanating_branch_starts_by_node : dict[int, list[npt.NDArray[np.int32]]] + Dictionary where the key is the node number and the value is an Nx2 numpy array of the starting coordinates + of the branches emanating from that node. Rather self-explanatory. + Eg: + ```python + { + 0: [np.array([6, 1]), np.array([7, 3]), np.array([8, 1])], + 1: [np.array([6, 11]), np.array([7, 9]), np.array([8, 11])], + }, + ```. + extend_distance : float + The distance to extend the branches to connect. If the shortest distance between two nodes is less than or equal + to this distance, the branches will be connected. If -1, the branches will be connected regardless of distance. + + Returns + ------- + npt.NDArray[np.int32] + 2D numpy array representing the network using integers to represent branches, nodes etc. + """ + for node_pair_index in match_indexes: + # Fetch the shortest distance between two nodes + shortest_distance = shortest_distances_between_nodes[node_pair_index[0], node_pair_index[1]] + if shortest_distance <= extend_distance or extend_distance == -1: + # Fetch the indexes of the branches to connect defined by the closest connecting + # branchees of the given nodes + indexes_of_branches_to_connect = shortest_distances_branch_indexes[node_pair_index[0], node_pair_index[1]] + node_numbers = list(emanating_branch_starts_by_node.keys()) + # Grab the coordinate of the starting branch position of the branch to connect + source = tuple( + emanating_branch_starts_by_node[node_numbers[node_pair_index[0]]][indexes_of_branches_to_connect[0]] + ) + # Grab the coordinate of the branch position to connect to + target = tuple( + emanating_branch_starts_by_node[node_numbers[node_pair_index[1]]][indexes_of_branches_to_connect[1]] + ) + # Get the path between the two branches using networkx + path = np.array(networkx.shortest_path(whole_skeleton_graph, source, target)) + # Set all the coordinates in the path to 3 in the network array representation + network_array_representation[path[:, 0], path[:, 1]] = 3 + + return network_array_representation + + +# pylint: disable=too-many-locals +def find_branches_for_nodes( + network_array_representation: npt.NDArray[np.int32], + labelled_nodes: npt.NDArray[np.int32], + labelled_branches: npt.NDArray[np.int32], +) -> dict[int, list[npt.NDArray[np.int32]]]: + """ + Locate branch starting positions for each node in a network. + + Parameters + ---------- + network_array_representation : npt.NDArray[np.int32] + 2D numpy array representing the network using integers to represent branches, nodes etc. + labelled_nodes : npt.NDArray[np.int32] + 2D numpy array representing the network using integers to represent nodes. + labelled_branches : npt.NDArray[np.int32] + 2D numpy array representing the network using integers to represent branches. + + Returns + ------- + dict[int, list[npt.NDArray[np.int32]]] + Dictionary where the key is the node number and the value is an Nx2 numpy array of the starting coordinates + of the branches emanating from that node. + """ + # Dictionary to store emanating branches for each labelled node + emanating_branch_starts_by_node = {} + + # Iterate over all the nodes in the labelled nodes image + for node_num in range(1, labelled_nodes.max() + 1): + num_branches = 0 + # makes lil box around node with 1 overflow + bounding_box = bounding_box_cartesian_points_integer(np.argwhere(labelled_nodes == node_num)) + crop_left = bounding_box[0] - 1 + crop_right = bounding_box[2] + 2 + crop_top = bounding_box[1] - 1 + crop_bottom = bounding_box[3] + 2 + cropped_matrix = network_array_representation[crop_left:crop_right, crop_top:crop_bottom] + # get coords of nodes and branches in box + node_coords = np.argwhere(cropped_matrix == 3) + branch_coords = np.argwhere(cropped_matrix == 1) + # iterate through node coords to see which are within 8 dirs + for node_coord in node_coords: + for branch_coord in branch_coords: + distance = math.dist(node_coord, branch_coord) + if distance <= math.sqrt(2): + num_branches = num_branches + 1 + + # All nodes with even branches are considered to be complete as they have one + # strand going in for each coming out. This assumes that no strands naturally terminate at nodes. + + # find the branch start point of odd branched nodes + if num_branches % 2 == 1: + emanating_branches: list[npt.NDArray[np.int32]] = ( + [] + ) # List to store emanating branches for the current label + for branch in range(1, labelled_branches.max() + 1): + # technically using labelled_branches when there's an end loop will only cause one + # of the end loop coords to be captured. This shopuldn't matter as the other + # label after the crossing should be closer to another node. + # The touching_point_1 and touching_point_2 can be None since the function returns None for both + # if no points touch. + touching, touching_point_1, _touching_point_2 = do_points_in_arrays_touch( + np.argwhere(labelled_branches == branch), + np.argwhere(labelled_nodes == node_num), + ) + if touching: + assert touching_point_1 is not None + # Above required for mypy to ensure that there are no Nones + # in the list, to prevent the return type being npt.NDArray[np.int32 | None] + emanating_branches.append(touching_point_1) + + assert len(emanating_branches) > 0, f"No branches found for node {node_num}" + emanating_branch_starts_by_node[node_num - 1] = ( + emanating_branches # Store emanating branches for this label + ) + + return emanating_branch_starts_by_node diff --git a/topostats/plotting.py b/topostats/plotting.py index 5088dcc61b..2487dc029b 100644 --- a/topostats/plotting.py +++ b/topostats/plotting.py @@ -10,6 +10,7 @@ import sys import yaml import matplotlib.pyplot as plt +import matplotlib.colors import numpy as np import numpy.typing as npt import pandas as pd @@ -18,6 +19,7 @@ from topostats.io import read_yaml, write_yaml, convert_basename_to_relative_paths from topostats.logs.logs import LOGGER_NAME from topostats.utils import update_config +from topostats.theme import Colormap LOGGER = logging.getLogger(LOGGER_NAME) @@ -77,7 +79,7 @@ def __init__( base_dir: str | Path = None, csv_file: str | Path = None, stat_to_sum: str = None, - molecule_id: str = "molecule_number", + molecule_id: str = "grain_number", image_id: str = "image", hist: bool = True, stat: str = "count", @@ -305,7 +307,7 @@ def melt_data(df: pd.DataFrame, stat_to_summarize: str, var_to_label: dict) -> p pd.DataFrame Data in long-format with descriptive variable names. """ - melted_data = pd.melt(df.reset_index(), id_vars=["molecule_number", "basename"], value_vars=stat_to_summarize) + melted_data = pd.melt(df.reset_index(), id_vars=["grain_number", "basename"], value_vars=stat_to_summarize) melted_data["variable"] = melted_data["variable"].map(var_to_label) LOGGER.info("[plotting] Data has been melted to long format for plotting.") @@ -471,6 +473,71 @@ def run_toposum(args=None) -> None: toposum(config) +def plot_crossing_linetrace_halfmax( + branch_stats_dict: dict, mask_cmap: matplotlib.colors.Colormap, title: str +) -> tuple: + """ + Plot the height-map line traces of the branches found in the 'branch_stats' dictionary, and their meetings. + + Parameters + ---------- + branch_stats_dict : dict + Dictionary containing branch height, distance and fwhm info. + mask_cmap : matplotlib.colors.Colormap + Colormap for plotting. + title : str + Title for the plot. + + Returns + ------- + fig, ax + Matplotlib fig and ax objects. + """ + fig, ax = plt.subplots(1, 1, figsize=(7, 4)) + cmp = Colormap(mask_cmap).get_cmap() + total_branches = len(branch_stats_dict) + # plot the highest first + fwhms = [] + for branch_idx, values in branch_stats_dict.items(): + fwhms.append(values["fwhm"]["fwhm"]) + branch_idx_order = np.array(list(branch_stats_dict.keys()))[np.argsort(np.array(fwhms))] + + for i, branch_idx in enumerate(branch_idx_order): + fwhm_dict = branch_stats_dict[branch_idx]["fwhm"] + if total_branches == 1: + cmap_ratio = 0 + else: + cmap_ratio = i / (total_branches - 1) + heights = branch_stats_dict[branch_idx]["heights"] + x = branch_stats_dict[branch_idx]["distances"] + ax.plot(x, heights, c=cmp(cmap_ratio)) # label=f"Branch: {branch_idx}" + + # plot the high point lines + plt.plot( + [-15, fwhm_dict["peaks"][1]], + [fwhm_dict["peaks"][2], fwhm_dict["peaks"][2]], + c=cmp(cmap_ratio), + label=f"FWHM: {fwhm_dict['fwhm']:.4f}", + ) + # plot the half max lines + plt.plot( + [fwhm_dict["half_maxs"][0], fwhm_dict["half_maxs"][0]], + [fwhm_dict["half_maxs"][2], heights.min()], + c=cmp(cmap_ratio), + ) + plt.plot( + [fwhm_dict["half_maxs"][1], fwhm_dict["half_maxs"][1]], + [fwhm_dict["half_maxs"][2], heights.min()], + c=cmp(cmap_ratio), + ) + + ax.set_xlabel("Distance from Node (nm)") + ax.set_ylabel("Height") + ax.set_title(title) + ax.legend() + return fig, ax + + def plot_height_profiles(height_profiles: list | npt.NDArray) -> tuple: """ Plot height profiles. diff --git a/topostats/plotting_dictionary.yaml b/topostats/plotting_dictionary.yaml index e7b0211880..c4cc867024 100644 --- a/topostats/plotting_dictionary.yaml +++ b/topostats/plotting_dictionary.yaml @@ -11,6 +11,8 @@ # | image_type | String | Whether the plot includes the height (non-binary) or the outline (binary) | # | savefig_dpi | int | Dots Per Inch for plotting | # | core_set | Boolean | Whether a plot is considered part of the core set of images that are plotted.| + +# Flattening Troubleshooting Images extracted_channel: filename: "00-raw_heightmap" title: "Raw Height" @@ -118,34 +120,47 @@ z_threshed: image_type: "non-binary" savefig_dpi: 100 core_set: true +# Grainfinding Troubleshooting Images mask_grains: filename: "17-mask_grains" title: "Mask for Grains" image_type: "binary" + mask_cmap: "binary" savefig_dpi: 100 core_set: false labelled_regions_01: filename: "18-labelled_regions" title: "Labelled Regions" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false tidied_border: filename: "19-tidy_borders" title: "Tidied Borders" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false removed_noise: filename: "20-noise_removed" title: "Noise removed" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false removed_small_objects: filename: "21-small_objects_removed" title: "Small Objects Removed" image_type: "binary" + mask_cmap: "rainbow" + savefig_dpi: 100 + core_set: false +removed_objects_too_small_to_process: + filename: "22-objects_too_small_removed" + title: "Objects too small to process removed" + image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false mask_overlay: @@ -154,34 +169,32 @@ mask_overlay: savefig_dpi: 100 core_set: true labelled_regions_02: - filename: "22-labelled_regions" + filename: "23-labelled_regions" title: "Labelled Regions" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false coloured_regions: - filename: "23-coloured_regions" + filename: "24-coloured_regions" title: "Coloured Regions" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false bounding_boxes: - filename: "24-bounding_boxes" + filename: "25-bounding_boxes" title: "Bounding Boxes" image_type: "binary" savefig_dpi: 100 core_set: false coloured_boxes: - filename: "25-labelled_image_bboxes" + filename: "26-labelled_image_bboxes" title: "Labelled Image with Bounding Boxes" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false -all_molecule_traces: - title: "Molecule Traces" - image_type: "non-binary" - savefig_dpi: 800 - core_set: true grain_image: image_type: "non-binary" savefig_dpi: 100 @@ -194,7 +207,127 @@ grain_mask_image: image_type: "non-binary" savefig_dpi: 100 core_set: false -single_molecule_trace: +# Disordered Tracing Troubleshooting Images +orig_grain: + filename: "20-original_grains" + title: "Image with Threshold Mask" image_type: "non-binary" - savefig_dpi: 100 + mask_cmap: "blue" + core_set: false +smoothed_grain: + filename: "21-smoothed_grains" + title: "Image with Gaussian-Smoothed Threshold Mask" + image_type: "non-binary" + mask_cmap: "blue" + core_set: false +skeleton: + filename: "22-original_skeletons" + title: "Original Skeletons" + image_type: "non-binary" + mask_cmap: "blue" + core_set: false + savefig_dpi: 600 +pruned_skeleton: + title: "Pruned Skeletons" + image_type: "non-binary" + mask_cmap: "blue" + core_set: false + savefig_dpi: 600 +branch_indexes: + filename: "23-segment_indexes" + title: "Skeleton Segment Indexes" + image_type: "non-binary" + mask_cmap: "viridis" + core_set: false + savefig_dpi: 600 +branch_types: + filename: "24-segment_types" + title: "Skeleton Segment Types" + image_type: "non-binary" + mask_cmap: "viridis" + core_set: false + savefig_dpi: 600 +# Nodestats troubleshooting images +convolved_skeletons: + filename: "25-convolved_skeleton" + title: "Skeletons and Nodes" + image_type: "non-binary" + mask_cmap: "blue_purple_green" + core_set: false + savefig_dpi: 600 +node_centres: + filename: "26-node_centres" + title: "Skeletons and Highlighted Nodes" + image_type: "non-binary" + mask_cmap: "blue_purple_green" + core_set: false + savefig_dpi: 600 +connected_nodes: + title: "Skeletons and Nodes" + image_type: "non-binary" + mask_cmap: "blue_purple_green" + core_set: true + savefig_dpi: 600 +node_area_skeleton: + title: "Zoom of Node Skeleton" + image_type: "non-binary" + core_set: false + mask_cmap: "blue_purple_green" + savefig_dpi: 200 +node_branch_mask: + title: "Crossing and Skeleton Branches" + image_type: "non-binary" + core_set: false + mask_cmap: "blue_purple_green" + savefig_dpi: 200 +node_avg_mask: + title: "Main and Parallel Traces" + image_type: "non-binary" + mask_cmap: "blue_purple_green" + core_set: false +node_line_trace: + title: "Heights of Crossing" + image_type: "non-binary" + mask_cmap: "blue_purple_green" + core_set: false +# Ordered tracing troubleshooting images +ordered_traces: + title: "Ordered Traces" + image_type: "non-binary" + mask_cmap: "viridis" core_set: false + savefig_dpi: 600 +trace_segments: + filename: "27-trace_segments" + title: "Trace Segments" + image_type: "non-binary" + mask_cmap: "gist_rainbow" + savefig_dpi: 600 + core_set: false +over_under: + filename: "28-molecule_crossings" + title: "Visualised Molecule Crossings" + image_type: "non-binary" + mask_cmap: "blue_purple_green" + core_set: false + savefig_dpi: 600 +all_molecules: + filename: "29-all_molecules" + title: "Individual Molecules" + image_type: "non-binary" + mask_cmap: "blue_purple_green" + savefig_dpi: 600 + core_set: false +# Splining +fitted_trace: + filename: "30-fitted-traces" + title: "Fitted Trace" + image_type: "non-binary" + mask_cmap: "blue_purple_green" + core_set: false + savefig_dpi: 600 +splined_trace: + title: "Smoothed Traces" + image_type: "non-binary" + savefig_dpi: 300 + core_set: true diff --git a/topostats/plottingfuncs.py b/topostats/plottingfuncs.py index cd44c81bd1..07eba2060a 100644 --- a/topostats/plottingfuncs.py +++ b/topostats/plottingfuncs.py @@ -18,18 +18,13 @@ from topostats.logs.logs import LOGGER_NAME from topostats.theme import Colormap +# pylint: disable=dangerous-default-value +# pylint: disable=too-many-arguments # pylint: disable=too-many-instance-attributes # pylint: disable=too-many-locals -# pylint: disable=too-many-arguments -# pylint: disable=dangerous-default-value LOGGER = logging.getLogger(LOGGER_NAME) -# pylint: disable=too-many-instance-attributes -# pylint: disable=too-many-arguments -# pylint: disable=too-many-locals -# pylint: disable=dangerous-default-value - def add_pixel_to_nm_to_plotting_config(plotting_config: dict, pixel_to_nm_scaling: float) -> dict: """ @@ -100,32 +95,34 @@ class Images: Parameters ---------- - data : np.array + data : npt.NDarray Numpy array to plot. - output_dir : Union[str, Path] + output_dir : str | Path Output directory to save the file to. - filename : Union[str, Path] + filename : str Filename to save image as. - style : dict - Filename of matploglibrc Params. + style : str | Path + Filename of matplotlibrc parameters. pixel_to_nm_scaling : float - The scaling factor showing the real length of 1 pixel, in nm. - masked_array : npt.NDArray + The scaling factor showing the real length of 1 pixel in nanometers (nm). + masked_array : npt.NDarray Optional mask array to overlay onto an image. + plot_coords : npt.NDArray + ??? Needs defining. title : str Title for plot. image_type : str - The image data type - binary or non-binary. + The image data type, options are 'binary' or 'non-binary'. image_set : str - The set of images to process - core or all. + The set of images to process, options are 'core' or 'all'. core_set : bool Flag to identify image as part of the core image set or not. - pixel_interpolation : str | None - Interpolation to use (default: None). - cmap : str + pixel_interpolation : str, optional + Interpolation to use (default is 'None'). + cmap : str, optional Colour map to use (default 'nanoscope', 'afmhot' also available). mask_cmap : str - Colour map to use for the secondary (masked) data (default 'jet_r', 'blu' proivides more contrast). + Colour map to use for the secondary (masked) data (default 'jet_r', 'blu' provides more contrast). region_properties : dict Dictionary of region properties, adds bounding boxes if specified. zrange : list @@ -135,27 +132,28 @@ class Images: axes : bool Optionally add/remove axes from the image. num_ticks : tuple[int | None] - The number of x and y ticks to display on the image. + The number of x and y ticks to display on the iage. save : bool Whether to save the image. - savefig_format : str + savefig_format : str, optional Format to save the image as. histogram_log_axis : bool - Optionally use a logarithmic y axis for the histogram plots. - histogram_bins : int + Optionally use a loagrithmic y-axis for the histogram plots. + histogram_bins : int, optional Number of bins for histograms to use. - savefig_dpi : str | float | None + savefig_dpi : str | float, optional The resolution of the saved plot (default 'figure'). """ def __init__( self, - data: np.array, + data: npt.NDarray, output_dir: str | Path, filename: str, style: str | Path = None, pixel_to_nm_scaling: float = 1.0, - masked_array: np.array = None, + masked_array: npt.NDarray = None, + plot_coords: npt.NDArray = None, title: str = None, image_type: str = "non-binary", image_set: str = "core", @@ -177,39 +175,41 @@ def __init__( """ Initialise the class. - There are two key parameters that ensure whether and image is plotted that are passed in from the update + There are two key parameters that ensure whether an image is plotted that are passed in from the updated plotting dictionary. These are the `image_set` which defines whether to plot 'all' images or just the `core` - set. There is then the 'core_set' which defines whether an individual images belongs to the 'core_set' or not. - If it doesn't then it is not plotted when `image_set == "core"`. + set. There is then the 'core_set' which defines whether an individual images belongs to the 'core_set' or + not. If it doesn't then it is not plotted when `image_set == "core"`. Parameters ---------- - data : np.array + data : npt.NDarray Numpy array to plot. - output_dir : Union[str, Path] + output_dir : str | Path Output directory to save the file to. - filename : Union[str, Path] + filename : str Filename to save image as. - style : dict - Filename of matploglibrc Params. + style : str | Path + Filename of matplotlibrc parameters. pixel_to_nm_scaling : float - The scaling factor showing the real length of 1 pixel, in nm. - masked_array : npt.NDArray + The scaling factor showing the real length of 1 pixel in nanometers (nm). + masked_array : npt.NDarray Optional mask array to overlay onto an image. + plot_coords : npt.NDArray + ??? Needs defining. title : str Title for plot. image_type : str - The image data type - binary or non-binary. + The image data type, options are 'binary' or 'non-binary'. image_set : str - The set of images to process - core or all. + The set of images to process, options are 'core' or 'all'. core_set : bool Flag to identify image as part of the core image set or not. - pixel_interpolation : str | None - Interpolation to use (default: None). - cmap : str + pixel_interpolation : str, optional + Interpolation to use (default is 'None'). + cmap : str, optional Colour map to use (default 'nanoscope', 'afmhot' also available). mask_cmap : str - Colour map to use for the secondary (masked) data (default 'jet_r', 'blu' proivides more contrast). + Colour map to use for the secondary (masked) data (default 'jet_r', 'blu' provides more contrast). region_properties : dict Dictionary of region properties, adds bounding boxes if specified. zrange : list @@ -219,16 +219,16 @@ def __init__( axes : bool Optionally add/remove axes from the image. num_ticks : tuple[int | None] - The number of x and y ticks to display on the image. + The number of x and y ticks to display on the iage. save : bool Whether to save the image. - savefig_format : str + savefig_format : str, optional Format to save the image as. histogram_log_axis : bool - Optionally use a logarithmic y axis for the histogram plots. - histogram_bins : int + Optionally use a loagrithmic y-axis for the histogram plots. + histogram_bins : int, optional Number of bins for histograms to use. - savefig_dpi : str | float | None + savefig_dpi : str | float, optional The resolution of the saved plot (default 'figure'). """ if style is None: @@ -241,6 +241,7 @@ def __init__( self.filename = filename self.pixel_to_nm_scaling = pixel_to_nm_scaling self.masked_array = masked_array + self.plot_coords = plot_coords self.title = title self.image_type = image_type self.image_set = image_set @@ -304,7 +305,7 @@ def plot_and_save(self): # Only plot if image_set is "all" (i.e. user wants all images) or an image is in the core_set if self.image_set == "all" or self.core_set: fig, ax = self.save_figure() - LOGGER.info( + LOGGER.debug( f"[{self.filename}] : Image saved to : {str(self.output_dir / self.filename)}.{self.savefig_format}" f" | DPI: {self.savefig_dpi}" ) @@ -333,7 +334,6 @@ def save_figure(self): vmax=self.zrange[1], ) if isinstance(self.masked_array, np.ndarray): - self.masked_array[self.masked_array != 0] = 1 mask = np.ma.masked_where(self.masked_array == 0, self.masked_array) ax.imshow( mask, @@ -349,6 +349,15 @@ def save_figure(self): ) patch = [Patch(color=self.mask_cmap(1, 0.7), label="Mask")] plt.legend(handles=patch, loc="upper right", bbox_to_anchor=(1.02, 1.09)) + # if coordinates are provided (such as in splines, plot those) + elif self.plot_coords is not None: + for grain_coords in self.plot_coords: + ax.plot( + grain_coords[:, 1] * self.pixel_to_nm_scaling, + (shape[0] - grain_coords[:, 0]) * self.pixel_to_nm_scaling, + c="c", + linewidth=1, + ) plt.title(self.title) plt.xlabel("Nanometres") diff --git a/topostats/processing.py b/topostats/processing.py index 851cc39c2d..efd598059c 100644 --- a/topostats/processing.py +++ b/topostats/processing.py @@ -6,6 +6,7 @@ from pathlib import Path import numpy as np +import numpy.typing as npt import pandas as pd from topostats import __version__ @@ -14,38 +15,45 @@ from topostats.grainstats import GrainStats from topostats.io import get_out_path, save_topostats_file from topostats.logs.logs import LOGGER_NAME, setup_logger +from topostats.plotting import plot_crossing_linetrace_halfmax from topostats.plottingfuncs import Images, add_pixel_to_nm_to_plotting_config from topostats.statistics import image_statistics -from topostats.tracing.dnatracing import trace_image +from topostats.tracing.disordered_tracing import trace_image_disordered +from topostats.tracing.nodestats import nodestats_image +from topostats.tracing.ordered_tracing import ordered_tracing_image +from topostats.tracing.splining import splining_image from topostats.utils import create_empty_dataframe # pylint: disable=broad-except # pylint: disable=line-too-long # pylint: disable=too-many-arguments # pylint: disable=too-many-branches +# pylint: disable=too-many-lines # pylint: disable=too-many-locals # pylint: disable=too-many-statements # pylint: disable=too-many-nested-blocks +# pylint: disable=too-many-positional-arguments # pylint: disable=unnecessary-dict-index-lookup +# pylint: disable=too-many-lines LOGGER = setup_logger(LOGGER_NAME) def run_filters( - unprocessed_image: np.ndarray, + unprocessed_image: npt.NDArray, pixel_to_nm_scaling: float, filename: str, filter_out_path: Path, core_out_path: Path, filter_config: dict, plotting_config: dict, -) -> np.ndarray | None: +) -> npt.NDArray | None: """ Filter and flatten an image. Optionally plots the results, returning the flattened image. Parameters ---------- - unprocessed_image : np.ndarray + unprocessed_image : npt.NDArray Image to be flattened. pixel_to_nm_scaling : float Scaling factor for converting pixel length scales to nanometres. @@ -63,7 +71,7 @@ def run_filters( Returns ------- - Union[np.ndarray, None] + npt.NDArray | None Either a numpy array of the flattened image, or None if an error occurs or flattening is disabled in the configuration. """ @@ -120,20 +128,20 @@ def run_filters( def run_grains( # noqa: C901 - image: np.ndarray, + image: npt.NDArray, pixel_to_nm_scaling: float, filename: str, grain_out_path: Path, core_out_path: Path, plotting_config: dict, grains_config: dict, -): +) -> dict | None: """ Identify grains (molecules) and optionally plots the results. Parameters ---------- - image : np.ndarray + image : npt.NDArray 2d numpy array image to find grains in. pixel_to_nm_scaling : float Scaling factor for converting pixel length scales to nanometres. I.e. the number of pixels per nanometre. @@ -150,7 +158,7 @@ def run_grains( # noqa: C901 Returns ------- - Union[dict, None] + dict | None Either None in the case of error or grain finding being disabled or a dictionary with keys of "above" and or "below" containing binary masks depicting where grains have been detected. @@ -174,7 +182,7 @@ def run_grains( # noqa: C901 if len(grains.region_properties[direction]) == 0: LOGGER.warning(f"[{filename}] : No grains found for direction {direction}") except Exception as e: - LOGGER.error(f"[{filename}] : An error occurred during grain finding, skipping grainstats and dnatracing.") + LOGGER.error(f"[{filename}] : An error occurred during grain finding, skipping following steps.") LOGGER.error(f"[{filename}] : The error: {e}") else: for direction, region_props in grains.region_properties.items(): @@ -196,7 +204,9 @@ def run_grains( # noqa: C901 array = array[:, :, 1] LOGGER.info(f"[{filename}] : Plotting {plot_name} image") plotting_config["plot_dict"][plot_name]["output_dir"] = grain_out_path_direction - Images(array, **plotting_config["plot_dict"][plot_name]).plot_and_save() + Images( + data=np.zeros_like(array), masked_array=array, **plotting_config["plot_dict"][plot_name] + ).plot_and_save() # Make a plot of coloured regions with bounding boxes plotting_config["plot_dict"]["bounding_boxes"]["output_dir"] = grain_out_path_direction Images( @@ -207,7 +217,8 @@ def run_grains( # noqa: C901 plotting_config["plot_dict"]["coloured_boxes"]["output_dir"] = grain_out_path_direction # hard code to class index 1, as this implementation is not yet generalised. Images( - grains.directions[direction]["labelled_regions_02"][:, :, 1], + data=np.zeros_like(grains.directions[direction]["labelled_regions_02"][:, :, 1]), + masked_array=grains.directions[direction]["labelled_regions_02"][:, :, 1], **plotting_config["plot_dict"]["coloured_boxes"], region_properties=grains.region_properties[direction], ).plot_and_save() @@ -218,7 +229,7 @@ def run_grains( # noqa: C901 Images( image, filename=f"{filename}_{direction}_masked", - masked_array=grains.directions[direction]["removed_small_objects"][:, :, 1], + masked_array=grains.directions[direction]["removed_small_objects"][:, :, 1].astype(bool), **plotting_config["plot_dict"][plot_name], region_properties=grains.region_properties[direction], ).plot_and_save() @@ -236,16 +247,17 @@ def run_grains( # noqa: C901 return grain_masks # Otherwise, return None and warn grainstats is disabled - LOGGER.info(f"[{filename}] Detection of grains disabled, returning empty data frame.") + LOGGER.info(f"[{filename}] Detection of grains disabled, GrainStats will not be run.") return None def run_grainstats( - image: np.ndarray, + image: npt.NDArray, pixel_to_nm_scaling: float, grain_masks: dict, filename: str, + basename: Path, grainstats_config: dict, plotting_config: dict, grain_out_path: Path, @@ -255,7 +267,7 @@ def run_grainstats( Parameters ---------- - image : np.ndarray + image : npt.NDArray 2D numpy array image for grain statistics calculations. pixel_to_nm_scaling : float Scaling factor for converting pixel length scales to nanometres. @@ -265,6 +277,8 @@ def run_grainstats( boolean arrays indicating the pixels that have been masked as grains. filename : str Name of the image. + basename : Path + Path to directory containing the image. grainstats_config : dict Dictionary of configuration for the GrainStats class to be used when initialised. plotting_config : dict @@ -348,6 +362,7 @@ def run_grainstats( raise ValueError( "grainstats dictionary has neither 'above' nor 'below' keys. This should be impossible." ) + grainstats_df["basename"] = basename.parent return grainstats_df, height_profiles_dict @@ -363,157 +378,509 @@ def run_grainstats( return create_empty_dataframe(), {} -def run_dnatracing( # noqa: C901 - image: np.ndarray, +def run_disordered_trace( + image: npt.NDArray, grain_masks: dict, pixel_to_nm_scaling: float, - image_path: Path, filename: str, + basename: str, core_out_path: Path, - grain_out_path: Path, - dnatracing_config: dict, + tracing_out_path: Path, + disordered_tracing_config: dict, plotting_config: dict, - results_df: pd.DataFrame = None, -): + grainstats_df: pd.DataFrame = None, +) -> dict: """ - Trace DNA molecule for the supplied grains adding results to statistics data frames and optionally plot results. + Skeletonise and prune grains, adding results to statistics data frames and optionally plot results. Parameters ---------- - image : np.ndarray - Image containing the DNA to pass to the dna tracing function. + image : npt.ndarray + Image containing the grains to pass to the tracing function. grain_masks : dict - Dictionary of grain masks, keys "above" or "below" with values of 2d numpy - boolean arrays indicating the pixels that have been masked as grains. + Dictionary of grain masks, keys "above" or "below" with values of 2D Numpy boolean arrays indicating the pixels + that have been masked as grains. pixel_to_nm_scaling : float - Scaling factor for converting pixel length scales to nanometres. - ie the number of pixels per nanometre. - image_path : Path - Path to the image file. Used for DataFrame indexing. + Scaling factor for converting pixel length scales to nanometers, i.e. the number of pixesl per nanometres (nm). filename : str Name of the image. + basename : Path + Path to directory containing the image. core_out_path : Path - General output directory for outputs such as the grain statistics - DataFrame. - grain_out_path : Path - Directory to save optional dna tracing visual information to. - dnatracing_config : dict - Dictionary configuration for the dna tracing function. + Path to save the core disordered trace image to. + tracing_out_path : Path + Path to save the optional, diagnostic disordered trace images to. + disordered_tracing_config : dict + Dictionary configuration for obtaining a disordered trace representation of the grains. plotting_config : dict Dictionary configuration for plotting images. - results_df : pd.DataFrame - Pandas DataFrame containing grain statistics. + grainstats_df : pd.DataFrame, optional + The grain statistics dataframe to be added to. by default None. Returns ------- - pd.DataFrame - Pandas DataFrame containing grain statistics and dna tracing statistics. - Keys are file path and molecule number. + dict + Dictionary of "grain_" keys and Nx2 coordinate arrays of the disordered grain trace. """ - # Create empty dataframe is none is passed - if results_df is None: - results_df = create_empty_dataframe() - - # Run dnatracing - try: - grain_trace_data = None - if dnatracing_config["run"]: - dnatracing_config.pop("run") - LOGGER.info(f"[{filename}] : *** DNA Tracing ***") - tracing_stats = defaultdict() - grain_trace_data = defaultdict() + if disordered_tracing_config["run"]: + disordered_tracing_config.pop("run") + LOGGER.info(f"[{filename}] : *** Disordered Tracing ***") + disordered_traces = defaultdict() + disordered_trace_grainstats = pd.DataFrame() + disordered_tracing_stats_image = pd.DataFrame() + try: + # run image using directional grain masks for direction, _ in grain_masks.items(): - # Get the DNA class mask from the tensor - LOGGER.info(f"[{filename}] : Mask dimensions: {grain_masks[direction].shape}") + # Check if there are grains assert len(grain_masks[direction].shape) == 3, "Grain masks should be 3D tensors" dna_class_mask = grain_masks[direction][:, :, 1] - LOGGER.info(f"[{filename}] : DNA Mask dimensions: {dna_class_mask.shape}") - - tracing_results = trace_image( + if np.max(dna_class_mask) == 0: + LOGGER.warning( + f"[{filename}] : No grains exist for the {direction} direction. Skipping disordered_tracing for {direction}." + ) + raise ValueError(f"No grains exist for the {direction} direction") + + # if grains are found + ( + disordered_traces_cropped_data, + _disordered_trace_grainstats, + disordered_tracing_images, + disordered_tracing_stats, + ) = trace_image_disordered( image=image, grains_mask=dna_class_mask, filename=filename, pixel_to_nm_scaling=pixel_to_nm_scaling, - **dnatracing_config, + **disordered_tracing_config, ) - tracing_stats[direction] = tracing_results["statistics"] - ordered_traces = tracing_results["all_ordered_traces"] - cropped_images: dict[int, np.ndarray] = tracing_results["all_cropped_images"] - image_spline_trace = tracing_results["image_spline_trace"] - tracing_stats[direction]["threshold"] = direction - - grain_trace_data[direction] = { - "ordered_traces": ordered_traces, - "cropped_images": cropped_images, - "ordered_trace_heights": tracing_results["all_ordered_trace_heights"], - "ordered_trace_cumulative_distances": tracing_results["all_ordered_trace_cumulative_distances"], - "splined_traces": tracing_results["all_splined_traces"], - } - - # Plot traces for the whole image + # save per image new grainstats stats + _disordered_trace_grainstats["threshold"] = direction + disordered_trace_grainstats = pd.concat([disordered_trace_grainstats, _disordered_trace_grainstats]) + + disordered_tracing_stats["threshold"] = direction + disordered_tracing_stats["basename"] = basename.parent + disordered_tracing_stats_image = pd.concat([disordered_tracing_stats_image, disordered_tracing_stats]) + + # append direction results to dict + disordered_traces[direction] = disordered_traces_cropped_data + # save plots Images( image, + masked_array=disordered_tracing_images.pop("pruned_skeleton"), output_dir=core_out_path, - filename=f"{filename}_{direction}_traced", - masked_array=image_spline_trace, - **plotting_config["plot_dict"]["all_molecule_traces"], + filename=f"{filename}_{direction}_disordered_trace", + **plotting_config["plot_dict"]["pruned_skeleton"], ).plot_and_save() + for plot_name, image_value in disordered_tracing_images.items(): + Images( + image, + masked_array=image_value, + output_dir=tracing_out_path / direction, + **plotting_config["plot_dict"][plot_name], + ).plot_and_save() - # Plot traces on each grain individually - if plotting_config["image_set"] == "all": - for grain_index, grain_trace in ordered_traces.items(): - cropped_image = cropped_images[grain_index] - grain_trace_mask = np.zeros(cropped_image.shape) - # Grain traces can be None if they do not trace successfully. Eg if they are too small. - if grain_trace is not None: - for coordinate in grain_trace: - grain_trace_mask[coordinate[0], coordinate[1]] = 1 - Images( - cropped_image, - output_dir=grain_out_path / direction, - filename=f"{filename}_grain_trace_{grain_index}", - masked_array=grain_trace_mask, - **plotting_config["plot_dict"]["single_molecule_trace"], - ).plot_and_save() + # merge grainstats data with other dataframe + resultant_grainstats = ( + pd.merge(grainstats_df, disordered_trace_grainstats, on=["image", "threshold", "grain_number"]) + if grainstats_df is not None + else disordered_trace_grainstats + ) - # Set create tracing_stats_df from above and below results - if "above" in tracing_stats and "below" in tracing_stats: - tracing_stats_df = pd.concat([tracing_stats["below"], tracing_stats["above"]]) - elif "above" in tracing_stats: - tracing_stats_df = tracing_stats["above"] - elif "below" in tracing_stats: - tracing_stats_df = tracing_stats["below"] - else: - raise ValueError( - "tracing_stats dictionary has neither 'above' nor 'below' keys. This should be impossible." + return disordered_traces, resultant_grainstats, disordered_tracing_stats_image + + except Exception as e: + LOGGER.info( + f"[{filename}] : Disordered tracing failed - skipping. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) + return {}, grainstats_df, None + + else: + LOGGER.info(f"[{filename}] Calculation of Disordered Tracing disabled, returning empty dictionary.") + return {}, grainstats_df, None + + +def run_nodestats( # noqa: C901 + image: npt.NDArray, + disordered_tracing_data: dict, + pixel_to_nm_scaling: float, + filename: str, + core_out_path: Path, + tracing_out_path: Path, + nodestats_config: dict, + plotting_config: dict, + grainstats_df: pd.DataFrame = None, +) -> tuple[dict, pd.DataFrame]: + """ + Analyse crossing points in grains adding results to statistics data frames and optionally plot results. + + Parameters + ---------- + image : npt.ndarray + Image containing the DNA to pass to the tracing function. + disordered_tracing_data : dict + Dictionary of skeletonised and pruned grain masks. Result from "run_disordered_tracing". + pixel_to_nm_scaling : float + Scaling factor for converting pixel length scales to nanometers, i.e. the number of pixels per nanometres (nm). + filename : str + Name of the image. + core_out_path : Path + Path to save the core NodeStats image to. + tracing_out_path : Path + Path to save optional, diagnostic NodeStats images to. + nodestats_config : dict + Dictionary configuration for analysing the crossing points. + plotting_config : dict + Dictionary configuration for plotting images. + grainstats_df : pd.DataFrame, optional + The grain statistics dataframe to bee added to. by default None. + + Returns + ------- + tuple[dict, pd.DataFrame] + A NodeStats analysis dictionary and grainstats metrics dataframe. + """ + if nodestats_config["run"]: + nodestats_config.pop("run") + LOGGER.info(f"[{filename}] : *** Nodestats ***") + nodestats_whole_data = defaultdict() + nodestats_grainstats = pd.DataFrame() + try: + # run image using directional grain masks + for direction, disordered_tracing_direction_data in disordered_tracing_data.items(): + ( + nodestats_data, + _nodestats_grainstats, + nodestats_full_images, + nodestats_branch_images, + ) = nodestats_image( + image=image, + disordered_tracing_direction_data=disordered_tracing_direction_data, + filename=filename, + pixel_to_nm_scaling=pixel_to_nm_scaling, + **nodestats_config, ) + + # save per image new grainstats stats + _nodestats_grainstats["threshold"] = direction + nodestats_grainstats = pd.concat([nodestats_grainstats, _nodestats_grainstats]) + + # append direction results to dict + nodestats_whole_data[direction] = {"stats": nodestats_data, "images": nodestats_branch_images} + + # save whole image plots + Images( + filename=f"{filename}_{direction}_nodes", + data=image, + masked_array=nodestats_full_images.pop("connected_nodes"), + output_dir=core_out_path, + **plotting_config["plot_dict"]["connected_nodes"], + ).plot_and_save() + for plot_name, image_value in nodestats_full_images.items(): + Images( + image, + masked_array=image_value, + output_dir=tracing_out_path / direction, + **plotting_config["plot_dict"][plot_name], + ).plot_and_save() + + # plot single node images + for mol_no, mol_stats in nodestats_data.items(): + if mol_stats is not None: + for node_no, single_node_stats in mol_stats.items(): + # plot the node and branch_mask images + for cropped_image_type, cropped_image in nodestats_branch_images[mol_no]["nodes"][ + node_no + ].items(): + Images( + nodestats_branch_images[mol_no]["grain"]["grain_image"], + masked_array=cropped_image, + output_dir=tracing_out_path / direction / "nodes", + filename=f"{mol_no}_{node_no}_{cropped_image_type}", + **plotting_config["plot_dict"][cropped_image_type], + ).plot_and_save() + + # plot crossing height linetrace + if plotting_config["image_set"] == "all": + if not single_node_stats["error"]: + fig, _ = plot_crossing_linetrace_halfmax( + branch_stats_dict=single_node_stats["branch_stats"], + mask_cmap=plotting_config["plot_dict"]["node_line_trace"]["mask_cmap"], + title=plotting_config["plot_dict"]["node_line_trace"]["mask_cmap"], + ) + fig.savefig( + tracing_out_path + / direction + / "nodes" + / f"{mol_no}_{node_no}_linetrace_halfmax.svg", + format="svg", + ) + LOGGER.info(f"[{filename}] : Finished Plotting NodeStats Images") + + # merge grainstats data with other dataframe + resultant_grainstats = ( + pd.merge(grainstats_df, nodestats_grainstats, on=["image", "threshold", "grain_number"]) + if grainstats_df is not None + else nodestats_grainstats + ) + + # merge all image dictionaries + return nodestats_whole_data, resultant_grainstats + + except Exception as e: LOGGER.info( - f"[{filename}] : Combining {list(tracing_stats.keys())} grain statistics and dnatracing statistics" + f"[{filename}] : NodeStats failed - skipping. Consider raising an issue on GitHub. Error: ", exc_info=e ) - # NB - Merge on image, molecule and threshold because we may have above and below molecules which - # gives duplicate molecule numbers as they are processed separately, if tracing stats - # are not available (because skeleton was too small), grainstats are still retained. - results = results_df.merge(tracing_stats_df, on=["image", "threshold", "molecule_number"], how="left") - results["basename"] = image_path.parent + return nodestats_whole_data, nodestats_grainstats - return results, grain_trace_data + else: + LOGGER.info(f"[{filename}] : Calculation of nodestats disabled, returning empty dataframe.") + return None, grainstats_df - # Otherwise, return the passed in dataframe and warn that tracing is disabled - LOGGER.info(f"[{filename}] Calculation of DNA Tracing disabled, returning grainstats data frame.") - results = results_df - results["basename"] = image_path.parent - return results, grain_trace_data +# need to add in the molstats here +def run_ordered_tracing( + image: npt.NDArray, + disordered_tracing_data: dict, + nodestats_data: dict, + filename: str, + basename: Path, + core_out_path: Path, + tracing_out_path: Path, + ordered_tracing_config: dict, + plotting_config: dict, + grainstats_df: pd.DataFrame = None, +) -> tuple: + """ + Order coordinates of traces, adding results to statistics data frames and optionally plot results. - except Exception: - # If no results we need a dummy dataframe to return. - LOGGER.warning( - f"[{filename}] : Errors occurred whilst calculating DNA tracing statistics, " "returning grain statistics" - ) - results = results_df - results["basename"] = image_path.parent - grain_trace_data = None - return results, grain_trace_data + Parameters + ---------- + image : npt.ndarray + Image containing the DNA to pass to the tracing function. + disordered_tracing_data : dict + Dictionary of skeletonised and pruned grain masks. Result from "run_disordered_tracing". + nodestats_data : dict + Dictionary of images and statistics from the NodeStats analysis. Result from "run_nodestats". + filename : str + Name of the image. + basename : Path + The path of the files' parent directory. + core_out_path : Path + Path to save the core ordered tracing image to. + tracing_out_path : Path + Path to save optional, diagnostic ordered trace images to. + ordered_tracing_config : dict + Dictionary configuration for obtaining an ordered trace representation of the skeletons. + plotting_config : dict + Dictionary configuration for plotting images. + grainstats_df : pd.DataFrame, optional + The grain statistics dataframe to be added to. by default None. + + Returns + ------- + tuple[dict, pd.DataFrame] + A NodeStats analysis dictionary and grainstats metrics dataframe. + """ + if ordered_tracing_config["run"]: + ordered_tracing_config.pop("run") + LOGGER.info(f"[{filename}] : *** Ordered Tracing ***") + ordered_tracing_image_data = defaultdict() + ordered_tracing_molstats = pd.DataFrame() + ordered_tracing_grainstats = pd.DataFrame() + + try: + # run image using directional grain masks + for direction, disordered_tracing_direction_data in disordered_tracing_data.items(): + # Check if there are grains + if not disordered_tracing_direction_data: + LOGGER.warning( + f"[{filename}] : No grains exist for the {direction} direction. Skipping ordered_tracing for {direction}." + ) + raise ValueError(f"No grains exist for the {direction} direction") + + # if grains are found + ( + ordered_tracing_data, + _ordered_tracing_grainstats, + _ordered_tracing_molstats, + ordered_tracing_full_images, + ) = ordered_tracing_image( + image=image, + disordered_tracing_direction_data=disordered_tracing_direction_data, + nodestats_direction_data=nodestats_data[direction], + filename=filename, + **ordered_tracing_config, + ) + + # save per image new grainstats stats + _ordered_tracing_grainstats["threshold"] = direction + ordered_tracing_grainstats = pd.concat([ordered_tracing_grainstats, _ordered_tracing_grainstats]) + _ordered_tracing_molstats["threshold"] = direction + ordered_tracing_molstats = pd.concat([ordered_tracing_molstats, _ordered_tracing_molstats]) + + # append direction results to dict + ordered_tracing_image_data[direction] = ordered_tracing_data + + # save whole image plots + plotting_config["plot_dict"]["ordered_traces"]["core_set"] = True # fudge around core having own cmap + Images( + filename=f"{filename}_{direction}_ordered_traces", + data=image, + masked_array=ordered_tracing_full_images.pop("ordered_traces"), + output_dir=core_out_path, + **plotting_config["plot_dict"]["ordered_traces"], + ).plot_and_save() + # save optional diagnostic plots (those with core_set = False) + for plot_name, image_value in ordered_tracing_full_images.items(): + Images( + image, + masked_array=image_value, + output_dir=tracing_out_path / direction, + **plotting_config["plot_dict"][plot_name], + ).plot_and_save() + + LOGGER.info(f"[{filename}] : Finished Plotting Ordered Tracing Images") + + # merge grainstats data with other dataframe + resultant_grainstats = ( + pd.merge(grainstats_df, ordered_tracing_grainstats, on=["image", "threshold", "grain_number"]) + if grainstats_df is not None + else ordered_tracing_grainstats + ) + + ordered_tracing_molstats["basename"] = basename.parent + + # merge all image dictionaries + return ordered_tracing_image_data, resultant_grainstats, ordered_tracing_molstats + + except Exception as e: + LOGGER.info( + f"[{filename}] : Ordered Tracing failed - skipping. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) + return ordered_tracing_image_data, grainstats_df, None + + return None, grainstats_df, None + + +def run_splining( + image: npt.NDArray, + ordered_tracing_data: dict, + pixel_to_nm_scaling: float, + filename: str, + core_out_path: Path, + splining_config: dict, + plotting_config: dict, + grainstats_df: pd.DataFrame = None, + molstats_df: pd.DataFrame = None, +) -> tuple: + """ + Smooth the ordered trace coordinates, adding results to statistics data frames and optionally plot results. + + Parameters + ---------- + image : npt.NDArray + Image containing the DNA to pass to the tracing function. + ordered_tracing_data : dict + Dictionary of ordered coordinates. Result from "run_ordered_tracing". + pixel_to_nm_scaling : float + Scaling factor for converting pixel length scales to nanometers, i.e. the number of pixels per nanometres (nm). + filename : str + Name of the image. + core_out_path : Path + Path to save the core ordered tracing image to. + splining_config : dict + Dictionary configuration for obtaining an ordered trace representation of the skeletons. + plotting_config : dict + Dictionary configuration for plotting images. + grainstats_df : pd.DataFrame, optional + The grain statistics dataframe to be added to. by default None. + molstats_df : pd.DataFrame, optional + The molecule statistics dataframe to be added to. by default None. + + Returns + ------- + tuple[dict, pd.DataFrame] + A smooth curve analysis dictionary and grainstats metrics dataframe. + """ + if splining_config["run"]: + splining_config.pop("run") + LOGGER.info(f"[{filename}] : *** Splining ***") + splined_image_data = defaultdict() + splining_grainstats = pd.DataFrame() + splining_molstats = pd.DataFrame() + + try: + # run image using directional grain masks + for direction, ordered_tracing_direction_data in ordered_tracing_data.items(): + if not ordered_tracing_direction_data: + LOGGER.warning( + f"[{filename}] : No grains exist for the {direction} direction. Skipping disordered_tracing for {direction}." + ) + splining_grainstats = create_empty_dataframe() + splining_molstats = create_empty_dataframe(columns=["image", "basename", "threshold"]) + raise ValueError(f"No grains exist for the {direction} direction") + + # if grains are found + ( + splined_data, + _splining_grainstats, + _splining_molstats, + ) = splining_image( + image=image, + ordered_tracing_direction_data=ordered_tracing_direction_data, + filename=filename, + pixel_to_nm_scaling=pixel_to_nm_scaling, + **splining_config, + ) + + # save per image new grainstats stats + _splining_grainstats["threshold"] = direction + splining_grainstats = pd.concat([splining_grainstats, _splining_grainstats]) + _splining_molstats["threshold"] = direction + splining_molstats = pd.concat([splining_molstats, _splining_molstats]) + + # append direction results to dict + splined_image_data[direction] = splined_data + + # Plot traces on each grain individually + all_splines = [] + for _, grain_dict in splined_data.items(): + for _, mol_dict in grain_dict.items(): + all_splines.append(mol_dict["spline_coords"] + mol_dict["bbox"][:2]) + Images( + data=image, + output_dir=core_out_path, + filename=f"{filename}_{direction}_all_splines", + plot_coords=all_splines, + **plotting_config["plot_dict"]["splined_trace"], + ).plot_and_save() + LOGGER.info(f"[{filename}] : Finished Plotting Splining Images") + + # merge grainstats data with other dataframe + resultant_grainstats = ( + pd.merge(grainstats_df, splining_grainstats, on=["image", "threshold", "grain_number"]) + if grainstats_df is not None + else splining_grainstats + ) + # merge molstats data with other dataframe + resultant_molstats = ( + pd.merge(molstats_df, splining_molstats, on=["image", "threshold", "grain_number", "molecule_number"]) + if molstats_df is not None + else splining_molstats + ) + + # merge all image dictionaries + return splined_image_data, resultant_grainstats, resultant_molstats + + except Exception as e: + LOGGER.error( + f"[{filename}] : Splining failed - skipping. Consider raising an issue on GitHub. Error: ", exc_info=e + ) + return splined_image_data, splining_grainstats, splining_molstats + + return None, grainstats_df, molstats_df def get_out_paths(image_path: Path, base_dir: Path, output_dir: Path, filename: str, plotting_config: dict): @@ -544,12 +911,17 @@ def get_out_paths(image_path: Path, base_dir: Path, output_dir: Path, filename: core_out_path.mkdir(parents=True, exist_ok=True) filter_out_path = core_out_path / filename / "filters" grain_out_path = core_out_path / filename / "grains" + tracing_out_path = core_out_path / filename / "dnatracing" if plotting_config["image_set"] == "all": filter_out_path.mkdir(exist_ok=True, parents=True) Path.mkdir(grain_out_path / "above", parents=True, exist_ok=True) Path.mkdir(grain_out_path / "below", parents=True, exist_ok=True) + Path.mkdir(tracing_out_path / "above", parents=True, exist_ok=True) + Path.mkdir(tracing_out_path / "below", parents=True, exist_ok=True) + Path.mkdir(tracing_out_path / "above" / "nodes", parents=True, exist_ok=True) + Path.mkdir(tracing_out_path / "below" / "nodes", parents=True, exist_ok=True) - return core_out_path, filter_out_path, grain_out_path + return core_out_path, filter_out_path, grain_out_path, tracing_out_path def process_scan( @@ -558,7 +930,10 @@ def process_scan( filter_config: dict, grains_config: dict, grainstats_config: dict, - dnatracing_config: dict, + disordered_tracing_config: dict, + nodestats_config: dict, + ordered_tracing_config: dict, + splining_config: dict, plotting_config: dict, output_dir: str | Path = "output", ) -> tuple[dict, pd.DataFrame, dict]: @@ -567,10 +942,10 @@ def process_scan( Parameters ---------- - topostats_object : dict[str, Union[np.ndarray, Path, float]] - A dictionary with keys 'image', 'img_path' and 'px_2_nm' containing a file or frames' image, it's path and it's + topostats_object : dict[str, Union[npt.NDArray, Path, float]] + A dictionary with keys 'image', 'img_path' and 'pixel_to_nm_scaling' containing a file or frames' image, it's path and it's pixel to namometre scaling value. - base_dir : Union[str, Path] + base_dir : str | Path Directory to recursively search for files, if not specified the current directory is scanned. filter_config : dict Dictionary of configuration options for running the Filter stage. @@ -578,11 +953,17 @@ def process_scan( Dictionary of configuration options for running the Grain detection stage. grainstats_config : dict Dictionary of configuration options for running the Grain Statistics stage. - dnatracing_config : dict - Dictionary of configuration options for running the DNA Tracing stage. + disordered_tracing_config : dict + Dictionary configuration for obtaining a disordered trace representation of the grains. + nodestats_config : dict + Dictionary of configuration options for running the NodeStats stage. + ordered_tracing_config : dict + Dictionary configuration for obtaining an ordered trace representation of the skeletons. + splining_config : dict + Dictionary of configuration options for running the splining stage. plotting_config : dict Dictionary of configuration options for plotting figures. - output_dir : Union[str, Path] + output_dir : str | Path Directory to save output to, it will be created if it does not exist. If it already exists then it is possible that output will be over-written. @@ -592,7 +973,7 @@ def process_scan( TopoStats dictionary object, DataFrame containing grain statistics and dna tracing statistics, and dictionary containing general image statistics. """ - core_out_path, filter_out_path, grain_out_path = get_out_paths( + core_out_path, filter_out_path, grain_out_path, tracing_out_path = get_out_paths( image_path=topostats_object["img_path"], base_dir=base_dir, output_dir=output_dir, @@ -634,36 +1015,82 @@ def process_scan( if "above" in topostats_object["grain_masks"].keys() or "below" in topostats_object["grain_masks"].keys(): # Grainstats : - results_df, height_profiles = run_grainstats( + grainstats_df, height_profiles = run_grainstats( image=topostats_object["image_flattened"], pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], grain_masks=topostats_object["grain_masks"], filename=topostats_object["filename"], + basename=topostats_object["img_path"], grainstats_config=grainstats_config, plotting_config=plotting_config, grain_out_path=grain_out_path, ) + topostats_object["height_profiles"] = height_profiles - # DNAtracing - results_df, grain_trace_data = run_dnatracing( + # Disordered Tracing + disordered_traces_data, grainstats_df, disordered_tracing_stats = run_disordered_trace( image=topostats_object["image_flattened"], - pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], grain_masks=topostats_object["grain_masks"], + pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], filename=topostats_object["filename"], + basename=topostats_object["img_path"], core_out_path=core_out_path, - grain_out_path=grain_out_path, - image_path=topostats_object["img_path"], + tracing_out_path=tracing_out_path, + disordered_tracing_config=disordered_tracing_config, + grainstats_df=grainstats_df, plotting_config=plotting_config, - dnatracing_config=dnatracing_config, - results_df=results_df, ) + topostats_object["disordered_traces"] = disordered_traces_data - # Add grain trace data and height profiles to topostats object - topostats_object["grain_trace_data"] = grain_trace_data - topostats_object["height_profiles"] = height_profiles + # Nodestats + nodestats, grainstats_df = run_nodestats( + image=topostats_object["image_flattened"], + disordered_tracing_data=topostats_object["disordered_traces"], + pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], + filename=topostats_object["filename"], + core_out_path=core_out_path, + tracing_out_path=tracing_out_path, + plotting_config=plotting_config, + nodestats_config=nodestats_config, + grainstats_df=grainstats_df, + ) + + # Ordered Tracing + ordered_tracing, grainstats_df, molstats_df = run_ordered_tracing( + image=topostats_object["image_flattened"], + disordered_tracing_data=topostats_object["disordered_traces"], + nodestats_data=nodestats, + filename=topostats_object["filename"], + basename=topostats_object["img_path"], + core_out_path=core_out_path, + tracing_out_path=tracing_out_path, + ordered_tracing_config=ordered_tracing_config, + plotting_config=plotting_config, + grainstats_df=grainstats_df, + ) + topostats_object["ordered_traces"] = ordered_tracing + topostats_object["nodestats"] = nodestats # looks weird but ordered adds an extra field + + # splining + splined_data, grainstats_df, molstats_df = run_splining( + image=topostats_object["image_flattened"], + ordered_tracing_data=topostats_object["ordered_traces"], + pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], + filename=topostats_object["filename"], + core_out_path=core_out_path, + plotting_config=plotting_config, + splining_config=splining_config, + grainstats_df=grainstats_df, + molstats_df=molstats_df, + ) + + # Add grain trace data to topostats object + topostats_object["splining"] = splined_data else: - results_df = create_empty_dataframe() + grainstats_df = create_empty_dataframe() + molstats_df = create_empty_dataframe() + disordered_tracing_stats = create_empty_dataframe() height_profiles = {} # Get image statistics @@ -677,7 +1104,7 @@ def process_scan( image_stats = image_statistics( image=image_for_image_stats, filename=topostats_object["filename"], - results_df=results_df, + results_df=grainstats_df, pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], ) @@ -686,10 +1113,25 @@ def process_scan( output_dir=core_out_path, filename=str(topostats_object["filename"]), topostats_object=topostats_object ) - return topostats_object["img_path"], results_df, image_stats, height_profiles + return ( + topostats_object["img_path"], + grainstats_df, + height_profiles, + image_stats, + disordered_tracing_stats, + molstats_df, + ) -def check_run_steps(filter_run: bool, grains_run: bool, grainstats_run: bool, dnatracing_run: bool) -> None: +def check_run_steps( # noqa: C901 + filter_run: bool, + grains_run: bool, + grainstats_run: bool, + disordered_tracing_run: bool, + nodestats_run: bool, + ordered_tracing_run: bool, + splining_run: bool, +) -> None: """ Check options for running steps (Filter, Grain, Grainstats and DNA tracing) are logically consistent. @@ -703,16 +1145,68 @@ def check_run_steps(filter_run: bool, grains_run: bool, grainstats_run: bool, dn Flag for running Grains. grainstats_run : bool Flag for running GrainStats. - dnatracing_run : bool + disordered_tracing_run : bool + Flag for running Disordered Tracing. + nodestats_run : bool + Flag for running NodeStats. + ordered_tracing_run : bool + Flag for running Ordered Tracing. + splining_run : bool Flag for running DNA Tracing. """ - if dnatracing_run: + LOGGER.debug(f"{filter_run=}") + LOGGER.debug(f"{grains_run=}") + LOGGER.debug(f"{grainstats_run=}") + LOGGER.debug(f"{disordered_tracing_run=}") + LOGGER.debug(f"{nodestats_run=}") + LOGGER.debug(f"{ordered_tracing_run=}") + LOGGER.debug(f"{splining_run=}") + if splining_run: + if ordered_tracing_run is False: + LOGGER.error("Splining enabled but Ordered Tracing disabled. Please check your configuration file.") + if nodestats_run is False: + LOGGER.error("Splining enabled but NodeStats disabled. Tracing will use the 'old' method.") + if disordered_tracing_run is False: + LOGGER.error("Splining enabled but Disordered Tracing disabled. Please check your configuration file.") + elif grainstats_run is False: + LOGGER.error("Splining enabled but Grainstats disabled. Please check your configuration file.") + elif grains_run is False: + LOGGER.error("Splining enabled but Grains disabled. Please check your configuration file.") + elif filter_run is False: + LOGGER.error("Splining enabled but Filters disabled. Please check your configuration file.") + else: + LOGGER.info("Configuration run options are consistent, processing can proceed.") + elif ordered_tracing_run: + if disordered_tracing_run is False: + LOGGER.error( + "Ordered Tracing enabled but Disordered Tracing disabled. Please check your configuration file." + ) + elif grainstats_run is False: + LOGGER.error("NodeStats enabled but Grainstats disabled. Please check your configuration file.") + elif grains_run is False: + LOGGER.error("NodeStats enabled but Grains disabled. Please check your configuration file.") + elif filter_run is False: + LOGGER.error("NodeStats enabled but Filters disabled. Please check your configuration file.") + else: + LOGGER.info("Configuration run options are consistent, processing can proceed.") + elif nodestats_run: + if disordered_tracing_run is False: + LOGGER.error("NodeStats enabled but Disordered Tracing disabled. Please check your configuration file.") + elif grainstats_run is False: + LOGGER.error("NodeStats enabled but Grainstats disabled. Please check your configuration file.") + elif grains_run is False: + LOGGER.error("NodeStats enabled but Grains disabled. Please check your configuration file.") + elif filter_run is False: + LOGGER.error("NodeStats enabled but Filters disabled. Please check your configuration file.") + else: + LOGGER.info("Configuration run options are consistent, processing can proceed.") + elif disordered_tracing_run: if grainstats_run is False: - LOGGER.error("DNA tracing enabled but Grainstats disabled. Please check your configuration file.") + LOGGER.error("Disordered Tracing enabled but Grainstats disabled. Please check your configuration file.") elif grains_run is False: - LOGGER.error("DNA tracing enabled but Grains disabled. Please check your configuration file.") + LOGGER.error("Disordered Tracing enabled but Grains disabled. Please check your configuration file.") elif filter_run is False: - LOGGER.error("DNA tracing enabled but Filters disabled. Please check your configuration file.") + LOGGER.error("Disordered Tracing enabled but Filters disabled. Please check your configuration file.") else: LOGGER.info("Configuration run options are consistent, processing can proceed.") elif grainstats_run: diff --git a/topostats/run_topostats.py b/topostats/run_topostats.py index 1d25998c03..ba59276906 100644 --- a/topostats/run_topostats.py +++ b/topostats/run_topostats.py @@ -83,7 +83,10 @@ def run_topostats(args: None = None) -> None: # noqa: C901 filter_run=config["filter"]["run"], grains_run=config["grains"]["run"], grainstats_run=config["grainstats"]["run"], - dnatracing_run=config["dnatracing"]["run"], + disordered_tracing_run=config["disordered_tracing"]["run"], + nodestats_run=config["nodestats"]["run"], + ordered_tracing_run=config["ordered_tracing"]["run"], + splining_run=config["splining"]["run"], ) # Ensures each image has all plotting options which are passed as **kwargs config["plotting"] = update_plotting_config(config["plotting"]) @@ -109,7 +112,10 @@ def run_topostats(args: None = None) -> None: # noqa: C901 filter_config=config["filter"], grains_config=config["grains"], grainstats_config=config["grainstats"], - dnatracing_config=config["dnatracing"], + disordered_tracing_config=config["disordered_tracing"], + nodestats_config=config["nodestats"], + ordered_tracing_config=config["ordered_tracing"], + splining_config=config["splining"], plotting_config=config["plotting"], output_dir=config["output_dir"], ) @@ -124,16 +130,27 @@ def run_topostats(args: None = None) -> None: # noqa: C901 with Pool(processes=config["cores"]) as pool: results = defaultdict() image_stats_all = defaultdict() + mols_results = defaultdict() + disordered_trace_results = defaultdict() height_profile_all = defaultdict() with tqdm( total=len(img_files), desc=f"Processing images from {config['base_dir']}, results are under {config['output_dir']}", ) as pbar: - for img, result, individual_image_stats_df, height_profiles in pool.imap_unordered( + for ( + img, + result, + height_profiles, + individual_image_stats_df, + disordered_trace_result, + mols_result, + ) in pool.imap_unordered( processing_function, scan_data_dict.values(), ): results[str(img)] = result + disordered_trace_results[str(img)] = disordered_trace_result + mols_results[str(img)] = mols_result pbar.update() # Add the dataframe to the results dict @@ -141,6 +158,7 @@ def run_topostats(args: None = None) -> None: # noqa: C901 # Combine all height profiles height_profile_all[str(img)] = height_profiles + # Display completion message for the image LOGGER.info(f"[{img.name}] Processing completed.") @@ -156,6 +174,17 @@ def run_topostats(args: None = None) -> None: # noqa: C901 LOGGER.error("No grains found in any images, consider adjusting your thresholds.") LOGGER.error(error) + try: + disordered_trace_results = pd.concat(disordered_trace_results.values()) + except ValueError as error: + LOGGER.error("No disordered traces found in any images, consider adjusting disordered tracing parameters.") + LOGGER.error(error) + + try: + mols_results = pd.concat(mols_results.values()) + except ValueError as error: + LOGGER.error("No mols found in any images, consider adjusting ordered tracing / splining parameters.") + LOGGER.error(error) # If requested save height profiles if config["grainstats"]["extract_height_profile"]: LOGGER.info(f"Saving all height profiles to {config['output_dir']}/height_profiles.json") @@ -205,11 +234,11 @@ def run_topostats(args: None = None) -> None: # noqa: C901 else: LOGGER.warning( "There are no results to plot, either...\n\n" - "* you have disabled grains/grainstats/dnatracing.\n" + "* you have disabled grains/grainstats etc.\n" "* no grains have been detected across all scans.\n" "* there have been errors.\n\n" "If you are not expecting to detect grains please consider disabling" - "grains/grainstats/dnatracing/plotting/summary_stats. If you are expecting to detect grains" + "grains/grainstats etc/plotting/summary_stats. If you are expecting to detect grains" " please check log-files for further information." ) else: @@ -218,7 +247,7 @@ def run_topostats(args: None = None) -> None: # noqa: C901 # Write statistics to CSV if there is data. if isinstance(results, pd.DataFrame) and not results.isna().values.all(): results.reset_index(inplace=True) - results.set_index(["image", "threshold", "molecule_number"], inplace=True) + results.set_index(["image", "threshold", "grain_number"], inplace=True) results.to_csv(config["output_dir"] / "all_statistics.csv", index=True) save_folder_grainstats(config["output_dir"], config["base_dir"], results) results.reset_index(inplace=True) # So we can access unique image names @@ -226,6 +255,24 @@ def run_topostats(args: None = None) -> None: # noqa: C901 else: images_processed = 0 LOGGER.warning("There are no grainstats or dnatracing statistics to write to CSV.") + + if isinstance(disordered_trace_results, pd.DataFrame) and not disordered_trace_results.isna().values.all(): + disordered_trace_results.reset_index(inplace=True) + disordered_trace_results.set_index(["image", "threshold", "grain_number"], inplace=True) + disordered_trace_results.to_csv(config["output_dir"] / "all_disordered_segment_statistics.csv", index=True) + save_folder_grainstats(config["output_dir"], config["base_dir"], mols_results) + disordered_trace_results.reset_index(inplace=True) # So we can access unique image names + else: + LOGGER.warning("There are no grainstats or disordered tracing statistics to write to CSV.") + + if isinstance(mols_results, pd.DataFrame) and not mols_results.isna().values.all(): + mols_results.reset_index(drop=True, inplace=True) + mols_results.set_index(["image", "threshold", "grain_number"], inplace=True) + mols_results.to_csv(config["output_dir"] / "all_mol_statistics.csv", index=True) + save_folder_grainstats(config["output_dir"], config["base_dir"], mols_results) + mols_results.reset_index(inplace=True) # So we can access unique image names + else: + LOGGER.warning("There are no grainstats or molecule tracing statistics to write to CSV.") # Write config to file config["plotting"].pop("plot_dict") write_yaml(config, output_dir=config["output_dir"]) diff --git a/topostats/summary_config.yaml b/topostats/summary_config.yaml index 34e13965a3..63b3ce858d 100644 --- a/topostats/summary_config.yaml +++ b/topostats/summary_config.yaml @@ -3,7 +3,7 @@ output_dir: ./output/summary_distributions csv_file: ./all_statistics.csv savefig_format: png var_to_label: null # Optional YAML file that maps variable names to labels, uses topostats/var_to_label.yaml if null -molecule_id: molecule_number +molecule_id: grain_number image_id: image # If both hist and kde are True they are plotted on the same graph, if you only want one you MUST set the other to False kde: True diff --git a/topostats/theme.py b/topostats/theme.py index 7a5b1f6f32..298ba3e174 100644 --- a/topostats/theme.py +++ b/topostats/theme.py @@ -59,8 +59,10 @@ def set_cmap(self, name: str) -> None: self.cmap = self.nanoscope() elif name.lower() == "gwyddion": self.cmap = self.gwyddion() - elif name.lower() == "blu": - self.cmap = self.blu() + elif name.lower() == "blue": + self.cmap = self.blue() + elif name.lower() == "blue_purple_green": + self.cmap = self.blue_purple_green() else: # Get one of the matplotlib colormaps self.cmap = mpl.colormaps[name] @@ -162,7 +164,7 @@ def gwyddion() -> LinearSegmentedColormap: return LinearSegmentedColormap.from_list("gwyddion", vals, N=256) @staticmethod - def blu() -> ListedColormap: + def blue() -> ListedColormap: """ Set RGBA colour map of just the colour blue. @@ -171,4 +173,13 @@ def blu() -> ListedColormap: ListedColormap The 'blu' colormap. """ - return ListedColormap([[32 / 256, 226 / 256, 205 / 256]], "blu", N=256) + return ListedColormap([[32 / 256, 226 / 256, 205 / 256]], "blue", N=256) + + @staticmethod + def blue_purple_green(): + """RGBA colour map of just the colour blue.""" + return ListedColormap( + [[0 / 256, 157 / 256, 229 / 256], [255 / 256, 100 / 256, 225 / 256], [0 / 256, 1, 139 / 256]], + "blue_purple_green", + N=3, + ) diff --git a/topostats/tracing/disordered_tracing.py b/topostats/tracing/disordered_tracing.py new file mode 100644 index 0000000000..51275a68aa --- /dev/null +++ b/topostats/tracing/disordered_tracing.py @@ -0,0 +1,793 @@ +"""Generates disordered traces (pruned skeletons) and metrics.""" + +from __future__ import annotations + +import logging +import warnings + +import numpy as np +import numpy.typing as npt +import pandas as pd +import skan +import skimage.measure as skimage_measure +from scipy import ndimage +from skimage import filters +from skimage.morphology import label + +from topostats.logs.logs import LOGGER_NAME +from topostats.tracing.pruning import prune_skeleton +from topostats.tracing.skeletonize import getSkeleton +from topostats.utils import convolve_skeleton + +LOGGER = logging.getLogger(LOGGER_NAME) + +# pylint: disable=too-many-positional-arguments + + +class disorderedTrace: # pylint: disable=too-many-instance-attributes + """ + Calculate disordered traces for a DNA molecule and calculates statistics from those traces. + + Parameters + ---------- + image : npt.NDArray + Cropped image, typically padded beyond the bounding box. + mask : npt.NDArray + Labelled mask for the grain, typically padded beyond the bounding box. + filename : str + Filename being processed. + pixel_to_nm_scaling : float + Pixel to nm scaling. + min_skeleton_size : int + Minimum skeleton size below which tracing statistics are not calculated. + mask_smoothing_params : dict + Dictionary of parameters to smooth the grain mask for better quality skeletonisation results. Contains + a gaussian 'sigma' and number of dilation iterations. + skeletonisation_params : dict + Skeletonisation Parameters. Method of skeletonisation to use 'topostats' is the original TopoStats + method. Three methods from scikit-image are available 'zhang', 'lee' and 'thin'. + pruning_params : dict + Dictionary of pruning parameters. Contains 'method', 'max_length', 'height_threshold', 'method_values' and + 'method_outlier'. + n_grain : int + Grain number being processed (only used in logging). + """ + + def __init__( # pylint: disable=too-many-arguments + self, + image: npt.NDArray, + mask: npt.NDArray, + filename: str, + pixel_to_nm_scaling: float, + min_skeleton_size: int = 10, + mask_smoothing_params: dict | None = None, + skeletonisation_params: dict | None = None, + pruning_params: dict | None = None, + n_grain: int = None, + ): + """ + Calculate disordered traces for a DNA molecule and calculates statistics from those traces. + + Parameters + ---------- + image : npt.NDArray + Cropped image, typically padded beyond the bounding box. + mask : npt.NDArray + Labelled mask for the grain, typically padded beyond the bounding box. + filename : str + Filename being processed. + pixel_to_nm_scaling : float + Pixel to nm scaling. + min_skeleton_size : int + Minimum skeleton size below which tracing statistics are not calculated. + mask_smoothing_params : dict + Dictionary of parameters to smooth the grain mask for better quality skeletonisation results. Contains + a gaussian 'sigma' and number of dilation iterations. + skeletonisation_params : dict + Skeletonisation Parameters. Method of skeletonisation to use 'topostats' is the original TopoStats + method. Three methods from scikit-image are available 'zhang', 'lee' and 'thin'. + pruning_params : dict + Dictionary of pruning parameters. Contains 'method', 'max_length', 'height_threshold', 'method_values' and + 'method_outlier'. + n_grain : int + Grain number being processed (only used in logging). + """ + self.image = image + self.mask = mask + self.filename = filename + self.pixel_to_nm_scaling = pixel_to_nm_scaling + self.min_skeleton_size = min_skeleton_size + self.mask_smoothing_params = mask_smoothing_params + self.skeletonisation_params = ( + skeletonisation_params if skeletonisation_params is not None else {"method": "zhang"} + ) + self.pruning_params = pruning_params if pruning_params is not None else {"method": "topostats"} + self.n_grain = n_grain + # Images + self.smoothed_mask = np.zeros_like(image) + self.skeleton = np.zeros_like(image) + self.pruned_skeleton = np.zeros_like(image) + # Trace + self.disordered_trace = None + + # suppresses scipy splining warnings + warnings.filterwarnings("ignore") + + LOGGER.debug(f"[{self.filename}] Performing Disordered Tracing") + + def trace_dna(self): + """Perform the DNA skeletonisation and cleaning pipeline.""" + self.smoothed_mask = self.smooth_mask(self.mask, **self.mask_smoothing_params) + self.skeleton = getSkeleton( + self.image, + self.smoothed_mask, + method=self.skeletonisation_params["method"], + height_bias=self.skeletonisation_params["height_bias"], + ).get_skeleton() + self.pruned_skeleton = prune_skeleton( + self.image, self.skeleton, self.pixel_to_nm_scaling, **self.pruning_params.copy() + ) + self.pruned_skeleton = self.remove_touching_edge(self.pruned_skeleton) + self.disordered_trace = np.argwhere(self.pruned_skeleton == 1) + + if self.disordered_trace is None: + LOGGER.warning(f"[{self.filename}] : Grain {self.n_grain} failed to Skeletonise.") + elif len(self.disordered_trace) < self.min_skeleton_size: + LOGGER.warning(f"[{self.filename}] : Grain {self.n_grain} skeleton < {self.min_skeleton_size}, skipping.") + self.disordered_trace = None + + def re_add_holes( + self, + orig_mask: npt.NDArray, + smoothed_mask: npt.NDArray, + holearea_min_max: tuple[float | int | None] = (2, None), + ) -> npt.NDArray: + """ + Restore holes in masks that were occluded by dilation. + + As Gaussian dilation smoothing methods can close holes in the original mask, this function obtains those holes + (based on the general background being the first due to padding) and adds them back into the smoothed mask. When + paired with ``smooth_mask``, this essentially just smooths the outer edge of the mask. + + Parameters + ---------- + orig_mask : npt.NDArray + Original mask. + smoothed_mask : npt.NDArray + Original mask but with inner and outer edged smoothed. The smoothing operation may have closed up important + holes in the mask. + holearea_min_max : tuple[float | int | None] + Tuple of minimum and maximum hole area (in nanometers) to replace from the original mask into the smoothed + mask. + + Returns + ------- + npt.NDArray + Smoothed mask with holes restored. + """ + # handle none's + if set(holearea_min_max) == {None}: + return smoothed_mask + if None in holearea_min_max: + none_index = holearea_min_max.index(None) + holearea_min_max[none_index] = 0 if none_index == 0 else np.inf + + # obtain px holesizes + holesize_min_px = holearea_min_max[0] / ((self.pixel_to_nm_scaling) ** 2) + holesize_max_px = holearea_min_max[1] / ((self.pixel_to_nm_scaling) ** 2) + + # obtain a hole mask + holes = 1 - orig_mask + holes = label(holes) + hole_sizes = [holes[holes == i].size for i in range(1, holes.max() + 1)] + holes[holes == 1] = 0 # set background to 0 assuming it is the first hole seen (from top left) + + # remove too small or too big holes from mask + for i, hole_size in enumerate(hole_sizes): + if hole_size < holesize_min_px or hole_size > holesize_max_px: # small holes may be fake are left out + holes[holes == i + 1] = 0 + holes[holes != 0] = 1 # set correct sixe holes to 1 + + # replace correct sized holes + return np.where(holes == 1, 0, smoothed_mask) + + @staticmethod + def remove_touching_edge(skeleton: npt.NDArray) -> npt.NDArray: + """ + Remove any skeleton points touching the border (to prevent errors later). + + Parameters + ---------- + skeleton : npt.NDArray + A binary array where touching clusters of 1's become 0's if touching the edge of the array. + + Returns + ------- + npt.NDArray + Skeleton without points touching the border. + """ + for edge in [skeleton[0, :-1], skeleton[:-1, -1], skeleton[-1, 1:], skeleton[1:, 0]]: + uniques = np.unique(edge) + for i in uniques: + skeleton[skeleton == i] = 0 + return skeleton + + def smooth_mask( + self, + grain: npt.NDArray, + dilation_iterations: int = 2, + gaussian_sigma: float | int = 2, + holearea_min_max: tuple[int | float | None] = (0, None), + ) -> npt.NDArray: + """ + Smooth a grain mask based on the lower number of binary pixels added from dilation or gaussian. + + This method ensures gaussian smoothing isn't too aggressive and covers / creates gaps in the mask. + + Parameters + ---------- + grain : npt.NDArray + Numpy array of the grain mask. + dilation_iterations : int + Number of times to dilate the grain to smooth it. Default is 2. + gaussian_sigma : float | None + Gaussian sigma value to smooth the grains after an Otsu threshold. If None, defaults to 2. + holearea_min_max : tuple[float | int | None] + Tuple of minimum and maximum hole area (in nanometers) to replace from the original mask into the smoothed + mask. + + Returns + ------- + npt.NDArray + Numpy array of smmoothed image. + """ + dilation = ndimage.binary_dilation(grain, iterations=dilation_iterations).astype(np.int32) + gauss = filters.gaussian(grain, sigma=gaussian_sigma) + gauss = np.where(gauss > filters.threshold_otsu(gauss) * 1.3, 1, 0) + gauss = gauss.astype(np.int32) + # Add hole to the smooth mask conditional on smallest pixel difference for dilation or the Gaussian smoothing. + if dilation.sum() > gauss.sum(): + LOGGER.debug(f"[{self.filename}] : smoothing done by gaussian {gaussian_sigma}") + return self.re_add_holes(grain, gauss, holearea_min_max) + LOGGER.debug(f"[{self.filename}] : smoothing done by dilation {dilation_iterations}") + return self.re_add_holes(grain, dilation, holearea_min_max) + + +def trace_image_disordered( # pylint: disable=too-many-arguments,too-many-locals + image: npt.NDArray, + grains_mask: npt.NDArray, + filename: str, + pixel_to_nm_scaling: float, + min_skeleton_size: int, + mask_smoothing_params: dict, + skeletonisation_params: dict, + pruning_params: dict, + pad_width: int = 1, +) -> dict: + """ + Processor function for tracing image. + + Parameters + ---------- + image : npt.NDArray + Full image as Numpy Array. + grains_mask : npt.NDArray + Full image as Grains that are labelled. + filename : str + File being processed. + pixel_to_nm_scaling : float + Pixel to nm scaling. + min_skeleton_size : int + Minimum size of grain in pixels after skeletonisation. + mask_smoothing_params : dict + Dictionary of parameters to smooth the grain mask for better quality skeletonisation results. Contains + a gaussian 'sigma' and number of dilation iterations. + skeletonisation_params : dict + Dictionary of options for skeletonisation, options are 'zhang' (scikit-image) / 'lee' (scikit-image) / 'thin' + (scikitimage) or 'topostats' (original TopoStats method). + pruning_params : dict + Dictionary of options for pruning. + pad_width : int + Padding to the cropped image mask. + + Returns + ------- + tuple[dict, dict] + Binary and integer labeled cropped and full-image masks from skeletonising and pruning the grains in the image. + """ + # Check both arrays are the same shape - should this be a test instead, why should this ever occur? + if image.shape != grains_mask.shape: + raise ValueError(f"Image shape ({image.shape}) and Mask shape ({grains_mask.shape}) should match.") + + cropped_images, cropped_masks, bboxs = prep_arrays(image, grains_mask, pad_width) + n_grains = len(cropped_images) + img_base = np.zeros_like(image) + disordered_trace_crop_data = {} + grainstats_additions = {} + disordered_tracing_stats = pd.DataFrame() + + # want to get each cropped image, use some anchor coords to match them onto the image, + # and compile all the grain images onto a single image + all_images = { + "smoothed_grain": img_base.copy(), + "skeleton": img_base.copy(), + "pruned_skeleton": img_base.copy(), + "branch_indexes": img_base.copy(), + "branch_types": img_base.copy(), + } + + LOGGER.info(f"[{filename}] : Calculating Disordered Tracing statistics for {n_grains} grains...") + + for cropped_image_index, cropped_image in cropped_images.items(): + try: + cropped_mask = cropped_masks[cropped_image_index] + disordered_trace_images = disordered_trace_grain( + cropped_image=cropped_image, + cropped_mask=cropped_mask, + pixel_to_nm_scaling=pixel_to_nm_scaling, + mask_smoothing_params=mask_smoothing_params, + skeletonisation_params=skeletonisation_params, + pruning_params=pruning_params, + filename=filename, + min_skeleton_size=min_skeleton_size, + n_grain=cropped_image_index, + ) + LOGGER.debug(f"[{filename}] : Disordered Traced grain {cropped_image_index + 1} of {n_grains}") + + # obtain segment stats + skan_skeleton = skan.Skeleton( + np.where(disordered_trace_images["pruned_skeleton"] == 1, cropped_image, 0), + spacing=pixel_to_nm_scaling, + ) + skan_df = skan.summarize(skan_skeleton) + skan_df = compile_skan_stats(skan_df, skan_skeleton, cropped_image, filename, cropped_image_index) + disordered_tracing_stats = pd.concat((disordered_tracing_stats, skan_df)) + + # obtain stats + conv_pruned_skeleton = convolve_skeleton(disordered_trace_images["pruned_skeleton"]) + grainstats_additions[cropped_image_index] = { + "image": filename, + "grain_number": cropped_image_index, + "grain_endpoints": np.int64((conv_pruned_skeleton == 2).sum()), + "grain_junctions": np.int64((conv_pruned_skeleton == 3).sum()), + "total_branch_lengths": skan_df["branch_distance"].sum() * 1e-9, + } + + # remap the cropped images back onto the original + for image_name, full_image in all_images.items(): + crop = disordered_trace_images[image_name] + bbox = bboxs[cropped_image_index] + full_image[bbox[0] : bbox[2], bbox[1] : bbox[3]] += crop[pad_width:-pad_width, pad_width:-pad_width] + disordered_trace_crop_data[f"grain_{cropped_image_index}"] = disordered_trace_images + disordered_trace_crop_data[f"grain_{cropped_image_index}"]["bbox"] = bboxs[cropped_image_index] + + # when skel too small, pruned to 0's, skan -> ValueError -> skipped + except Exception as e: # pylint: disable=broad-exception-caught + LOGGER.error( # pylint: disable=logging-not-lazy + f"[{filename}] : Disordered tracing of grain" + + f"{cropped_image_index} failed. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) + + # convert stats dict to dataframe + grainstats_additions_df = pd.DataFrame.from_dict(grainstats_additions, orient="index") + + return disordered_trace_crop_data, grainstats_additions_df, all_images, disordered_tracing_stats + + +def compile_skan_stats( + skan_df: pd.DataFrame, skan_skeleton: skan.Skeleton, image: npt.NDArray, filename: str, grain_number: int +) -> pd.DataFrame: + """ + Obtain and add more stats to the resultant Skan dataframe. + + Parameters + ---------- + skan_df : pd.DataFrame + The statistics DataFrame produced by Skan's `summarize` function. + skan_skeleton : skan.Skeleton + The graphical representation of the skeleton produced by Skan. + image : npt.NDArray + The image the skeleton was produced from. + filename : str + Name of the file being processed. + grain_number : int + The number of the grain being processed. + + Returns + ------- + pd.DataFrame + A dataframe containing the filename, grain_number, branch-distance, branch-type, connected_segments, + mean-pixel-value, stdev-pixel-value, min-value, median-value, and mid-value. + """ + skan_df["image"] = filename + skan_df["grain_number"] = grain_number + skan_df["connected_segments"] = skan_df.apply(find_connections, axis=1, skan_df=skan_df) + skan_df["min_value"] = skan_df.apply(lambda x: segment_heights(x, skan_skeleton, image).min(), axis=1) + skan_df["median_value"] = skan_df.apply(lambda x: np.median(segment_heights(x, skan_skeleton, image)), axis=1) + skan_df["middle_value"] = skan_df.apply(segment_middles, skan_skeleton=skan_skeleton, image=image, axis=1) + + skan_df = skan_df.rename( + columns={ # remove with Skan new release + "branch-distance": "branch_distance", + "branch-type": "branch_type", + "mean-pixel-value": "mean_pixel_value", + "stdev-pixel-value": "stdev_pixel_value", + } + ) + + # remove unused skan columns + return skan_df[ + [ + "image", + "grain_number", + "branch_distance", + "branch_type", + "connected_segments", + "mean_pixel_value", + "stdev_pixel_value", + "min_value", + "median_value", + "middle_value", + ] + ] + + +def segment_heights(row: pd.Series, skan_skeleton: skan.Skeleton, image: npt.NDArray) -> npt.NDArray: + """ + Obtain an ordered list of heights from the skan defined skeleton segment. + + Parameters + ---------- + row : pd.Series + A row from the Skan summarize dataframe. + skan_skeleton : skan.Skeleton + The graphical representation of the skeleton produced by Skan. + image : npt.NDArray + The image the skeleton was produced from. + + Returns + ------- + npt.NDArray + Heights along the segment, naturally ordered by Skan. + """ + coords = skan_skeleton.path_coordinates(row.name) + return image[coords[:, 0], coords[:, 1]] + + +def segment_middles(row: pd.Series, skan_skeleton: skan.csr.Skeleton, image: npt.NDArray) -> float: + """ + Obtain the pixel value in the middle of the ordered segment. + + Parameters + ---------- + row : pd.Series + A row from the Skan summarize dataframe. + skan_skeleton : skan.csr.Skeleton + The graphical representation of the skeleton produced by Skan. + image : npt.NDArray + The image the skeleton was produced from. + + Returns + ------- + float + The single or mean pixel value corresponding to the middle coordinate(s) of the segment. + """ + heights = segment_heights(row, skan_skeleton, image) + middle_idx, middle_remainder = (len(heights) + 1) // 2 - 1, (len(heights) + 1) % 2 + return heights[[middle_idx, middle_idx + middle_remainder]].mean() + + +def find_connections(row: pd.Series, skan_df: pd.DataFrame) -> str: + """ + Compile the neighbouring branch indexes of the row. + + Parameters + ---------- + row : pd.Series + A row from the Skan summarize dataframe. + skan_df : pd.DataFrame + The statistics DataFrame produced by Skan's `summarize` function. + + Returns + ------- + str + A string representation of a list of matching row indices where the node src and dst + columns match that of the rows. + String is needed for csv compatibility since csvs can't hold lists. + """ + connections = skan_df[ + (skan_df["node-id-src"] == row["node-id-src"]) + | (skan_df["node-id-dst"] == row["node-id-dst"]) + | (skan_df["node-id-src"] == row["node-id-dst"]) + | (skan_df["node-id-dst"] == row["node-id-src"]) + ].index.tolist() + + # Remove the index of the current row itself from the list of connections + connections.remove(row.name) + return str(connections) + + +def prep_arrays( + image: npt.NDArray, labelled_grains_mask: npt.NDArray, pad_width: int +) -> tuple[dict[int, npt.NDArray], dict[int, npt.NDArray]]: + """ + Take an image and labelled mask and crops individual grains and original heights to a list. + + A second padding is made after cropping to ensure for "edge cases" where grains are close to bounding box edges that + they are traced correctly. This is accounted for when aligning traces to the whole image mask. + + Parameters + ---------- + image : npt.NDArray + Gaussian filtered image. Typically filtered_image.images["gaussian_filtered"]. + labelled_grains_mask : npt.NDArray + 2D Numpy array of labelled grain masks, with each mask being comprised solely of unique integer (not + zero). Typically this will be output from 'grains.directions[["labelled_region_02]'. + pad_width : int + Cells by which to pad cropped regions by. + + Returns + ------- + Tuple + Returns a tuple of two dictionaries, each consisting of cropped arrays. + """ + # Get bounding boxes for each grain + region_properties = skimage_measure.regionprops(labelled_grains_mask) + # Subset image and grains then zip them up + cropped_images = {} + cropped_masks = {} + + # for index, grain in enumerate(region_properties): + # cropped_image, cropped_bbox = crop_array(image, grain.bbox, pad_width) + + cropped_images = {index: crop_array(image, grain.bbox, pad_width) for index, grain in enumerate(region_properties)} + cropped_images = {index: np.pad(grain, pad_width=pad_width) for index, grain in cropped_images.items()} + cropped_masks = { + index: crop_array(labelled_grains_mask, grain.bbox, pad_width) for index, grain in enumerate(region_properties) + } + cropped_masks = {index: np.pad(grain, pad_width=pad_width) for index, grain in cropped_masks.items()} + cropped_masks = {index: np.where(grain == (index + 1), 1, 0) for index, grain in cropped_masks.items()} + # Get BBOX coords to remap crops to images + bboxs = [pad_bounding_box(image.shape, list(grain.bbox), pad_width=pad_width) for grain in region_properties] + + return (cropped_images, cropped_masks, bboxs) + + +def grain_anchor(array_shape: tuple, bounding_box: list, pad_width: int) -> list: + """ + Extract anchor (min_row, min_col) from labelled regions and align individual traces over the original image. + + Parameters + ---------- + array_shape : tuple + Shape of original array. + bounding_box : list + A list of region properties returned by 'skimage.measure.regionprops()'. + pad_width : int + Padding for image. + + Returns + ------- + list(Tuple) + A list of tuples of the min_row, min_col of each bounding box. + """ + bounding_coordinates = pad_bounding_box(array_shape, bounding_box, pad_width) + return (bounding_coordinates[0], bounding_coordinates[1]) + + +def disordered_trace_grain( # pylint: disable=too-many-arguments + cropped_image: npt.NDArray, + cropped_mask: npt.NDArray, + pixel_to_nm_scaling: float, + mask_smoothing_params: dict, + skeletonisation_params: dict, + pruning_params: dict, + filename: str = None, + min_skeleton_size: int = 10, + n_grain: int = None, +) -> dict: + """ + Trace an individual grain. + + Tracing involves multiple steps... + + 1. Skeletonisation + 2. Pruning of side branches (artefacts from skeletonisation). + 3. Ordering of the skeleton. + + Parameters + ---------- + cropped_image : npt.NDArray + Cropped array from the original image defined as the bounding box from the labelled mask. + cropped_mask : npt.NDArray + Cropped array from the labelled image defined as the bounding box from the labelled mask. This should have been + converted to a binary mask. + pixel_to_nm_scaling : float + Pixel to nm scaling. + mask_smoothing_params : dict + Dictionary of parameters to smooth the grain mask for better quality skeletonisation results. Contains + a gaussian 'sigma' and number of dilation iterations. + skeletonisation_params : dict + Dictionary of skeletonisation parameters, options are 'zhang' (scikit-image) / 'lee' (scikit-image) / 'thin' + (scikitimage) or 'topostats' (original TopoStats method). + pruning_params : dict + Dictionary of pruning parameters. + filename : str + File being processed. + min_skeleton_size : int + Minimum size of grain in pixels after skeletonisation. + n_grain : int + Grain number being processed. + + Returns + ------- + dict + Dictionary of the contour length, whether the image is circular or linear, the end-to-end distance and an array + of coordinates. + """ + disorderedtrace = disorderedTrace( + image=cropped_image, + mask=cropped_mask, + filename=filename, + pixel_to_nm_scaling=pixel_to_nm_scaling, + min_skeleton_size=min_skeleton_size, + mask_smoothing_params=mask_smoothing_params, + skeletonisation_params=skeletonisation_params, + pruning_params=pruning_params, + n_grain=n_grain, + ) + + disorderedtrace.trace_dna() + + return { + "original_image": cropped_image, + "original_grain": cropped_mask, + "smoothed_grain": disorderedtrace.smoothed_mask, + "skeleton": disorderedtrace.skeleton, + "pruned_skeleton": disorderedtrace.pruned_skeleton, + "branch_types": get_skan_image( + cropped_image, disorderedtrace.pruned_skeleton, "branch-type" + ), # change with Skan new release + "branch_indexes": get_skan_image( + cropped_image, disorderedtrace.pruned_skeleton, "node-id-src" + ), # change with Skan new release + } + + +def get_skan_image(original_image: npt.NDArray, pruned_skeleton: npt.NDArray, skan_column: str) -> npt.NDArray: + """ + Label each branch with it's Skan branch type label. + + Branch types (+1 compared to Skan docs) are defined as: + 1 = Endpoint-to-endpoint (isolated branch) + 2 = Junction-to-endpoint + 3 = Junction-to-junction + 4 = Isolated cycle + + Parameters + ---------- + original_image : npt.NDArray + Height image from which the pruned skeleton is derived from. + pruned_skeleton : npt.NDArray + Single pixel thick skeleton mask. + skan_column : str + A column from Skan's summarize function to colour the branch segments with. + + Returns + ------- + npt.NDArray + 2D array where the background is 0, and skeleton branches label as their Skan branch type. + """ + branch_field_image = np.zeros_like(original_image) + skeleton_image = np.where(pruned_skeleton == 1, original_image, 0) + + try: + skan_skeleton = skan.Skeleton(skeleton_image, spacing=1e-9, value_is_height=True) + res = skan.summarize(skan_skeleton) + for i, branch_field in enumerate(res[skan_column]): + path_coords = skan_skeleton.path_coordinates(i) + if skan_column == "node-id-src": + branch_field = i + branch_field_image[path_coords[:, 0], path_coords[:, 1]] = branch_field + 1 + except ValueError: # when no skeleton to skan + LOGGER.warning("Skeleton has been pruned out of existence.") + + return branch_field_image + + +def crop_array(array: npt.NDArray, bounding_box: tuple, pad_width: int = 0) -> npt.NDArray: + """ + Crop an array. + + Ideally we pad the array that is being cropped so that we have heights outside of the grains bounding box. However, + in some cases, if a grain is near the edge of the image scan this results in requesting indexes outside of the + existing image. In which case we get as much of the image padded as possible. + + Parameters + ---------- + array : npt.NDArray + 2D Numpy array to be cropped. + bounding_box : Tuple + Tuple of coordinates to crop, should be of form (min_row, min_col, max_row, max_col). + pad_width : int + Padding to apply to bounding box. + + Returns + ------- + npt.NDArray() + Cropped array. + """ + bounding_box = list(bounding_box) + bounding_box = pad_bounding_box(array.shape, bounding_box, pad_width) + return array[ + bounding_box[0] : bounding_box[2], + bounding_box[1] : bounding_box[3], + ] + + +def pad_bounding_box(array_shape: tuple, bounding_box: list, pad_width: int) -> list: + """ + Pad coordinates, if they extend beyond image boundaries stop at boundary. + + Parameters + ---------- + array_shape : tuple + Shape of original image. + bounding_box : list + List of coordinates 'min_row', 'min_col', 'max_row', 'max_col'. + pad_width : int + Cells to pad arrays by. + + Returns + ------- + list + List of padded coordinates. + """ + # Top Row : Make this the first column if too close + bounding_box[0] = 0 if bounding_box[0] - pad_width < 0 else bounding_box[0] - pad_width + # Left Column : Make this the first column if too close + bounding_box[1] = 0 if bounding_box[1] - pad_width < 0 else bounding_box[1] - pad_width + # Bottom Row : Make this the last row if too close + bounding_box[2] = array_shape[0] if bounding_box[2] + pad_width > array_shape[0] else bounding_box[2] + pad_width + # Right Column : Make this the last column if too close + bounding_box[3] = array_shape[1] if bounding_box[3] + pad_width > array_shape[1] else bounding_box[3] + pad_width + return bounding_box + + +# 2023-06-09 - Code that runs dnatracing in parallel across grains, left deliberately for use when we remodularise the +# entry-points/workflow. Will require that the gaussian filtered array is saved and passed in along with +# the labelled regions. @ns-rse +# +# +# if __name__ == "__main__": +# cropped_images, cropped_masks = prep_arrays(image, grains_mask, pad_width) +# n_grains = len(cropped_images) +# LOGGER.info(f"[{filename}] : Calculating statistics for {n_grains} grains.") +# # Process in parallel +# with Pool(processes=cores) as pool: +# results = {} +# with tqdm(total=n_grains) as pbar: +# x = 0 +# for result in pool.starmap( +# trace_grain, +# zip( +# cropped_images, +# cropped_masks, +# repeat(pixel_to_nm_scaling), +# repeat(filename), +# repeat(min_skeleton_size), +# repeat(skeletonisation_method), +# ), +# ): +# LOGGER.info(f"[{filename}] : Traced grain {x + 1} of {n_grains}") +# results[x] = result +# x += 1 +# pbar.update() +# try: +# results = pd.DataFrame.from_dict(results, orient="index") +# results.index.name = "molecule_number" +# except ValueError as error: +# LOGGER.error("No grains found in any images, consider adjusting your thresholds.") +# LOGGER.error(error) +# return results diff --git a/topostats/tracing/dnatracing.py b/topostats/tracing/dnatracing.py deleted file mode 100644 index 3c2ed2015b..0000000000 --- a/topostats/tracing/dnatracing.py +++ /dev/null @@ -1,1345 +0,0 @@ -"""Perform DNA Tracing.""" - -from __future__ import annotations - -from collections import OrderedDict -from functools import partial -from itertools import repeat -import logging -import math -from multiprocessing import Pool -import os -from pathlib import Path -from typing import Dict, List, Union, Tuple -import warnings - -import numpy as np -import numpy.typing as npt -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -from scipy import ndimage, spatial, interpolate as interp -from skimage import morphology -from skimage.filters import gaussian -import skimage.measure as skimage_measure -from tqdm import tqdm - -from topostats.logs.logs import LOGGER_NAME -from topostats.tracing.skeletonize import get_skeleton -from topostats.tracing.tracingfuncs import genTracingFuncs, getSkeleton, reorderTrace -from topostats.utils import bound_padded_coordinates_to_image - -LOGGER = logging.getLogger(LOGGER_NAME) - - -class dnaTrace: - """ - Calculates traces for a DNA molecule and calculates statistics from those traces. - - 2023-06-09 : This class has undergone some refactoring so that it works with a single grain. The `trace_grain()` - helper function runs the class and returns the expected statistics whilst the `trace_image()` function handles - processing all detected grains within an image. The original methods of skeletonisation are available along with - additional methods from scikit-image. - - Some bugs have been identified and corrected see commits for further details... - - 236750b2 - 2a79c4ff - - Parameters - ---------- - image : npt.NDArray - Cropped image, typically padded beyond the bounding box. - grain : npt.NDArray - Labelled mask for the grain, typically padded beyond the bounding box. - filename : str - Filename being processed. - pixel_to_nm_scaling : float - Pixel to nm scaling. - min_skeleton_size : int - Minimum skeleton size below which tracing statistics are not calculated. - convert_nm_to_m : bool - Convert nanometers to metres. - skeletonisation_method : str - Method of skeletonisation to use 'topostats' is the original TopoStats method. Three methods from - scikit-image are available 'zhang', 'lee' and 'thin'. - n_grain : int - Grain number being processed (only used in logging). - spline_step_size : float - Step size for spline evaluation in metres. - spline_linear_smoothing : float - Smoothness of linear splines. - spline_circular_smoothing : float - Smoothness of circular splines. - spline_quiet : bool - Suppresses scipy splining warnings. - spline_degree : int - Degree of the spline. - """ - - def __init__( - self, - image: npt.NDArray, - grain: npt.NDArray, - filename: str, - pixel_to_nm_scaling: float, - min_skeleton_size: int = 10, - convert_nm_to_m: bool = True, - skeletonisation_method: str = "topostats", - n_grain: int = None, - spline_step_size: float = 7e-9, - spline_linear_smoothing: float = 5.0, - spline_circular_smoothing: float = 0.0, - spline_quiet: bool = True, - spline_degree: int = 3, - ): - """ - Initialise the class. - - Parameters - ---------- - image : npt.NDArray - Cropped image, typically padded beyond the bounding box. - grain : npt.NDArray - Labelled mask for the grain, typically padded beyond the bounding box. - filename : str - Filename being processed. - pixel_to_nm_scaling : float - Pixel to nm scaling. - min_skeleton_size : int - Minimum skeleton size below which tracing statistics are not calculated. - convert_nm_to_m : bool - Convert nanometers to metres. - skeletonisation_method : str - Method of skeletonisation to use 'topostats' is the original TopoStats method. Three methods from - scikit-image are available 'zhang', 'lee' and 'thin'. - n_grain : int - Grain number being processed (only used in logging). - spline_step_size : float - Step size for spline evaluation in metres. - spline_linear_smoothing : float - Smoothness of linear splines. - spline_circular_smoothing : float - Smoothness of circular splines. - spline_quiet : bool - Suppresses scipy splining warnings. - spline_degree : int - Degree of the spline. - """ - self.image = image * 1e-9 if convert_nm_to_m else image - self.grain = grain - self.filename = filename - self.pixel_to_nm_scaling = pixel_to_nm_scaling * 1e-9 if convert_nm_to_m else pixel_to_nm_scaling - self.min_skeleton_size = min_skeleton_size - self.skeletonisation_method = skeletonisation_method - self.n_grain = n_grain - self.number_of_rows = self.image.shape[0] - self.number_of_columns = self.image.shape[1] - self.sigma = 0.7 / (self.pixel_to_nm_scaling * 1e9) - - self.gauss_image = None - self.grain = grain - self.disordered_trace = None - self.ordered_trace = None - self.fitted_trace = None - self.splined_trace = None - self.contour_length = np.nan - self.end_to_end_distance = np.nan - self.mol_is_circular = np.nan - self.curvature = np.nan - - # Splining parameters - self.spline_step_size: float = spline_step_size - self.spline_linear_smoothing: float = spline_linear_smoothing - self.spline_circular_smoothing: float = spline_circular_smoothing - self.spline_quiet: bool = spline_quiet - self.spline_degree: int = spline_degree - - self.neighbours = 5 # The number of neighbours used for the curvature measurement - - self.ordered_trace_heights = None - self.ordered_trace_cumulative_distances = None - - # suppresses scipy splining warnings - warnings.filterwarnings("ignore") - - LOGGER.debug(f"[{self.filename}] Performing DNA Tracing") - - def trace_dna(self): - """Perform DNA tracing.""" - self.gaussian_filter() - self.get_disordered_trace() - if self.disordered_trace is None: - LOGGER.info(f"[{self.filename}] : Grain failed to Skeletonise") - elif len(self.disordered_trace) >= self.min_skeleton_size: - self.linear_or_circular(self.disordered_trace) - self.get_ordered_traces() - self.get_ordered_trace_heights() - self.get_ordered_trace_cumulative_distances() - self.linear_or_circular(self.ordered_trace) - self.get_fitted_traces() - self.get_splined_traces() - # self.find_curvature() - # self.saveCurvature() - self.measure_contour_length() - self.measure_end_to_end_distance() - else: - LOGGER.info(f"[{self.filename}] [{self.n_grain}] : Grain skeleton pixels < {self.min_skeleton_size}") - - def gaussian_filter(self, **kwargs) -> np.array: - """ - Apply Gaussian filter. - - Parameters - ---------- - **kwargs - Arguments passed to 'skimage.filter.gaussian(**kwargs)'. - """ - self.gauss_image = gaussian(self.image, sigma=self.sigma, **kwargs) - LOGGER.info(f"[{self.filename}] [{self.n_grain}] : Gaussian filter applied.") - - def get_ordered_trace_heights(self) -> None: - """ - Derive the pixel heights from the ordered trace `self.ordered_trace` list. - - Gets the heights of each pixel in the ordered trace from the gaussian filtered image. The pixel coordinates - for the ordered trace are stored in the ordered trace list as part of the class. - """ - self.ordered_trace_heights = np.array(self.gauss_image[self.ordered_trace[:, 0], self.ordered_trace[:, 1]]) - - def get_ordered_trace_cumulative_distances(self) -> None: - """Calculate the cumulative distances of each pixel in the `self.ordered_trace` list.""" - - # Get the cumulative distances of each pixel in the ordered trace from the gaussian filtered image - # the pixel coordinates are stored in the ordered trace list. - self.ordered_trace_cumulative_distances = self.coord_dist( - coordinates=self.ordered_trace, px_to_nm=self.pixel_to_nm_scaling - ) - - @staticmethod - def coord_dist(coordinates: npt.NDArray, px_to_nm: float) -> npt.NDArray: - """ - Calculate the cumulative real distances between each pixel in a trace. - - Take a Nx2 numpy array of (grid adjacent) coordinates and produce a list of cumulative distances in - nanometres, travelling from pixel to pixel. 1D example: coordinates: [0, 0], [0, 1], [1, 1], [2, 2] cumulative - distances: [0, 1, 2, 3.4142]. Counts diagonal connections as 1.4142 distance. Converts distances from - pixels to nanometres using px_to_nm scaling factor. - Note that the pixels have to be adjacent. - - Parameters - ---------- - coordinates : npt.NDArray - A Nx2 integer array of coordinates of the pixels of a trace from a binary trace image. - px_to_nm : float - Pixel to nanometre scaling factor to allow for real length measurements of distances rather - than pixels. - - Returns - ------- - npt.NDArray - Numpy array of length N containing the cumulative sum of distances (0 at the first entry, - full molecule length at the last entry). - """ - - # Shift the array by one coordinate so the end is at the start and the second to last is at the end - # this allows for the calculation of the distance between each pixel by subtracting the shifted array - # from the original array - rolled_coords = np.roll(coordinates, 1, axis=0) - - # Calculate the distance between each pixel in the trace - pixel_diffs = coordinates - rolled_coords - pixel_distances = np.linalg.norm(pixel_diffs, axis=1) - - # Set the first distance to zero since we don't want to count the distance from the last pixel to the first - pixel_distances[0] = 0 - - # Calculate the cumulative sum of the distances - cumulative_distances = np.cumsum(pixel_distances) - - # Convert the cumulative distances from pixels to nanometres - cumulative_distances_nm = cumulative_distances * px_to_nm - - return cumulative_distances_nm - - def get_disordered_trace(self) -> None: - """ - Create a skeleton for each of the grains in the image. - - Uses my own skeletonisation function from tracingfuncs module. I (Joe) will eventually get round to editing this - function to try to reduce the branching and to try to better trace from looped molecules. - """ - smoothed_grain = ndimage.binary_dilation(self.grain, iterations=1).astype(self.grain.dtype) - - sigma = 0.01 / (self.pixel_to_nm_scaling * 1e9) - very_smoothed_grain = ndimage.gaussian_filter(smoothed_grain, sigma) - - LOGGER.info(f"[{self.filename}] [{self.n_grain}] : Skeletonising using {self.skeletonisation_method} method.") - try: - if self.skeletonisation_method == "topostats": - dna_skeleton = getSkeleton( - self.gauss_image, - smoothed_grain, - self.number_of_columns, - self.number_of_rows, - self.pixel_to_nm_scaling, - ) - self.disordered_trace = dna_skeleton.output_skeleton - elif self.skeletonisation_method in ["lee", "zhang", "thin"]: - self.disordered_trace = np.argwhere( - get_skeleton(smoothed_grain, method=self.skeletonisation_method) == True - ) - else: - raise ValueError - except IndexError as e: - # Some gwyddion grains touch image border causing IndexError - # These grains are deleted - LOGGER.info(f"[{self.filename}] [{self.n_grain}] : Grain failed to skeletonise.") - # raise e - - def linear_or_circular(self, traces) -> None: - """ - Determine whether molecule is circular or linear based on the local environment of each pixel from the trace. - - This function is sensitive to branches from the skeleton so might need to implement a function to remove them. - - Parameters - ---------- - traces : npt.NDArray - The array of coordinates to be assessed. - """ - - points_with_one_neighbour = 0 - fitted_trace_list = traces.tolist() - - # For loop determines how many neighbours a point has - if only one it is an end - for x, y in fitted_trace_list: - if genTracingFuncs.countNeighbours(x, y, fitted_trace_list) == 1: - points_with_one_neighbour += 1 - else: - pass - - if points_with_one_neighbour == 0: - self.mol_is_circular = True - else: - self.mol_is_circular = False - - def get_ordered_traces(self): - """Order a trace.""" - if self.mol_is_circular: - self.ordered_trace, trace_completed = reorderTrace.circularTrace(self.disordered_trace) - - if not trace_completed: - self.mol_is_circular = False - try: - self.ordered_trace = reorderTrace.linearTrace(self.ordered_trace.tolist()) - except UnboundLocalError: - pass - - elif not self.mol_is_circular: - self.ordered_trace = reorderTrace.linearTrace(self.disordered_trace.tolist()) - - def get_fitted_traces(self): - """Create trace coordinates that are adjusted to lie along the highest points of each traced molecule.""" - - individual_skeleton = self.ordered_trace - # This indexes a 3 nm height profile perpendicular to DNA backbone - # note that this is a hard coded parameter - index_width = int(3e-9 / (self.pixel_to_nm_scaling)) - if index_width < 2: - index_width = 2 - - for coord_num, trace_coordinate in enumerate(individual_skeleton): - height_values = None - - # Ensure that padding will not exceed the image boundaries - trace_coordinate = bound_padded_coordinates_to_image( - coordinates=trace_coordinate, - padding=index_width, - image_shape=(self.number_of_rows, self.number_of_columns), - ) - - # calculate vector to n - 2 coordinate in trace - if self.mol_is_circular: - nearest_point = individual_skeleton[coord_num - 2] - vector = np.subtract(nearest_point, trace_coordinate) - vector_angle = math.degrees(math.atan2(vector[1], vector[0])) - else: - try: - nearest_point = individual_skeleton[coord_num + 2] - except IndexError: - nearest_point = individual_skeleton[coord_num - 2] - vector = np.subtract(nearest_point, trace_coordinate) - vector_angle = math.degrees(math.atan2(vector[1], vector[0])) - - if vector_angle < 0: - vector_angle += 180 - - # if angle is closest to 45 degrees - if 67.5 > vector_angle >= 22.5: - perp_direction = "negative diaganol" - # positive diagonal (change in x and y) - # Take height values at the inverse of the positive diaganol - # (i.e. the negative diaganol) - y_coords = np.arange(trace_coordinate[1] - index_width, trace_coordinate[1] + index_width)[::-1] - x_coords = np.arange(trace_coordinate[0] - index_width, trace_coordinate[0] + index_width) - - # if angle is closest to 135 degrees - elif 157.5 >= vector_angle >= 112.5: - perp_direction = "positive diaganol" - y_coords = np.arange(trace_coordinate[1] - index_width, trace_coordinate[1] + index_width) - x_coords = np.arange(trace_coordinate[0] - index_width, trace_coordinate[0] + index_width) - - # if angle is closest to 90 degrees - if 112.5 > vector_angle >= 67.5: - perp_direction = "horizontal" - x_coords = np.arange(trace_coordinate[0] - index_width, trace_coordinate[0] + index_width) - y_coords = np.full(len(x_coords), trace_coordinate[1]) - - elif 22.5 > vector_angle: # if angle is closest to 0 degrees - perp_direction = "vertical" - y_coords = np.arange(trace_coordinate[1] - index_width, trace_coordinate[1] + index_width) - x_coords = np.full(len(y_coords), trace_coordinate[0]) - - elif vector_angle >= 157.5: # if angle is closest to 180 degrees - perp_direction = "vertical" - y_coords = np.arange(trace_coordinate[1] - index_width, trace_coordinate[1] + index_width) - x_coords = np.full(len(y_coords), trace_coordinate[0]) - - # Use the perp array to index the gaussian filtered image - perp_array = np.column_stack((x_coords, y_coords)) - try: - height_values = self.gauss_image[perp_array[:, 0], perp_array[:, 1]] - except IndexError: - perp_array[:, 0] = np.where( - perp_array[:, 0] > self.gauss_image.shape[0], self.gauss_image.shape[0], perp_array[:, 0] - ) - perp_array[:, 1] = np.where( - perp_array[:, 1] > self.gauss_image.shape[1], self.gauss_image.shape[1], perp_array[:, 1] - ) - height_values = self.gauss_image[perp_array[:, 1], perp_array[:, 0]] - - # Grab x,y coordinates for highest point - # fine_coords = np.column_stack((fine_x_coords, fine_y_coords)) - sorted_array = perp_array[np.argsort(height_values)] - highest_point = sorted_array[-1] - - try: - # could use np.append() here - fitted_coordinate_array = np.vstack((fitted_coordinate_array, highest_point)) - except UnboundLocalError: - fitted_coordinate_array = highest_point - - self.fitted_trace = fitted_coordinate_array - del fitted_coordinate_array # cleaned up by python anyway? - - @staticmethod - # Perhaps we need a module for array functions? - def remove_duplicate_consecutive_tuples(tuple_list: list[tuple | npt.NDArray]) -> list[tuple]: - """ - Remove duplicate consecutive tuples from a list. - - Parameters - ---------- - tuple_list : list[tuple | npt.NDArray] - List of tuples or numpy ndarrays to remove consecutive duplicates from. - - Returns - ------- - list[Tuple] - List of tuples with consecutive duplicates removed. - - Examples - -------- - For the list of tuples [(1, 2), (1, 2), (1, 2), (2, 3), (2, 3), (3, 4)], this function will return - [(1, 2), (2, 3), (3, 4)] - """ - - duplicates_removed = [] - for index, tup in enumerate(tuple_list): - if index == 0 or not np.array_equal(tuple_list[index - 1], tup): - duplicates_removed.append(tup) - return np.array(duplicates_removed) - - def get_splined_traces( - self, - ) -> None: - """ - Get a splined version of the fitted trace - useful for finding the radius of gyration etc. - - This function actually calculates the average of several splines which is important for getting a good fit on - the lower resolution data. - """ - - # Fitted traces are Nx2 numpy arrays of coordinates - # All self references are here for easy turning into static method if wanted, also type hints and short documentation - fitted_trace: np.ndarray = self.fitted_trace # boolean 2d numpy array of fitted traces to spline - step_size_m: float = self.spline_step_size # the step size for the splines to skip pixels in the fitted trace - pixel_to_nm_scaling: float = self.pixel_to_nm_scaling # pixel to nanometre scaling factor for the image - mol_is_circular: bool = self.mol_is_circular # whether or not the molecule is classed as circular - n_grain: int = self.n_grain # the grain index (for logging purposes) - - # Calculate the step size in pixels from the step size in metres. - # Should always be at least 1. - # Note that step_size_m is in m and pixel_to_nm_scaling is in m because of the legacy code which seems to almost always have - # pixel_to_nm_scaling be set in metres using the flag convert_nm_to_m. No idea why this is the case. - step_size_px = max(int(step_size_m / pixel_to_nm_scaling), 1) - - # Splines will be totalled and then divived by number of splines to calculate the average spline - spline_sum = None - - # Get the length of the fitted trace - fitted_trace_length = fitted_trace.shape[0] - - # If the fitted trace is less than the degree plus one, then there is no - # point in trying to spline it, just return the fitted trace - if fitted_trace_length < self.spline_degree + 1: - LOGGER.warning( - f"Fitted trace for grain {n_grain} too small ({fitted_trace_length}), returning fitted trace" - ) - self.splined_trace = fitted_trace - return - - # There cannot be fewer than degree + 1 points in the spline - # Decrease the step size to ensure more than this number of points - while fitted_trace_length / step_size_px < self.spline_degree + 1: - # Step size cannot be less than 1 - if step_size_px <= 1: - step_size_px = 1 - break - step_size_px = -1 - - # Set smoothness and periodicity appropriately for linear / circular molecules. - spline_smoothness, spline_periodicity = ( - (self.spline_circular_smoothing, 2) if mol_is_circular else (self.spline_linear_smoothing, 0) - ) - - # Create an array of evenly spaced points between 0 and 1 for the splines to be evaluated at. - # This is needed to ensure that the splines are all the same length as the number of points - # in the spline is controlled by the ev_array variable. - ev_array = np.linspace(0, 1, fitted_trace_length * step_size_px) - - # Find as many splines as there are steps in step size, this allows for a better spline to be obtained - # by averaging the splines. Think of this like weaving a lot of splines together along the course of - # the trace. Example spline coordinate indexes: [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], where spline - # 1 takes every 4th coordinate, starting at position 0, then spline 2 takes every 4th coordinate - # starting at position 1, etc... - for i in range(step_size_px): - # Sample the fitted trace at every step_size_px pixels - sampled = [fitted_trace[j, :] for j in range(i, fitted_trace_length, step_size_px)] - - # Scipy.splprep cannot handle duplicate consecutive x, y tuples, so remove them. - # Get rid of any consecutive duplicates in the sampled coordinates - sampled = self.remove_duplicate_consecutive_tuples(tuple_list=sampled) - - x_sampled = sampled[:, 0] - y_sampled = sampled[:, 1] - - # Use scipy's B-spline functions - # tck is a tuple, (t,c,k) containing the vector of knots, the B-spline coefficients - # and the degree of the spline. - # s is the smoothing factor, per is the periodicity, k is the degree of the spline - tck, _ = interp.splprep( - [x_sampled, y_sampled], - s=spline_smoothness, - per=spline_periodicity, - quiet=self.spline_quiet, - k=self.spline_degree, - ) - # splev returns a tuple (x_coords ,y_coords) containing the smoothed coordinates of the - # spline, constructed from the B-spline coefficients and knots. The number of points in - # the spline is controlled by the ev_array variable. - # ev_array is an array of evenly spaced points between 0 and 1. - # This is to ensure that the splines are all the same length. - # Tck simply provides the coefficients for the spline. - out = interp.splev(ev_array, tck) - splined_trace = np.column_stack((out[0], out[1])) - - # Add the splined trace to the spline_sum array for averaging later - if spline_sum is None: - spline_sum = np.array(splined_trace) - else: - spline_sum = np.add(spline_sum, splined_trace) - - # Find the average spline between the set of splines - # This is an attempt to find a better spline by averaging our candidates - spline_average = np.divide(spline_sum, [step_size_px, step_size_px]) - - self.splined_trace = spline_average - - def show_traces(self): - """Plot traces.""" - plt.pcolormesh(self.gauss_image, vmax=-3e-9, vmin=3e-9) - plt.colorbar() - plt.plot(self.ordered_trace[:, 0], self.ordered_trace[:, 1], markersize=1) - plt.plot(self.fitted_trace[:, 0], self.fitted_trace[:, 1], markersize=1) - plt.plot(self.splined_trace[:, 0], self.splined_trace[:, 1], markersize=1) - - plt.show() - plt.close() - - def saveTraceFigures( - self, - filename: str | Path, - channel_name: str, - vmaxval: float | int, - vminval: float | int, - output_dir: str | Path = None, - ) -> None: - """ - Save the traces. - - Parameters - ---------- - filename : str | Path - Filename being processed. - channel_name : str - Channel. - vmaxval : float | int - Maximum value for height. - vminval : float | int - Minimum value for height. - output_dir : str | Path - Output directory. - """ - if output_dir: - filename = self._checkForSaveDirectory(filename, output_dir) - - # save_file = filename[:-4] - - vmaxval = 20e-9 - vminval = -10e-9 - - plt.pcolormesh(self.image, vmax=vmaxval, vmin=vminval) - plt.colorbar() - # plt.savefig("%s_%s_originalImage.png" % (save_file, channel_name)) - plt.savefig(output_dir / filename / f"{channel_name}_original.png") - plt.close() - - plt.pcolormesh(self.image, vmax=vmaxval, vmin=vminval) - plt.colorbar() - # disordered_trace_list = self.ordered_trace[dna_num].tolist() - # less_dense_trace = np.array([disordered_trace_list[i] for i in range(0,len(disordered_trace_list),5)]) - plt.plot(self.splined_trace[:, 0], self.splined_trace[:, 1], color="c", linewidth=1.0) - if self.mol_is_circular: - starting_point = 0 - else: - starting_point = self.neighbours - length = len(self.curvature) - plt.plot( - self.splined_trace[starting_point, 0], - self.splined_trace[starting_point, 1], - color="#D55E00", - markersize=3.0, - marker=5, - ) - plt.plot( - self.splined_trace[starting_point + int(length / 6), 0], - self.splined_trace[starting_point + int(length / 6), 1], - color="#E69F00", - markersize=3.0, - marker=5, - ) - plt.plot( - self.splined_trace[starting_point + int(length / 6 * 2), 0], - self.splined_trace[starting_point + int(length / 6 * 2), 1], - color="#F0E442", - markersize=3.0, - marker=5, - ) - plt.plot( - self.splined_trace[starting_point + int(length / 6 * 3), 0], - self.splined_trace[starting_point + int(length / 6 * 3), 1], - color="#009E74", - markersize=3.0, - marker=5, - ) - plt.plot( - self.splined_trace[starting_point + int(length / 6 * 4), 0], - self.splined_trace[starting_point + int(length / 6 * 4), 1], - color="#0071B2", - markersize=3.0, - marker=5, - ) - plt.plot( - self.splined_trace[starting_point + int(length / 6 * 5), 0], - self.splined_trace[starting_point + int(length / 6 * 5), 1], - color="#CC79A7", - markersize=3.0, - marker=5, - ) - plt.savefig(f"{save_file}_{channel_name}_splinedtrace_with_markers.png") - plt.close() - - plt.pcolormesh(self.image, vmax=vmaxval, vmin=vminval) - plt.colorbar() - plt.plot(self.splined_trace[:, 0], self.splined_trace[:, 1], color="c", linewidth=1.0) - # plt.savefig("%s_%s_splinedtrace.png" % (save_file, channel_name)) - plt.savefig(output_dir / filename / f"{channel_name}_splinedtrace.png") - LOGGER.info(f"Splined Trace image saved to : {str(output_dir / filename / f'{channel_name}_splinedtrace.png')}") - plt.close() - - # plt.pcolormesh(self.image, vmax=vmaxval, vmin=vminval) - # plt.colorbar() - # LOOP REMOVED - # for dna_num in sorted(self.disordered_trace.keys()): - # disordered_trace_list = self.disordered_trace[dna_num].tolist() - # less_dense_trace = np.array([disordered_trace_list[i] for i in range(0,len(disordered_trace_list),5)]) - plt.plot( - self.disordered_trace[:, 0], - self.disordered_trace[:, 1], - "o", - markersize=1.0, - color="c", - ) - # plt.savefig("%s_%s_disorderedtrace.png" % (save_file, channel_name)) - # plt.savefig(output_dir / filename / f"{channel_name}_disordered_trace.png") - plt.savefig(output_dir / f"{filename}.png") - plt.close() - LOGGER.info( - f"Disordered trace image saved to : {str(output_dir / filename / f'{channel_name}_disordered_trace.png')}" - ) - - # plt.pcolormesh(self.image, vmax=vmaxval, vmin=vminval) - # plt.colorbar() - # for dna_num in sorted(self.grain.keys()): - # grain_plt = np.argwhere(self.grain[dna_num] == 1) - # plt.plot(grain_plt[:, 0], grain_plt[:, 1], "o", markersize=2, color="c") - # plt.savefig("%s_%s_grains.png" % (save_file, channel_name)) - # plt.savefig(output_dir / filename / f"{channel_name}_grains.png") - # plt.savefig(output_dir / f"{filename}_grain.png") - # plt.close() - LOGGER.info(f"Grains image saved to : {str(output_dir / filename / f'{channel_name}_grains.png')}") - - # FIXME : Replace with Path() (.mkdir(parent=True, exists=True) negate need to handle errors.) - def _checkForSaveDirectory(self, filename: str, new_output_dir: str) -> str: - """ - Create output directory and updates filename to account for this. - - Parameters - ---------- - filename : str - Filename. - new_output_dir : str - Target directory. - - Returns - ------- - str - Updated output directory. - """ - - split_directory_path = os.path.split(filename) - - try: - os.mkdir(os.path.join(split_directory_path[0], new_output_dir)) - except OSError: # OSError happens if the directory already exists - pass - - updated_filename = os.path.join(split_directory_path[0], new_output_dir, split_directory_path[1]) - - return updated_filename - - def find_curvature(self): - """Calculate curvature of the molecule.""" - curve = [] - contour = 0 - coordinates = np.zeros([2, self.neighbours * 2 + 1]) - for i, (x, y) in enumerate(self.splined_trace): - # Extracts the coordinates for the required number of points and puts them in an array - if self.mol_is_circular or (self.neighbours < i < len(self.splined_trace) - self.neighbours): - for j in range(self.neighbours * 2 + 1): - coordinates[0][j] = self.splined_trace[i - j][0] - coordinates[1][j] = self.splined_trace[i - j][1] - - # Calculates the angles for the tangent lines to the left and the right of the point - theta1 = math.atan( - (coordinates[1][self.neighbours] - coordinates[1][0]) - / (coordinates[0][self.neighbours] - coordinates[0][0]) - ) - theta2 = math.atan( - (coordinates[1][-1] - coordinates[1][self.neighbours]) - / (coordinates[0][-1] - coordinates[0][self.neighbours]) - ) - - left = coordinates[:, : self.neighbours + 1] - right = coordinates[:, -(self.neighbours + 1) :] - - xa = np.mean(left[0]) - ya = np.mean(left[1]) - - xb = np.mean(right[0]) - yb = np.mean(right[1]) - - # Calculates the curvature using the change in angle divided by the distance - dist = math.hypot((xb - xa), (yb - ya)) - dist_real = dist * self.pixel_to_nm_scaling - curve.append([i, contour, (theta2 - theta1) / dist_real]) - - contour = contour + math.hypot( - (coordinates[0][self.neighbours] - coordinates[0][self.neighbours - 1]), - (coordinates[1][self.neighbours] - coordinates[1][self.neighbours - 1]), - ) - self.curvature = curve - - def saveCurvature(self) -> None: - """Save curvature statistics.""" - # FIXME : Iterate directly over self.splined_trace.values() or self.splined_trace.items() - # roc_array = np.zeros(shape=(1, 3)) - for i, [n, contour, c] in enumerate(self.curvature): - try: - roc_array = np.append(roc_array, np.array([[i, contour, c]]), axis=0) - # oc_array.append([dna_num, i, contour, c]) - except NameError: - roc_array = np.array([[i, contour, c]]) - # roc_array = np.vstack((roc_array, np.array([dna_num, i, c]))) - # roc_array = np.delete(roc_array, 0, 0) - roc_stats = pd.DataFrame(roc_array) - - if not os.path.exists(os.path.join(os.path.dirname(self.filename), "Curvature")): - os.mkdir(os.path.join(os.path.dirname(self.filename), "Curvature")) - directory = os.path.join(os.path.dirname(self.filename), "Curvature") - savename = os.path.join(directory, os.path.basename(self.filename)[:-4]) - roc_stats.to_json(savename + ".json") - roc_stats.to_csv(savename + ".csv") - - def plotCurvature(self, dna_num: int) -> None: - """ - Plot the curvature of the chosen molecule as a function of the contour length (in metres). - - Parameters - ---------- - dna_num : int - Molecule to plot, used for indexing. - """ - - curvature = np.array(self.curvature[dna_num]) - length = len(curvature) - # FIXME : Replace with Path() - if not os.path.exists(os.path.join(os.path.dirname(self.filename), "Curvature")): - os.mkdir(os.path.join(os.path.dirname(self.filename), "Curvature")) - directory = os.path.join(os.path.dirname(self.filename), "Curvature") - savename = os.path.join(directory, os.path.basename(self.filename)[:-4]) - - plt.figure() - sns.lineplot(curvature[:, 1] * self.pixel_to_nm_scaling, curvature[:, 2], color="k") - plt.ylim(-1e9, 1e9) - plt.ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) - plt.axvline(curvature[0][1], color="#D55E00") - plt.axvline(curvature[int(length / 6)][1] * self.pixel_to_nm_scaling, color="#E69F00") - plt.axvline(curvature[int(length / 6 * 2)][1] * self.pixel_to_nm_scaling, color="#F0E442") - plt.axvline(curvature[int(length / 6 * 3)][1] * self.pixel_to_nm_scaling, color="#009E74") - plt.axvline(curvature[int(length / 6 * 4)][1] * self.pixel_to_nm_scaling, color="#0071B2") - plt.axvline(curvature[int(length / 6 * 5)][1] * self.pixel_to_nm_scaling, color="#CC79A7") - plt.savefig(f"{savename}_{dna_num}_curvature.png") - plt.close() - - def measure_contour_length(self) -> None: - """ - Contour lengthof the splined trace taking into account whether the molecule is circular or linear. - - Contour length units are nm. - """ - if self.mol_is_circular: - for num, i in enumerate(self.splined_trace): - x1 = self.splined_trace[num - 1, 0] - y1 = self.splined_trace[num - 1, 1] - x2 = self.splined_trace[num, 0] - y2 = self.splined_trace[num, 1] - - try: - hypotenuse_array.append(math.hypot((x1 - x2), (y1 - y2))) - except NameError: - hypotenuse_array = [math.hypot((x1 - x2), (y1 - y2))] - - self.contour_length = np.sum(np.array(hypotenuse_array)) * self.pixel_to_nm_scaling - del hypotenuse_array - - else: - for num, i in enumerate(self.splined_trace): - try: - x1 = self.splined_trace[num, 0] - y1 = self.splined_trace[num, 1] - x2 = self.splined_trace[num + 1, 0] - y2 = self.splined_trace[num + 1, 1] - - try: - hypotenuse_array.append(math.hypot((x1 - x2), (y1 - y2))) - except NameError: - hypotenuse_array = [math.hypot((x1 - x2), (y1 - y2))] - except IndexError: # IndexError happens at last point in array - self.contour_length = np.sum(np.array(hypotenuse_array)) * self.pixel_to_nm_scaling - del hypotenuse_array - break - - def measure_end_to_end_distance(self): - """ - Calculate the Euclidean distance between the start and end of linear molecules. - - The hypotenuse is calculated between the start ([0,0], [0,1]) and end ([-1,0], [-1,1]) of linear - molecules. If the molecule is circular then the distance is set to zero (0). - """ - if self.mol_is_circular: - self.end_to_end_distance = 0 - else: - x1 = self.splined_trace[0, 0] - y1 = self.splined_trace[0, 1] - x2 = self.splined_trace[-1, 0] - y2 = self.splined_trace[-1, 1] - self.end_to_end_distance = math.hypot((x1 - x2), (y1 - y2)) * self.pixel_to_nm_scaling - - -def trace_image( - image: npt.NDArray, - grains_mask: npt.NDArray, - filename: str, - pixel_to_nm_scaling: float, - min_skeleton_size: int, - skeletonisation_method: str, - spline_step_size: float = 7e-9, - spline_linear_smoothing: float = 5.0, - spline_circular_smoothing: float = 0.0, - pad_width: int = 1, - cores: int = 1, -) -> dict: - """ - Processor function for tracing image. - - Parameters - ---------- - image : npt.NDArray - Full image as Numpy Array. - grains_mask : npt.NDArray - Full image as Grains that are labelled. - filename : str - File being processed. - pixel_to_nm_scaling : float - Pixel to nm scaling. - min_skeleton_size : int - Minimum size of grain in pixels after skeletonisation. - skeletonisation_method : str - Method of skeletonisation, options are 'zhang' (scikit-image) / 'lee' (scikit-image) / 'thin' (scikitimage) or - 'topostats' (original TopoStats method). - spline_step_size : float - Step size for spline evaluation in metres. - spline_linear_smoothing : float - Smoothness of linear splines. - spline_circular_smoothing : float - Smoothness of circular splines. - pad_width : int - Number of cells to pad arrays by, required to handle instances where grains touch the bounding box edges. - cores : int - Number of cores to process with. - - Returns - ------- - dict - Statistics from skeletonising and tracing the grains in the image. - """ - # Check both arrays are the same shape - if image.shape != grains_mask.shape: - raise ValueError(f"Image shape ({image.shape}) and Mask shape ({grains_mask.shape}) should match.") - - cropped_images, cropped_masks = prep_arrays(image, grains_mask, pad_width) - region_properties = skimage_measure.regionprops(grains_mask) - grain_anchors = [grain_anchor(image.shape, list(grain.bbox), pad_width) for grain in region_properties] - n_grains = len(cropped_images) - LOGGER.info(f"[{filename}] : Calculating statistics for {n_grains} grains.") - results = {} - ordered_traces = {} - splined_traces = {} - all_ordered_trace_heights = {} - all_ordered_trace_cumulative_distances = {} - for cropped_image_index, cropped_image in cropped_images.items(): - cropped_mask = cropped_masks[cropped_image_index] - - result = trace_grain( - cropped_image, - cropped_mask, - pixel_to_nm_scaling, - filename, - min_skeleton_size, - skeletonisation_method, - spline_step_size, - spline_linear_smoothing, - spline_circular_smoothing, - cropped_image_index, - ) - LOGGER.info(f"[{filename}] : Traced grain {cropped_image_index + 1} of {n_grains}") - ordered_traces[cropped_image_index] = result.pop("ordered_trace") - splined_traces[cropped_image_index] = result.pop("splined_trace") - all_ordered_trace_heights[cropped_image_index] = result.pop("ordered_trace_heights") - all_ordered_trace_cumulative_distances[cropped_image_index] = result.pop("ordered_trace_cumulative_distances") - results[cropped_image_index] = result - try: - results = pd.DataFrame.from_dict(results, orient="index") - results.index.name = "molecule_number" - image_trace = trace_mask(grain_anchors, ordered_traces, image.shape, pad_width) - rounded_splined_traces = round_splined_traces(splined_traces=splined_traces) - image_spline_trace = trace_mask(grain_anchors, rounded_splined_traces, image.shape, pad_width) - except ValueError as error: - LOGGER.error("No grains found in any images, consider adjusting your thresholds.") - LOGGER.error(error) - return { - "statistics": results, - "all_ordered_traces": ordered_traces, - "all_splined_traces": splined_traces, - "all_cropped_images": cropped_images, - "image_ordered_trace": image_trace, - "image_spline_trace": image_spline_trace, - "all_ordered_trace_heights": all_ordered_trace_heights, - "all_ordered_trace_cumulative_distances": all_ordered_trace_cumulative_distances, - } - - -def round_splined_traces(splined_traces: dict) -> dict: - """ - Round a Dict of floating point coordinates to integer floating point coordinates. - - Parameters - ---------- - splined_traces : dict - Floating point coordinates to be rounded. - - Returns - ------- - dict - Dictionary of rounded integer coordinates. - """ - rounded_splined_traces = {} - for grain_number, splined_trace in splined_traces.items(): - if splined_trace is not None: - rounded_splined_traces[grain_number] = np.round(splined_trace).astype(int) - else: - rounded_splined_traces[grain_number] = None - - return rounded_splined_traces - - -def trim_array(array: npt.NDArray, pad_width: int) -> npt.NDArray: - """ - Trim an array by the specified pad_width. - - Removes a border from an array. Typically this is the second padding that is added to the image/masks for edge cases - that are near image borders and means traces will be correctly aligned as a mask for the original image. - - Parameters - ---------- - array : npt.NDArray - Numpy array to be trimmed. - pad_width : int - Padding to be removed. - - Returns - ------- - npt.NDArray - Trimmed array. - """ - return array[pad_width:-pad_width, pad_width:-pad_width] - - -def adjust_coordinates(coordinates: npt.NDArray, pad_width: int) -> npt.NDArray: - """ - Adjust coordinates of a trace by the pad_width. - - A second padding is made to allow for grains that are "edge cases" and close to the bounding box edge. This adds the - pad_width to the cropped grain array. In order to realign the trace with the original image we need to remove this - padding so that when the coordinates are combined with the "grain_anchor", which isn't padded twice, the - coordinates correctly align with the original image. - - Parameters - ---------- - coordinates : npt.NDArray - An array of trace coordinates (typically ordered). - pad_width : int - The amount of padding used. - - Returns - ------- - npt.NDArray - Array of trace coordinates adjusted for secondary padding. - """ - return coordinates - pad_width - - -def trace_mask( - grain_anchors: list[npt.NDArray], ordered_traces: dict[str, npt.NDArray], image_shape: tuple, pad_width: int -) -> npt.NDArray: - """ - Place the traced skeletons into an array of the original image for plotting/overlaying. - - Adjusts the coordinates back to the original position based on each grains anchor coordinates of the padded - bounding box. Adjustments are made for the secondary padding that is made. - - Parameters - ---------- - grain_anchors : List[npt.NDArray] - List of grain anchors for the padded bounding box. - ordered_traces : Dict[npt.NDArray] - Coordinates for each grain trace. - Dict of coordinates for each grains trace. - image_shape : tuple - Shape of original image. - pad_width : int - The amount of padding used on the image. - - Returns - ------- - npt.NDArray - Mask of traces for all grains that can be overlaid on original image. - """ - image = np.zeros(image_shape) - for grain_number, (grain_anchor, ordered_trace) in enumerate(zip(grain_anchors, ordered_traces.values())): - # Don't always have an ordered_trace for a given grain_anchor if for example the trace was too small - if ordered_trace is not None: - ordered_trace = adjust_coordinates(ordered_trace, pad_width) - # If any of the values in ordered_trace added to their respective grain_anchor are greater than the image - # shape, then the trace is outside the image and should be skipped. - if ( - np.max(ordered_trace[:, 0]) + grain_anchor[0] > image_shape[0] - or np.max(ordered_trace[:, 1]) + grain_anchor[1] > image_shape[1] - ): - LOGGER.info(f"Grain {grain_number} has a trace that breaches the image bounds. Skipping.") - continue - ordered_trace[:, 0] = ordered_trace[:, 0] + grain_anchor[0] - ordered_trace[:, 1] = ordered_trace[:, 1] + grain_anchor[1] - image[ordered_trace[:, 0], ordered_trace[:, 1]] = 1 - - return image - - -def prep_arrays( - image: npt.NDArray, labelled_grains_mask: npt.NDArray, pad_width: int -) -> tuple[dict[int, npt.NDArray], dict[int, npt.NDArray]]: - """ - Take an image and labelled mask and crops individual grains and original heights to a list. - - A second padding is made after cropping to ensure for "edge cases" where grains are close to bounding box edges that - they are traced correctly. This is accounted for when aligning traces to the whole image mask. - - Parameters - ---------- - image : npt.NDArray - Gaussian filtered image. Typically filtered_image.images["gaussian_filtered"]. - labelled_grains_mask : npt.NDArray - 2D Numpy array of labelled grain masks, with each mask being comprised solely of unique integer (not - zero). Typically this will be output from 'grains.directions[["labelled_region_02]'. - pad_width : int - Cells by which to pad cropped regions by. - - Returns - ------- - Tuple - Returns a tuple of two dictionaries, each consisting of cropped arrays. - """ - # Get bounding boxes for each grain - region_properties = skimage_measure.regionprops(labelled_grains_mask) - # Subset image and grains then zip them up - cropped_images = {index: crop_array(image, grain.bbox, pad_width) for index, grain in enumerate(region_properties)} - cropped_images = {index: np.pad(grain, pad_width=pad_width) for index, grain in cropped_images.items()} - cropped_masks = { - index: crop_array(labelled_grains_mask, grain.bbox, pad_width) for index, grain in enumerate(region_properties) - } - cropped_masks = {index: np.pad(grain, pad_width=pad_width) for index, grain in cropped_masks.items()} - # Flip every labelled region to be 1 instead of its label - cropped_masks = {index: np.where(grain == 0, 0, 1) for index, grain in cropped_masks.items()} - return (cropped_images, cropped_masks) - - -def grain_anchor(array_shape: tuple, bounding_box: list, pad_width: int) -> list: - """ - Extract anchor (min_row, min_col) from labelled regions and align individual traces over the original image. - - Parameters - ---------- - array_shape : tuple - Shape of original array. - bounding_box : list - A list of region properties returned by 'skimage.measure.regionprops()'. - pad_width : int - Padding for image. - - Returns - ------- - list(Tuple) - A list of tuples of the min_row, min_col of each bounding box. - """ - bounding_coordinates = pad_bounding_box(array_shape, bounding_box, pad_width) - return (bounding_coordinates[0], bounding_coordinates[1]) - - -def trace_grain( - cropped_image: npt.NDArray, - cropped_mask: npt.NDArray, - pixel_to_nm_scaling: float, - filename: str = None, - min_skeleton_size: int = 10, - skeletonisation_method: str = "topostats", - spline_step_size: float = 7e-9, - spline_linear_smoothing: float = 5.0, - spline_circular_smoothing: float = 0.0, - n_grain: int = None, -) -> dict: - """ - Trace an individual grain. - - Tracing involves multiple steps... - - 1. Skeletonisation - 2. Pruning of side branch artefacts from skeletonisation. - 3. Ordering of the skeleton. - 4. Determination of molecule shape. - 5. Jiggling/Fitting - 6. Splining to improve resolution of image. - - Parameters - ---------- - cropped_image : npt.NDArray - Cropped array from the original image defined as the bounding box from the labelled mask. - cropped_mask : npt.NDArray - Cropped array from the labelled image defined as the bounding box from the labelled mask. This should have been - converted to a binary mask. - pixel_to_nm_scaling : float - Pixel to nm scaling. - filename : str - File being processed. - min_skeleton_size : int - Minimum size of grain in pixels after skeletonisation. - skeletonisation_method : str - Method of skeletonisation, options are 'zhang' (scikit-image) / 'lee' (scikit-image) / 'thin' (scikitimage) or - 'topostats' (original TopoStats method). - spline_step_size : float - Step size for spline evaluation in metres. - spline_linear_smoothing : float - Smoothness of linear splines. - spline_circular_smoothing : float - Smoothness of circular splines. - n_grain : int - Grain number being processed. - - Returns - ------- - dict - Dictionary of the contour length, whether the image is circular or linear, the end-to-end distance and an array - of coordinates. - """ - dnatrace = dnaTrace( - image=cropped_image, - grain=cropped_mask, - filename=filename, - pixel_to_nm_scaling=pixel_to_nm_scaling, - min_skeleton_size=min_skeleton_size, - skeletonisation_method=skeletonisation_method, - spline_step_size=spline_step_size, - spline_linear_smoothing=spline_linear_smoothing, - spline_circular_smoothing=spline_circular_smoothing, - n_grain=n_grain, - ) - dnatrace.trace_dna() - return { - "image": dnatrace.filename, - "contour_length": dnatrace.contour_length, - "circular": dnatrace.mol_is_circular, - "end_to_end_distance": dnatrace.end_to_end_distance, - "ordered_trace": dnatrace.ordered_trace, - "splined_trace": dnatrace.splined_trace, - "ordered_trace_heights": dnatrace.ordered_trace_heights, - "ordered_trace_cumulative_distances": dnatrace.ordered_trace_cumulative_distances, - } - - -def crop_array(array: npt.NDArray, bounding_box: tuple, pad_width: int = 0) -> npt.NDArray: - """ - Crop an array. - - Ideally we pad the array that is being cropped so that we have heights outside of the grains bounding box. However, - in some cases, if an grain is near the edge of the image scan this results in requesting indexes outside of the - existing image. In which case we get as much of the image padded as possible. - - Parameters - ---------- - array : npt.NDArray - 2D Numpy array to be cropped. - bounding_box : Tuple - Tuple of coordinates to crop, should be of form (min_row, min_col, max_row, max_col). - pad_width : int - Padding to apply to bounding box. - - Returns - ------- - npt.NDArray() - Cropped array. - """ - bounding_box = list(bounding_box) - bounding_box = pad_bounding_box(array.shape, bounding_box, pad_width) - return array[ - bounding_box[0] : bounding_box[2], - bounding_box[1] : bounding_box[3], - ] - - -def pad_bounding_box(array_shape: tuple, bounding_box: list, pad_width: int) -> list: - """ - Pad coordinates, if they extend beyond image boundaries stop at boundary. - - Parameters - ---------- - array_shape : tuple - Shape of original image. - bounding_box : list - List of coordinates 'min_row', 'min_col', 'max_row', 'max_col'. - pad_width : int - Cells to pad arrays by. - - Returns - ------- - list - List of padded coordinates. - """ - # Top Row : Make this the first column if too close - bounding_box[0] = 0 if bounding_box[0] - pad_width < 0 else bounding_box[0] - pad_width - # Left Column : Make this the first column if too close - bounding_box[1] = 0 if bounding_box[1] - pad_width < 0 else bounding_box[1] - pad_width - # Bottom Row : Make this the last row if too close - bounding_box[2] = array_shape[0] if bounding_box[2] + pad_width > array_shape[0] else bounding_box[2] + pad_width - # Right Column : Make this the last column if too close - bounding_box[3] = array_shape[1] if bounding_box[3] + pad_width > array_shape[1] else bounding_box[3] + pad_width - return bounding_box - - -# 2023-06-09 - Code that runs dnatracing in parallel across grains, left deliberately for use when we remodularise the -# entry-points/workflow. Will require that the gaussian filtered array is saved and passed in along with -# the labelled regions. @ns-rse -# -# -# if __name__ == "__main__": -# cropped_images, cropped_masks = prep_arrays(image, grains_mask, pad_width) -# n_grains = len(cropped_images) -# LOGGER.info(f"[{filename}] : Calculating statistics for {n_grains} grains.") -# # Process in parallel -# with Pool(processes=cores) as pool: -# results = {} -# with tqdm(total=n_grains) as pbar: -# x = 0 -# for result in pool.starmap( -# trace_grain, -# zip( -# cropped_images, -# cropped_masks, -# repeat(pixel_to_nm_scaling), -# repeat(filename), -# repeat(min_skeleton_size), -# repeat(skeletonisation_method), -# ), -# ): -# LOGGER.info(f"[{filename}] : Traced grain {x + 1} of {n_grains}") -# results[x] = result -# x += 1 -# pbar.update() -# try: -# results = pd.DataFrame.from_dict(results, orient="index") -# results.index.name = "molecule_number" -# except ValueError as error: -# LOGGER.error("No grains found in any images, consider adjusting your thresholds.") -# LOGGER.error(error) -# return results diff --git a/topostats/tracing/nodestats.py b/topostats/tracing/nodestats.py new file mode 100644 index 0000000000..2957938069 --- /dev/null +++ b/topostats/tracing/nodestats.py @@ -0,0 +1,1932 @@ +"""Perform Crossing Region Processing and Analysis.""" + +from __future__ import annotations + +import logging +from itertools import combinations +from typing import TypedDict + +import networkx as nx +import numpy as np +import numpy.typing as npt +import pandas as pd +from scipy.ndimage import binary_dilation +from scipy.signal import argrelextrema +from skimage.morphology import label + +from topostats.logs.logs import LOGGER_NAME +from topostats.measure.geometry import ( + calculate_shortest_branch_distances, + connect_best_matches, + find_branches_for_nodes, +) +from topostats.tracing.pruning import prune_skeleton +from topostats.tracing.skeletonize import getSkeleton +from topostats.tracing.tracingfuncs import order_branch, order_branch_from_start +from topostats.utils import ResolutionError, convolve_skeleton + +LOGGER = logging.getLogger(LOGGER_NAME) + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-branches +# pylint: disable=too-many-instance-attributes +# pylint: disable=too-many-lines +# pylint: disable=too-many-locals +# pylint: disable=too-many-nested-blocks +# pylint: disable=too-many-public-methods +# pylint: disable=too-many-statements + + +class NodeDict(TypedDict): + """Dictionary containing the node information.""" + + error: bool + pixel_to_nm_scaling: np.float64 + branch_stats: dict[int, MatchedBranch] | None + node_coords: npt.NDArray[np.int32] | None + confidence: np.float64 | None + + +class MatchedBranch(TypedDict): + """ + Dictionary containing the matched branches. + + matched_branches: dict[int, dict[str, npt.NDArray[np.number]]] + Dictionary where the key is the index of the pair and the value is a dictionary containing the following + keys: + - "ordered_coords" : npt.NDArray[np.int32]. The ordered coordinates of the branch. + - "heights" : npt.NDArray[np.number]. Heights of the branch coordinates. + - "distances" : npt.NDArray[np.number]. Distances of the branch coordinates. + - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branch. + - "angles" : np.float64. The initial direction angle of the branch, added in later steps. + """ + + ordered_coords: npt.NDArray[np.int32] + heights: npt.NDArray[np.number] + distances: npt.NDArray[np.number] + fwhm: dict[str, np.float64 | tuple[np.float64]] + angles: np.float64 | None + + +class ImageDict(TypedDict): + """Dictionary containing the image information.""" + + nodes: dict[str, dict[str, npt.NDArray[np.int32]]] + grain: dict[str, npt.NDArray[np.int32] | dict[str, npt.NDArray[np.int32]]] + + +class nodeStats: + """ + Class containing methods to find and analyse the nodes/crossings within a grain. + + Parameters + ---------- + filename : str + The name of the file being processed. For logging purposes. + image : npt.npt.NDArray + The array of pixels. + mask : npt.npt.NDArray + The binary segmentation mask. + smoothed_mask : npt.NDArray + A smoothed version of the bianary segmentation mask. + skeleton : npt.NDArray + A binary single-pixel wide mask of objects in the 'image'. + pixel_to_nm_scaling : np.float32 + The pixel to nm scaling factor. + n_grain : int + The grain number. + node_joining_length : float + The length over which to join skeletal intersections to be counted as one crossing. + node_joining_length : float + The distance over which to join nearby odd-branched nodes. + node_extend_dist : float + The distance under which to join odd-branched node regions. + branch_pairing_length : float + The length from the crossing point to pair and trace, obtaining FWHM's. + pair_odd_branches : bool + Whether to try and pair odd-branched nodes. + """ + + def __init__( + self, + filename: str, + image: npt.NDArray, + mask: npt.NDArray, + smoothed_mask: npt.NDArray, + skeleton: npt.NDArray, + pixel_to_nm_scaling: np.float64, + n_grain: int, + node_joining_length: float, + node_extend_dist: float, + branch_pairing_length: float, + pair_odd_branches: bool, + ) -> None: + """ + Initialise the nodeStats class. + + Parameters + ---------- + filename : str + The name of the file being processed. For logging purposes. + image : npt.NDArray + The array of pixels. + mask : npt.NDArray + The binary segmentation mask. + smoothed_mask : npt.NDArray + A smoothed version of the bianary segmentation mask. + skeleton : npt.NDArray + A binary single-pixel wide mask of objects in the 'image'. + pixel_to_nm_scaling : float + The pixel to nm scaling factor. + n_grain : int + The grain number. + node_joining_length : float + The length over which to join skeletal intersections to be counted as one crossing. + node_joining_length : float + The distance over which to join nearby odd-branched nodes. + node_extend_dist : float + The distance under which to join odd-branched node regions. + branch_pairing_length : float + The length from the crossing point to pair and trace, obtaining FWHM's. + pair_odd_branches : bool + Whether to try and pair odd-branched nodes. + """ + self.filename = filename + self.image = image + self.mask = mask + self.smoothed_mask = smoothed_mask # only used to average traces + self.skeleton = skeleton + self.pixel_to_nm_scaling = pixel_to_nm_scaling + self.n_grain = n_grain + self.node_joining_length = node_joining_length + self.node_extend_dist = node_extend_dist / self.pixel_to_nm_scaling + self.branch_pairing_length = branch_pairing_length + self.pair_odd_branches = pair_odd_branches + + self.conv_skelly = np.zeros_like(self.skeleton) + self.connected_nodes = np.zeros_like(self.skeleton) + self.all_connected_nodes = np.zeros_like(self.skeleton) + self.whole_skel_graph: nx.classes.graph.Graph | None = None + self.node_centre_mask = np.zeros_like(self.skeleton) + + self.metrics = { + "num_crossings": 0, + "avg_crossing_confidence": None, + "min_crossing_confidence": None, + } + + self.node_dicts: dict[str, NodeDict] = {} + self.image_dict: ImageDict = { + "nodes": {}, + "grain": { + "grain_image": self.image, + "grain_mask": self.mask, + "grain_skeleton": self.skeleton, + }, + } + + self.full_dict = {} + self.mol_coords = {} + self.visuals = {} + self.all_visuals_img = None + + def get_node_stats(self) -> tuple: + """ + Run the workflow to obtain the node statistics. + + Returns + ------- + dict + Key structure: + |-> + |-> 'error' + └-> 'node_coords' + └-> 'branch_stats' + └-> + |-> 'ordered_coords' + |-> 'heights' + |-> 'gaussian_fit' + |-> 'fwhm' + └-> 'angles' + dict + Key structure: 'nodes' + + |-> 'node_area_skeleton' + |-> 'node_branch_mask' + └-> 'node_avg_mask + 'grain' + |-> 'grain_image' + |-> 'grain_mask' + └-> 'grain_skeleton' + """ + LOGGER.debug(f"Node Stats - Processing Grain: {self.n_grain}") + self.conv_skelly = convolve_skeleton(self.skeleton) + if len(self.conv_skelly[self.conv_skelly == 3]) != 0: # check if any nodes + LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} contains crossings.") + # convolve to see crossing and end points + # self.conv_skelly = self.tidy_branches(self.conv_skelly, self.image) + # reset skeleton var as tidy branches may have modified it + self.skeleton = np.where(self.conv_skelly != 0, 1, 0) + self.image_dict["grain"]["grain_skeleton"] = self.skeleton + # get graph of skeleton + self.whole_skel_graph = self.skeleton_image_to_graph(self.skeleton) + # connect the close nodes + LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} connecting close nodes.") + self.connected_nodes = self.connect_close_nodes(self.conv_skelly, node_width=self.node_joining_length) + # connect the odd-branch nodes + self.connected_nodes = self.connect_extended_nodes_nearest( + self.connected_nodes, node_extend_dist=self.node_extend_dist + ) + # obtain a mask of node centers and their count + self.node_centre_mask = self.highlight_node_centres(self.connected_nodes) + # Begin the hefty crossing analysis + LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} analysing found crossings.") + self.analyse_nodes(max_branch_length=self.branch_pairing_length) + self.compile_metrics() + else: + LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} has no crossings.") + return self.node_dicts, self.image_dict + # self.all_visuals_img = dnaTrace.concat_images_in_dict(self.image.shape, self.visuals) + + @staticmethod + def skeleton_image_to_graph(skeleton: npt.NDArray) -> nx.classes.graph.Graph: + """ + Convert a skeletonised mask into a Graph representation. + + Graphs conserve the coordinates via the node label. + + Parameters + ---------- + skeleton : npt.NDArray + A binary single-pixel wide mask, or result from conv_skelly(). + + Returns + ------- + nx.classes.graph.Graph + A networkX graph connecting the pixels in the skeleton to their neighbours. + """ + skeImPos = np.argwhere(skeleton).T + g = nx.Graph() + neigh = np.array([[0, 1], [0, -1], [1, 0], [-1, 0], [1, 1], [1, -1], [-1, 1], [-1, -1]]) + + for idx in range(skeImPos[0].shape[0]): + for neighIdx in range(neigh.shape[0]): + curNeighPos = skeImPos[:, idx] + neigh[neighIdx] + if np.any(curNeighPos < 0) or np.any(curNeighPos >= skeleton.shape): + continue + if skeleton[curNeighPos[0], curNeighPos[1]] > 0: + idx_coord = skeImPos[0, idx], skeImPos[1, idx] + curNeigh_coord = curNeighPos[0], curNeighPos[1] + # assign lower weight to nodes if not a binary image + if skeleton[idx_coord] == 3 and skeleton[curNeigh_coord] == 3: + weight = 0 + else: + weight = 1 + g.add_edge(idx_coord, curNeigh_coord, weight=weight) + g.graph["physicalPos"] = skeImPos.T + return g + + @staticmethod + def graph_to_skeleton_image(g: nx.Graph, im_shape: tuple[int]) -> npt.NDArray: + """ + Convert the skeleton graph back to a binary image. + + Parameters + ---------- + g : nx.Graph + Graph with coordinates as node labels. + im_shape : tuple[int] + The shape of the image to dump. + + Returns + ------- + npt.NDArray + Skeleton binary image from the graph representation. + """ + im = np.zeros(im_shape) + for node in g: + im[node] = 1 + + return im + + def tidy_branches(self, connect_node_mask: npt.NDArray, image: npt.NDArray) -> npt.NDArray: + """ + Wrangle distant connected nodes back towards the main cluster. + + Works by filling and reskeletonising soely the node areas. + + Parameters + ---------- + connect_node_mask : npt.NDArray + The connected node mask - a skeleton where node regions = 3, endpoints = 2, and skeleton = 1. + image : npt.NDArray + The intensity image. + + Returns + ------- + npt.NDArray + The wrangled connected_node_mask. + """ + new_skeleton = np.where(connect_node_mask != 0, 1, 0) + labeled_nodes = label(np.where(connect_node_mask == 3, 1, 0)) + for node_num in range(1, labeled_nodes.max() + 1): + solo_node = np.where(labeled_nodes == node_num, 1, 0) + coords = np.argwhere(solo_node == 1) + node_centre = coords.mean(axis=0).astype(np.int32) + node_wid = coords[:, 0].max() - coords[:, 0].min() + 2 # +2 so always 2 by default + node_len = coords[:, 1].max() - coords[:, 1].min() + 2 # +2 so always 2 by default + overflow = int(10 / self.pixel_to_nm_scaling) if int(10 / self.pixel_to_nm_scaling) != 0 else 1 + # grain mask fill + new_skeleton[ + node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow, + node_centre[1] - node_len // 2 - overflow : node_centre[1] + node_len // 2 + overflow, + ] = self.mask[ + node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow, + node_centre[1] - node_len // 2 - overflow : node_centre[1] + node_len // 2 + overflow, + ] + # remove any artifacts of the grain caught in the overflow areas + new_skeleton = self.keep_biggest_object(new_skeleton) + # Re-skeletonise + new_skeleton = getSkeleton(image, new_skeleton, method="topostats", height_bias=0.6).get_skeleton() + # new_skeleton = pruneSkeleton(image, new_skeleton).prune_skeleton( + # {"method": "topostats", "max_length": -1} + # ) + new_skeleton = prune_skeleton( + image, new_skeleton, self.pixel_to_nm_scaling, **{"method": "topostats", "max_length": -1} + ) + # cleanup around nibs + new_skeleton = getSkeleton(image, new_skeleton, method="zhang").get_skeleton() + # might also need to remove segments that have squares connected + + return convolve_skeleton(new_skeleton) + + @staticmethod + def keep_biggest_object(mask: npt.NDArray) -> npt.NDArray: + """ + Retain the largest object in a binary mask. + + Parameters + ---------- + mask : npt.NDArray + Binary mask. + + Returns + ------- + npt.NDArray + A binary mask with only one object. + """ + labelled_mask = label(mask) + idxs, counts = np.unique(mask, return_counts=True) + try: + max_idx = idxs[np.argmax(counts[1:]) + 1] + return np.where(labelled_mask == max_idx, 1, 0) + except ValueError as e: + LOGGER.debug(f"{e}: mask is empty.") + return mask + + def connect_close_nodes(self, conv_skelly: npt.NDArray, node_width: float = 2.85) -> npt.NDArray: + """ + Connect nodes within the 'node_width' boundary distance. + + This labels them as part of the same node. + + Parameters + ---------- + conv_skelly : npt.NDArray + A labeled skeleton image with skeleton = 1, endpoints = 2, crossing points =3. + node_width : float + The width of the dna in the grain, used to connect close nodes. + + Returns + ------- + np.ndarray + The skeleton (label=1) with close nodes connected (label=3). + """ + self.connected_nodes = conv_skelly.copy() + nodeless = conv_skelly.copy() + nodeless[(nodeless == 3) | (nodeless == 2)] = 0 # remove node & termini points + nodeless_labels = label(nodeless) + for i in range(1, nodeless_labels.max() + 1): + if nodeless[nodeless_labels == i].size < (node_width / self.pixel_to_nm_scaling): + # maybe also need to select based on height? and also ensure small branches classified + self.connected_nodes[nodeless_labels == i] = 3 + + return self.connected_nodes + + def highlight_node_centres(self, mask: npt.NDArray) -> npt.NDArray: + """ + Calculate the node centres based on height and re-plot on the mask. + + Parameters + ---------- + mask : npt.NDArray + 2D array with background = 0, skeleton = 1, endpoints = 2, node_centres = 3. + + Returns + ------- + npt.NDArray + 2D array with the highest node coordinate for each node labeled as 3. + """ + small_node_mask = mask.copy() + small_node_mask[mask == 3] = 1 # remap nodes to skeleton + big_nodes = mask.copy() + big_nodes = np.where(mask == 3, 1, 0) # remove non-nodes & set nodes to 1 + big_node_mask = label(big_nodes) + + for i in np.delete(np.unique(big_node_mask), 0): # get node indices + centre = np.unravel_index((self.image * (big_node_mask == i).astype(int)).argmax(), self.image.shape) + small_node_mask[centre] = 3 + + return small_node_mask + + def connect_extended_nodes_nearest( + self, connected_nodes: npt.NDArray, node_extend_dist: float = -1 + ) -> npt.NDArray[np.int32]: + """ + Extend the odd branched nodes to other odd branched nodes within the 'extend_dist' threshold. + + Parameters + ---------- + connected_nodes : npt.NDArray + A 2D array representing the network with background = 0, skeleton = 1, endpoints = 2, + node_centres = 3. + node_extend_dist : int | float, optional + The distance over which to connect odd-branched nodes, by default -1 for no-limit. + + Returns + ------- + npt.NDArray[np.int32] + Connected nodes array with odd-branched nodes connected. + """ + just_nodes = np.where(connected_nodes == 3, 1, 0) # remove branches & termini points + labelled_nodes = label(just_nodes) + + just_branches = np.where(connected_nodes == 1, 1, 0) # remove node & termini points + just_branches[connected_nodes == 1] = labelled_nodes.max() + 1 + labelled_branches = label(just_branches) + + nodes_with_branch_starting_coords = find_branches_for_nodes( + network_array_representation=connected_nodes, + labelled_nodes=labelled_nodes, + labelled_branches=labelled_branches, + ) + + # If there is only one node, then there is no need to connect the nodes since there is nothing to + # connect it to. Return the original connected_nodes instead. + if len(nodes_with_branch_starting_coords) <= 1: + self.connected_nodes = connected_nodes + return self.connected_nodes + + assert self.whole_skel_graph is not None, "Whole skeleton graph is not defined." # for type safety + shortest_node_dists, shortest_dists_branch_idxs, _shortest_dist_coords = calculate_shortest_branch_distances( + nodes_with_branch_starting_coords=nodes_with_branch_starting_coords, + whole_skeleton_graph=self.whole_skel_graph, + ) + + # Matches is an Nx2 numpy array of indexes of the best matching nodes. + # Eg: np.array([[1, 0], [2, 3]]) means that the best matching nodes are + # node 1 and node 0, and node 2 and node 3. + matches: npt.NDArray[np.int32] = self.best_matches(shortest_node_dists, max_weight_matching=False) + + # Connect the nodes by their best matches, using the shortest distances between their branch starts. + connected_nodes = connect_best_matches( + network_array_representation=connected_nodes, + whole_skeleton_graph=self.whole_skel_graph, + match_indexes=matches, + shortest_distances_between_nodes=shortest_node_dists, + shortest_distances_branch_indexes=shortest_dists_branch_idxs, + emanating_branch_starts_by_node=nodes_with_branch_starting_coords, + extend_distance=node_extend_dist, + ) + + self.connected_nodes = connected_nodes + return self.connected_nodes + + @staticmethod + def find_branch_starts(reduced_node_image: npt.NDArray) -> npt.NDArray: + """ + Find the coordinates where the branches connect to the node region through binary dilation of the node. + + Parameters + ---------- + reduced_node_image : npt.NDArray + A 2D numpy array containing a single node region (=3) and its connected branches (=1). + + Returns + ------- + npt.NDArray + Coordinate array of pixels next to crossing points (=3 in input). + """ + node = np.where(reduced_node_image == 3, 1, 0) + nodeless = np.where(reduced_node_image == 1, 1, 0) + thick_node = binary_dilation(node, structure=np.ones((3, 3))) + + return np.argwhere(thick_node * nodeless == 1) + + # pylint: disable=too-many-locals + def analyse_nodes(self, max_branch_length: float = 20) -> None: + """ + Obtain the main analyses for the nodes of a single molecule along the 'max_branch_length'(nm) from the node. + + Parameters + ---------- + max_branch_length : float + The side length of the box around the node to analyse (in nm). + """ + # Get coordinates of nodes + # This is a numpy array of coords, shape Nx2 + assert self.node_centre_mask is not None, "Node centre mask is not defined." + node_coords: npt.NDArray[np.int32] = np.argwhere(self.node_centre_mask.copy() == 3) + + # Check whether average trace resides inside the grain mask + # Checks if we dilate the skeleton once or twice, then all the pixels should fit in the grain mask + dilate = binary_dilation(self.skeleton, iterations=2) + # This flag determines whether to use average of 3 traces in calculation of FWHM + average_trace_advised = dilate[self.smoothed_mask == 1].sum() == dilate.sum() + LOGGER.debug(f"[{self.filename}] : Branch height traces will be averaged: {average_trace_advised}") + + # Iterate over the nodes and analyse the branches + matched_branches = None + branch_image = None + avg_image = np.zeros_like(self.image) + real_node_count = 0 + for node_no, (node_x, node_y) in enumerate(node_coords): + unmatched_branches = {} + error = False + + # Get branches relevant to the node + max_length_px = max_branch_length / (self.pixel_to_nm_scaling * 1) + reduced_node_area: npt.NDArray[np.int32] = nodeStats.only_centre_branches( + self.connected_nodes, np.array([node_x, node_y]) + ) + # Reduced skel graph is a networkx graph of the reduced node area. + reduced_skel_graph: nx.classes.graph.Graph = nodeStats.skeleton_image_to_graph(reduced_node_area) + + # Binarise the reduced node area + branch_mask = reduced_node_area.copy() + branch_mask[branch_mask == 3] = 0 + branch_mask[branch_mask == 2] = 1 + node_coords = np.argwhere(reduced_node_area == 3) + + # Find the starting coordinates of any branches connected to the node + branch_start_coords = self.find_branch_starts(reduced_node_area) + + # Stop processing if nib (node has 2 branches) + if branch_start_coords.shape[0] <= 2: + LOGGER.debug( + f"node {node_no} has only two branches - skipped & nodes removed.{len(node_coords)}" + "pixels in nib node." + ) + else: + try: + real_node_count += 1 + LOGGER.debug(f"Node: {real_node_count}") + + # Analyse the node branches + ( + pairs, + matched_branches, + ordered_branches, + masked_image, + branch_under_over_order, + confidence, + singlet_branch_vectors, + ) = nodeStats.analyse_node_branches( + p_to_nm=self.pixel_to_nm_scaling, + reduced_node_area=reduced_node_area, + branch_start_coords=branch_start_coords, + max_length_px=max_length_px, + reduced_skeleton_graph=reduced_skel_graph, + image=self.image, + average_trace_advised=average_trace_advised, + node_coord=(node_x, node_y), + pair_odd_branches=self.pair_odd_branches, + filename=self.filename, + resolution_threshold=np.float64(1000 / 512), + ) + + # Add the analysed branches to the labelled image + branch_image, avg_image = nodeStats.add_branches_to_labelled_image( + branch_under_over_order=branch_under_over_order, + matched_branches=matched_branches, + masked_image=masked_image, + branch_start_coords=branch_start_coords, + ordered_branches=ordered_branches, + pairs=pairs, + average_trace_advised=average_trace_advised, + image_shape=(self.image.shape[0], self.image.shape[1]), + ) + + # Calculate crossing angles of unpaired branches and add to stats dict + nodestats_calc_singlet_angles_result = nodeStats.calc_angles(np.asarray(singlet_branch_vectors)) + angles_between_singlet_branch_vectors: npt.NDArray[np.float64] = ( + nodestats_calc_singlet_angles_result[0] + ) + + for branch_index, angle in enumerate(angles_between_singlet_branch_vectors): + unmatched_branches[branch_index] = {"angles": angle} + + # Get the vector of each branch based on ordered_coords. Ordered_coords is only the first N nm + # of the branch so this is just a general vibe on what direction a branch is going. + if len(branch_start_coords) % 2 == 0 or self.pair_odd_branches: + vectors: list[npt.NDArray[np.float64]] = [] + for _, values in matched_branches.items(): + vectors.append(nodeStats.get_vector(values["ordered_coords"], np.array([node_x, node_y]))) + # Calculate angles between the vectors + nodestats_calc_angles_result = nodeStats.calc_angles(np.asarray(vectors)) + angles_between_vectors_along_branch: npt.NDArray[np.float64] = nodestats_calc_angles_result[0] + for branch_index, angle in enumerate(angles_between_vectors_along_branch): + if len(branch_start_coords) % 2 == 0 or self.pair_odd_branches: + matched_branches[branch_index]["angles"] = angle + else: + self.image_dict["grain"]["grain_skeleton"][node_coords[:, 0], node_coords[:, 1]] = 0 + + # Eg: length 2 array: [array([ nan, 79.00]), array([79.00, 0.0])] + # angles_between_vectors_along_branch + + except ResolutionError: + LOGGER.debug(f"Node stats skipped as resolution too low: {self.pixel_to_nm_scaling}nm per pixel") + error = True + + self.node_dicts[f"node_{real_node_count}"] = { + "error": error, + "pixel_to_nm_scaling": self.pixel_to_nm_scaling, + "branch_stats": matched_branches, + "unmatched_branch_stats": unmatched_branches, + "node_coords": node_coords, + "confidence": confidence, + } + + assert reduced_node_area is not None, "Reduced node area is not defined." + assert branch_image is not None, "Branch image is not defined." + assert avg_image is not None, "Average image is not defined." + node_images_dict: dict[str, npt.NDArray[np.int32]] = { + "node_area_skeleton": reduced_node_area, + "node_branch_mask": branch_image, + "node_avg_mask": avg_image, + } + self.image_dict["nodes"][f"node_{real_node_count}"] = node_images_dict + + self.all_connected_nodes[self.connected_nodes != 0] = self.connected_nodes[self.connected_nodes != 0] + + # pylint: disable=too-many-arguments + @staticmethod + def add_branches_to_labelled_image( + branch_under_over_order: npt.NDArray[np.int32], + matched_branches: dict[int, MatchedBranch], + masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]], + branch_start_coords: npt.NDArray[np.int32], + ordered_branches: list[npt.NDArray[np.int32]], + pairs: npt.NDArray[np.int32], + average_trace_advised: bool, + image_shape: tuple[int, int], + ) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]: + """ + Add branches to a labelled image. + + Parameters + ---------- + branch_under_over_order : npt.NDArray[np.int32] + The order of the branches. + matched_branches : dict[int, dict[str, MatchedBranch]] + Dictionary where the key is the index of the pair and the value is a dictionary containing the following + keys: + - "ordered_coords" : npt.NDArray[np.int32]. + - "heights" : npt.NDArray[np.number]. Heights of the branches. + - "distances" : + - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches. + masked_image : dict[int, dict[str, npt.NDArray[np.bool_]]] + Dictionary where the key is the index of the pair and the value is a dictionary containing the following + keys: + - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches. + branch_start_coords : npt.NDArray[np.int32] + An Nx2 numpy array of the coordinates of the branches connected to the node. + ordered_branches : list[npt.NDArray[np.int32]] + List of numpy arrays of ordered branch coordinates. + pairs : npt.NDArray[np.int32] + Nx2 numpy array of pairs of branches that are matched through a node. + average_trace_advised : bool + Flag to determine whether to use the average trace. + image_shape : tuple[int] + The shape of the image, to create a mask from. + + Returns + ------- + tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]] + The branch image and the average image. + """ + branch_image: npt.NDArray[np.int32] = np.zeros(image_shape).astype(np.int32) + avg_image: npt.NDArray[np.int32] = np.zeros(image_shape).astype(np.int32) + + for i, branch_index in enumerate(branch_under_over_order): + branch_coords = matched_branches[branch_index]["ordered_coords"] + + # Add the matched branch to the image, starting at index 1 + branch_image[branch_coords[:, 0], branch_coords[:, 1]] = i + 1 + if average_trace_advised: + # For type safety, check if avg_image is None and skip if so. + # This is because the type hinting does not allow for None in the array. + avg_image[masked_image[branch_index]["avg_mask"] != 0] = i + 1 + + # Determine branches that were not able to be paired + unpaired_branches = np.delete(np.arange(0, branch_start_coords.shape[0]), pairs.flatten()) + LOGGER.debug(f"Unpaired branches: {unpaired_branches}") + # Ensure that unpaired branches start at index I where I is the number of paired branches. + branch_label = branch_image.max() + # Add the unpaired branches back to the branch image + for i in unpaired_branches: + branch_label += 1 + branch_image[ordered_branches[i][:, 0], ordered_branches[i][:, 1]] = branch_label + + return branch_image, avg_image + + @staticmethod + def analyse_node_branches( + p_to_nm: np.float64, + reduced_node_area: npt.NDArray[np.int32], + branch_start_coords: npt.NDArray[np.int32], + max_length_px: np.float64, + reduced_skeleton_graph: nx.classes.graph.Graph, + image: npt.NDArray[np.number], + average_trace_advised: bool, + node_coord: tuple[np.int32, np.int32], + pair_odd_branches: bool, + filename: str, + resolution_threshold: np.float64, + ) -> tuple[ + npt.NDArray[np.int32], + dict[int, MatchedBranch], + list[npt.NDArray[np.int32]], + dict[int, dict[str, npt.NDArray[np.bool_]]], + npt.NDArray[np.int32], + np.float64 | None, + ]: + """ + Analyse the branches of a single node. + + Parameters + ---------- + p_to_nm : np.float64 + The pixel to nm scaling factor. + reduced_node_area : npt.NDArray[np.int32] + An NxM numpy array of the node in question and the branches connected to it. + Node is marked by 3, and branches by 1. + branch_start_coords : npt.NDArray[np.int32] + An Nx2 numpy array of the coordinates of the branches connected to the node. + max_length_px : np.int32 + The maximum length in pixels to traverse along while ordering. + reduced_skeleton_graph : nx.classes.graph.Graph + The graph representation of the reduced node area. + image : npt.NDArray[np.number] + The full image of the grain. + average_trace_advised : bool + Flag to determine whether to use the average trace. + node_coord : tuple[np.int32, np.int32] + The node coordinates. + pair_odd_branches : bool + Whether to try and pair odd-branched nodes. + filename : str + The filename of the image. + resolution_threshold : np.float64 + The resolution threshold below which to warn the user that the node is difficult to analyse. + + Returns + ------- + pairs: npt.NDArray[np.int32] + Nx2 numpy array of pairs of branches that are matched through a node. + matched_branches: dict[int, MatchedBranch]] + Dictionary where the key is the index of the pair and the value is a dictionary containing the following + keys: + - "ordered_coords" : npt.NDArray[np.int32]. + - "heights" : npt.NDArray[np.number]. Heights of the branches. + - "distances" : npt.NDArray[np.number]. The accumulating distance along the branch. + - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches. + - "angles" : np.float64. The angle of the branch, added in later steps. + ordered_branches: list[npt.NDArray[np.int32]] + List of numpy arrays of ordered branch coordinates. + masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] + Dictionary where the key is the index of the pair and the value is a dictionary containing the following + keys: + - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches. + branch_under_over_order: npt.NDArray[np.int32] + The order of the branches based on the FWHM. + confidence: np.float64 | None + The confidence of the crossing. Optional. + """ + if not p_to_nm <= resolution_threshold: + LOGGER.debug(f"Resolution {p_to_nm} is below suggested {resolution_threshold}, node difficult to analyse.") + + # Pixel-wise order the branches coming from the node and calculate the starting vector for each branch + ordered_branches, singlet_branch_vectors = nodeStats.get_ordered_branches_and_vectors( + reduced_node_area, branch_start_coords, max_length_px + ) + + # Pair the singlet branch vectors based on their suitability using vector orientation. + if len(branch_start_coords) % 2 == 0 or pair_odd_branches: + pairs = nodeStats.pair_vectors(np.asarray(singlet_branch_vectors)) + else: + pairs = np.array([], dtype=np.int32) + + # Match the branches up + matched_branches, masked_image = nodeStats.join_matching_branches_through_node( + pairs, + ordered_branches, + reduced_skeleton_graph, + image, + average_trace_advised, + node_coord, + filename, + ) + + # Redo the FWHMs after the processing for more accurate determination of under/overs. + hms = [] + for _, values in matched_branches.items(): + hms.append(values["fwhm"]["half_maxs"][2]) + for _, values in matched_branches.items(): + values["fwhm"] = nodeStats.calculate_fwhm(values["heights"], values["distances"], hm=max(hms)) + + # Get the confidence of the crossing + crossing_fwhms = [] + for _, values in matched_branches.items(): + crossing_fwhms.append(values["fwhm"]["fwhm"]) + if len(crossing_fwhms) <= 1: + confidence = None + else: + crossing_fwhm_combinations = list(combinations(crossing_fwhms, 2)) + confidence = np.float64(nodeStats.cross_confidence(crossing_fwhm_combinations)) + + # Order the branch indexes based on the FWHM of the branches. + branch_under_over_order = np.array(list(matched_branches.keys()))[np.argsort(np.array(crossing_fwhms))] + + return ( + pairs, + matched_branches, + ordered_branches, + masked_image, + branch_under_over_order, + confidence, + singlet_branch_vectors, + ) + + @staticmethod + def join_matching_branches_through_node( + pairs: npt.NDArray[np.int32], + ordered_branches: list[npt.NDArray[np.int32]], + reduced_skeleton_graph: nx.classes.graph.Graph, + image: npt.NDArray[np.number], + average_trace_advised: bool, + node_coords: tuple[np.int32, np.int32], + filename: str, + ) -> tuple[dict[int, MatchedBranch], dict[int, dict[str, npt.NDArray[np.bool_]]]]: + """ + Join branches that are matched through a node. + + Parameters + ---------- + pairs : npt.NDArray[np.int32] + Nx2 numpy array of pairs of branches that are matched through a node. + ordered_branches : list[npt.NDArray[np.int32]] + List of numpy arrays of ordered branch coordinates. + reduced_skeleton_graph : nx.classes.graph.Graph + Graph representation of the skeleton. + image : npt.NDArray[np.number] + The full image of the grain. + average_trace_advised : bool + Flag to determine whether to use the average trace. + node_coords : tuple[np.int32, np.int32] + The node coordinates. + filename : str + The filename of the image. + + Returns + ------- + matched_branches: dict[int, dict[str, npt.NDArray[np.number]]] + Dictionary where the key is the index of the pair and the value is a dictionary containing the following + keys: + - "ordered_coords" : npt.NDArray[np.int32]. + - "heights" : npt.NDArray[np.number]. Heights of the branches. + - "distances" : + - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches. + masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] + Dictionary where the key is the index of the pair and the value is a dictionary containing the following + keys: + - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches. + """ + matched_branches: dict[int, MatchedBranch] = {} + masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] = ( + {} + ) # Masked image is a dictionary of pairs of branches + for i, (branch_1, branch_2) in enumerate(pairs): + matched_branches[i] = MatchedBranch( + ordered_coords=np.array([], dtype=np.int32), + heights=np.array([], dtype=np.float64), + distances=np.array([], dtype=np.float64), + fwhm={}, + angles=None, + ) + masked_image[i] = {} + # find close ends by rearranging branch coords + branch_1_coords, branch_2_coords = nodeStats.order_branches( + ordered_branches[branch_1], ordered_branches[branch_2] + ) + # Get graphical shortest path between branch ends on the skeleton + crossing = nx.shortest_path( + reduced_skeleton_graph, + source=tuple(branch_1_coords[-1]), + target=tuple(branch_2_coords[0]), + weight="weight", + ) + crossing = np.asarray(crossing[1:-1]) # remove start and end points & turn into array + # Branch coords and crossing + if crossing.shape == (0,): + branch_coords = np.vstack([branch_1_coords, branch_2_coords]) + else: + branch_coords = np.vstack([branch_1_coords, crossing, branch_2_coords]) + # make images of single branch joined and multiple branches joined + single_branch_img: npt.NDArray[np.bool_] = np.zeros_like(image).astype(bool) + single_branch_img[branch_coords[:, 0], branch_coords[:, 1]] = True + single_branch_coords = order_branch(single_branch_img.astype(bool), [0, 0]) + # calc image-wide coords + matched_branches[i]["ordered_coords"] = single_branch_coords + # get heights and trace distance of branch + try: + assert average_trace_advised + distances, heights, mask, _ = nodeStats.average_height_trace( + image, single_branch_img, single_branch_coords, [node_coords[0], node_coords[1]] + ) + masked_image[i]["avg_mask"] = mask + except ( + AssertionError, + IndexError, + ) as e: # Assertion - avg trace not advised, Index - wiggy branches + LOGGER.debug(f"[{filename}] : avg trace failed with {e}, single trace only.") + average_trace_advised = False + distances = nodeStats.coord_dist_rad(single_branch_coords, np.array([node_coords[0], node_coords[1]])) + # distances = self.coord_dist(single_branch_coords) + zero_dist = distances[ + np.argmin( + np.sqrt( + (single_branch_coords[:, 0] - node_coords[0]) ** 2 + + (single_branch_coords[:, 1] - node_coords[1]) ** 2 + ) + ) + ] + heights = image[single_branch_coords[:, 0], single_branch_coords[:, 1]] # self.hess + distances = distances - zero_dist + distances, heights = nodeStats.average_uniques( + distances, heights + ) # needs to be paired with coord_dist_rad + matched_branches[i]["heights"] = heights + matched_branches[i]["distances"] = distances + # identify over/under + matched_branches[i]["fwhm"] = nodeStats.calculate_fwhm(heights, distances) + + return matched_branches, masked_image + + @staticmethod + def get_ordered_branches_and_vectors( + reduced_node_area: npt.NDArray[np.int32], + branch_start_coords: npt.NDArray[np.int32], + max_length_px: np.float64, + ) -> tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]]: + """ + Get ordered branches and vectors for a node. + + Branches are ordered so they are no longer just a disordered set of coordinates, and vectors are calculated to + represent the general direction tendency of the branch, this allows for alignment matching later on. + + Parameters + ---------- + reduced_node_area : npt.NDArray[np.int32] + An NxM numpy array of the node in question and the branches connected to it. + Node is marked by 3, and branches by 1. + branch_start_coords : npt.NDArray[np.int32] + An Px2 numpy array of coordinates representing the start of branches where P is the number of branches. + max_length_px : np.int32 + The maximum length in pixels to traverse along while ordering. + + Returns + ------- + tuple[list[npt.NDArray[np.int32]], list[npt.NDArray[np.int32]]] + A tuple containing a list of ordered branches and a list of vectors. + """ + ordered_branches = [] + vectors = [] + nodeless = np.where(reduced_node_area == 1, 1, 0) + for branch_start_coord in branch_start_coords: + # Order the branch coordinates so they're no longer just a disordered set of coordinates + ordered_branch = order_branch_from_start(nodeless.copy(), branch_start_coord, max_length=max_length_px) + ordered_branches.append(ordered_branch) + + # Calculate vector to represent the general direction tendency of the branch (for alignment matching) + vector = nodeStats.get_vector(ordered_branch, branch_start_coord) + vectors.append(vector) + + return ordered_branches, vectors + + @staticmethod + def cross_confidence(pair_combinations: list) -> float: + """ + Obtain the average confidence of the combinations using a reciprical function. + + Parameters + ---------- + pair_combinations : list + List of length 2 combinations of FWHM values. + + Returns + ------- + float + The average crossing confidence. + """ + c = 0 + for pair in pair_combinations: + c += nodeStats.recip(pair) + return c / len(pair_combinations) + + @staticmethod + def recip(vals: list) -> float: + """ + Compute 1 - (max / min) of the two values provided. + + Parameters + ---------- + vals : list + List of 2 values. + + Returns + ------- + float + Result of applying the 1-(min / max) function to the two values. + """ + try: + if min(vals) == 0: # means fwhm variation hasn't worked + return 0 + return 1 - min(vals) / max(vals) + except ZeroDivisionError: + return 0 + + @staticmethod + def get_vector(coords: npt.NDArray, origin: npt.NDArray) -> npt.NDArray: + """ + Calculate the normalised vector of the coordinate means in a branch. + + Parameters + ---------- + coords : npt.NDArray + 2xN array of x, y coordinates. + origin : npt.NDArray + 2x1 array of an x, y coordinate. + + Returns + ------- + npt.NDArray + Normalised vector from origin to the mean coordinate. + """ + vector = coords.mean(axis=0) - origin + norm = np.sqrt(vector @ vector) + return vector if norm == 0 else vector / norm # normalise vector so length=1 + + @staticmethod + def calc_angles(vectors: npt.NDArray) -> npt.NDArray[np.float64]: + """ + Calculate the angles between vectors in an array. + + Uses the formula: cos(theta) = |a|•|b|/|a||b| + + Parameters + ---------- + vectors : npt.NDArray + Array of 2x1 vectors. + + Returns + ------- + npt.NDArray + An array of the cosine of the angles between the vectors. + """ + dot = vectors @ vectors.T + norm = np.diag(dot) ** 0.5 + cos_angles = dot / (norm.reshape(-1, 1) @ norm.reshape(1, -1)) + return abs(np.arccos(cos_angles) / np.pi * 180) # angles in degrees + + @staticmethod + def pair_vectors(vectors: npt.NDArray) -> npt.NDArray[np.int32]: + """ + Take a list of vectors and pairs them based on the angle between them. + + Parameters + ---------- + vectors : npt.NDArray + Array of 2x1 vectors to be paired. + + Returns + ------- + npt.NDArray + An array of the matching pair indices. + """ + # calculate cosine of angle + angles = nodeStats.calc_angles(vectors) + # find highest values + np.fill_diagonal(angles, 0) # ensures not paired with itself + # match angles + return nodeStats.best_matches(angles) + + @staticmethod + def best_matches(arr: npt.NDArray, max_weight_matching: bool = True) -> npt.NDArray: + """ + Turn a matrix into a graph and calculates the best matching index pairs. + + Parameters + ---------- + arr : npt.NDArray + Transpose symmetric MxM array where the value of index i, j represents a weight between i and j. + max_weight_matching : bool + Whether to obtain best matching pairs via maximum weight, or minimum weight matching. + + Returns + ------- + npt.NDArray + Array of pairs of indexes. + """ + if max_weight_matching: + G = nodeStats.create_weighted_graph(arr) + matching = np.array(list(nx.max_weight_matching(G, maxcardinality=True))) + else: + np.fill_diagonal(arr, arr.max() + 1) + G = nodeStats.create_weighted_graph(arr) + matching = np.array(list(nx.min_weight_matching(G))) + return matching + + @staticmethod + def create_weighted_graph(matrix: npt.NDArray) -> nx.Graph: + """ + Create a bipartite graph connecting i <-> j from a square matrix of weights matrix[i, j]. + + Parameters + ---------- + matrix : npt.NDArray + Square array of weights between rows and columns. + + Returns + ------- + nx.Graph + Bipatrite graph with edge weight i->j matching matrix[i,j]. + """ + n = len(matrix) + G = nx.Graph() + for i in range(n): + for j in range(i + 1, n): + G.add_edge(i, j, weight=matrix[i, j]) + return G + + @staticmethod + def pair_angles(angles: npt.NDArray) -> list: + """ + Pair angles that are 180 degrees to each other and removes them before selecting the next pair. + + Parameters + ---------- + angles : npt.NDArray + Square array (i,j) of angles between i and j. + + Returns + ------- + list + A list of paired indexes in a list. + """ + angles_cp = angles.copy() + pairs = [] + for _ in range(int(angles.shape[0] / 2)): + pair = np.unravel_index(np.argmax(angles_cp), angles.shape) + pairs.append(pair) # add to list + angles_cp[[pair]] = 0 # set rows 0 to avoid picking again + angles_cp[:, [pair]] = 0 # set cols 0 to avoid picking again + + return np.asarray(pairs) + + @staticmethod + def gaussian(x: npt.NDArray, h: float, mean: float, sigma: float): + """ + Apply the gaussian function. + + Parameters + ---------- + x : npt.NDArray + X values to be passed into the gaussian. + h : float + The peak height of the gaussian. + mean : float + The mean of the x values. + sigma : float + The standard deviation of the image. + + Returns + ------- + npt.NDArray + The y-values of the gaussian performed on the x values. + """ + return h * np.exp(-((x - mean) ** 2) / (2 * sigma**2)) + + @staticmethod + def interpolate_between_yvalue(x: npt.NDArray, y: npt.NDArray, yvalue: float) -> float: + """ + Calculate the x value between the two points either side of yvalue in y. + + Parameters + ---------- + x : npt.NDArray + An array of length y. + y : npt.NDArray + An array of length x. + yvalue : float + A value within the bounds of the y array. + + Returns + ------- + float + The linearly interpolated x value between the arrays. + """ + for i in range(len(y) - 1): + if y[i] <= yvalue <= y[i + 1] or y[i + 1] <= yvalue <= y[i]: # if points cross through the hm value + return nodeStats.lin_interp([x[i], y[i]], [x[i + 1], y[i + 1]], yvalue=yvalue) + return 0 + + @staticmethod + def calculate_fwhm( + heights: npt.NDArray, distances: npt.NDArray, hm: float | None = None + ) -> dict[str, np.float64 | list[np.float64 | float | None]]: + """ + Calculate the FWHM value. + + First identifyies the HM then finding the closest values in the distances array and using + linear interpolation to calculate the FWHM. + + Parameters + ---------- + heights : npt.NDArray + Array of heights. + distances : npt.NDArray + Array of distances. + hm : Union[None, float], optional + The halfmax value to match (if wanting the same HM between curves), by default None. + + Returns + ------- + tuple[float, list, list] + The FWHM value, [distance at hm for 1st half of trace, distance at hm for 2nd half of trace, + HM value], [index of the highest point, distance at highest point, height at highest point]. + """ + centre_fraction = int(len(heights) * 0.2) # in case zone approaches another node, look around centre for max + if centre_fraction == 0: + high_idx = np.argmax(heights) + else: + high_idx = np.argmax(heights[centre_fraction:-centre_fraction]) + centre_fraction + # get array halves to find first points that cross hm + arr1 = heights[:high_idx][::-1] + dist1 = distances[:high_idx][::-1] + arr2 = heights[high_idx:] + dist2 = distances[high_idx:] + if hm is None: + # Get half max + hm = (heights.max() - heights.min()) / 2 + heights.min() + # half max value -> try to make it the same as other crossing branch? + # increase make hm = lowest of peak if it doesn’t hit one side + if np.min(arr1) > hm: + arr1_local_min = argrelextrema(arr1, np.less)[-1] # closest to end + try: + hm = arr1[arr1_local_min][0] + except IndexError: # index error when no local minima + hm = np.min(arr1) + elif np.min(arr2) > hm: + arr2_local_min = argrelextrema(arr2, np.less)[0] # closest to start + try: + hm = arr2[arr2_local_min][0] + except IndexError: # index error when no local minima + hm = np.min(arr2) + arr1_hm = nodeStats.interpolate_between_yvalue(x=dist1, y=arr1, yvalue=hm) + arr2_hm = nodeStats.interpolate_between_yvalue(x=dist2, y=arr2, yvalue=hm) + fwhm = np.float64(abs(arr2_hm - arr1_hm)) + return { + "fwhm": fwhm, + "half_maxs": [arr1_hm, arr2_hm, hm], + "peaks": [high_idx, distances[high_idx], heights[high_idx]], + } + + @staticmethod + def lin_interp(point_1: list, point_2: list, xvalue: float | None = None, yvalue: float | None = None) -> float: + """ + Linear interp 2 points by finding line equation and subbing. + + Parameters + ---------- + point_1 : list + List of an x and y coordinate. + point_2 : list + List of an x and y coordinate. + xvalue : Union[float, None], optional + Value at which to interpolate to get a y coordinate, by default None. + yvalue : Union[float, None], optional + Value at which to interpolate to get an x coordinate, by default None. + + Returns + ------- + float + Value of x or y linear interpolation. + """ + m = (point_1[1] - point_2[1]) / (point_1[0] - point_2[0]) + c = point_1[1] - (m * point_1[0]) + if xvalue is not None: + return m * xvalue + c # interp_y + if yvalue is not None: + return (yvalue - c) / m # interp_x + raise ValueError + + @staticmethod + def order_branches(branch1: npt.NDArray, branch2: npt.NDArray) -> tuple: + """ + Order the two ordered arrays based on the closest endpoint coordinates. + + Parameters + ---------- + branch1 : npt.NDArray + An Nx2 array describing coordinates. + branch2 : npt.NDArray + An Nx2 array describing coordinates. + + Returns + ------- + tuple + An tuple with the each coordinate array ordered to follow on from one-another. + """ + endpoints1 = np.asarray([branch1[0], branch1[-1]]) + endpoints2 = np.asarray([branch2[0], branch2[-1]]) + sum1 = abs(endpoints1 - endpoints2).sum(axis=1) + sum2 = abs(endpoints1[::-1] - endpoints2).sum(axis=1) + if sum1.min() < sum2.min(): + if np.argmin(sum1) == 0: + return branch1[::-1], branch2 + return branch1, branch2[::-1] + if np.argmin(sum2) == 0: + return branch1, branch2 + return branch1[::-1], branch2[::-1] + + @staticmethod + def binary_line(start: npt.NDArray, end: npt.NDArray) -> npt.NDArray: + """ + Create a binary path following the straight line between 2 points. + + Parameters + ---------- + start : npt.NDArray + A coordinate. + end : npt.NDArray + Another coordinate. + + Returns + ------- + npt.NDArray + An Nx2 coordinate array that the line passes through. + """ + arr = [] + m_swap = False + x_swap = False + slope = (end - start)[1] / (end - start)[0] + + if abs(slope) > 1: # swap x and y if slope will cause skips + start, end = start[::-1], end[::-1] + slope = 1 / slope + m_swap = True + + if start[0] > end[0]: # swap x coords if coords wrong way around + start, end = end, start + x_swap = True + + # code assumes slope < 1 hence swap + x_start, y_start = start + x_end, _ = end + for x in range(x_start, x_end + 1): + y_true = slope * (x - x_start) + y_start + y_pixel = np.round(y_true) + arr.append([x, y_pixel]) + + if m_swap: # if swapped due to slope, return + arr = np.asarray(arr)[:, [1, 0]].reshape(-1, 2).astype(int) + if x_swap: + return arr[::-1] + return arr + arr = np.asarray(arr).reshape(-1, 2).astype(int) + if x_swap: + return arr[::-1] + return arr + + @staticmethod + def coord_dist_rad(coords: npt.NDArray, centre: npt.NDArray, pixel_to_nm_scaling: float = 1) -> npt.NDArray: + """ + Calculate the distance from the centre coordinate to a point along the ordered coordinates. + + This differs to traversal along the coordinates taken. This also averages any common distance + values and makes those in the trace before the node index negative. + + Parameters + ---------- + coords : npt.NDArray + Nx2 array of branch coordinates. + centre : npt.NDArray + A 1x2 array of the centre coordinates to identify a 0 point for the node. + pixel_to_nm_scaling : float, optional + The pixel to nanometer scaling factor to provide real units, by default 1. + + Returns + ------- + npt.NDArray + A Nx1 array of the distance from the node centre. + """ + diff_coords = coords - centre + if np.all(coords == centre, axis=1).sum() == 0: # if centre not in coords, reassign centre + diff_dists = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2) + centre = coords[np.argmin(diff_dists)] + cross_idx = np.argwhere(np.all(coords == centre, axis=1)) + rad_dist = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2) + rad_dist[0 : cross_idx[0][0]] *= -1 + return rad_dist * pixel_to_nm_scaling + + @staticmethod + def above_below_value_idx(array: npt.NDArray, value: float) -> list: + """ + Identify indices of the array neighbouring the specified value. + + Parameters + ---------- + array : npt.NDArray + Array of values. + value : float + Value to identify indices between. + + Returns + ------- + list + List of the lower index and higher index around the value. + + Raises + ------ + IndexError + When the value is in the array. + """ + idx1 = abs(array - value).argmin() + try: + if array[idx1] < value < array[idx1 + 1]: + idx2 = idx1 + 1 + elif array[idx1 - 1] < value < array[idx1]: + idx2 = idx1 - 1 + else: + raise IndexError # this will be if the number is the same + indices = [idx1, idx2] + indices.sort() + return indices + except IndexError: + return None + + @staticmethod + def average_height_trace( # noqa: C901 + img: npt.NDArray, branch_mask: npt.NDArray, branch_coords: npt.NDArray, centre=(0, 0) + ) -> tuple: + """ + Average two side-by-side ordered skeleton distance and height traces. + + Dilate the original branch to create two additional side-by-side branches + in order to get a more accurate average of the height traces. This function produces + the common distances between these 3 branches, and their averaged heights. + + Parameters + ---------- + img : npt.NDArray + An array of numbers pertaining to an image. + branch_mask : npt.NDArray + A binary array of the branch, must share the same dimensions as the image. + branch_coords : npt.NDArray + Ordered coordinates of the branch mask. + centre : Union[float, None] + The coordinates to centre the branch around. + + Returns + ------- + tuple + A tuple of the averaged heights from the linetrace and their corresponding distances + from the crossing. + """ + # get heights and dists of the original (middle) branch + branch_dist = nodeStats.coord_dist_rad(branch_coords, centre) + # branch_dist = self.coord_dist(branch_coords) + branch_heights = img[branch_coords[:, 0], branch_coords[:, 1]] + branch_dist, branch_heights = nodeStats.average_uniques( + branch_dist, branch_heights + ) # needs to be paired with coord_dist_rad + dist_zero_point = branch_dist[ + np.argmin(np.sqrt((branch_coords[:, 0] - centre[0]) ** 2 + (branch_coords[:, 1] - centre[1]) ** 2)) + ] + branch_dist_norm = branch_dist - dist_zero_point # - 0 # branch_dist[branch_heights.argmax()] + + # want to get a 3 pixel line trace, one on each side of orig + dilate = binary_dilation(branch_mask, iterations=1) + dilate = nodeStats.fill_holes(dilate) + dilate_minus = np.where(dilate != branch_mask, 1, 0) + dilate2 = binary_dilation(dilate, iterations=1) + dilate2[(dilate == 1) | (branch_mask == 1)] = 0 + labels = label(dilate2) + # Cleanup stages - re-entering, early terminating, closer traces + # if parallel trace out and back in zone, can get > 2 labels + labels = nodeStats._remove_re_entering_branches(labels, remaining_branches=2) + # if parallel trace doesn't exit window, can get 1 label + # occurs when skeleton has poor connections (extra branches which cut corners) + if labels.max() == 1: + conv = convolve_skeleton(branch_mask) + endpoints = np.argwhere(conv == 2) + for endpoint in endpoints: # may be >1 endpoint + para_trace_coords = np.argwhere(labels == 1) + abs_diff = np.absolute(para_trace_coords - endpoint).sum(axis=1) + min_idxs = np.where(abs_diff == abs_diff.min()) + trace_coords_remove = para_trace_coords[min_idxs] + labels[trace_coords_remove[:, 0], trace_coords_remove[:, 1]] = 0 + labels = label(labels) + # reduce binary dilation distance + parallel = np.zeros_like(branch_mask).astype(np.int32) + for i in range(1, labels.max() + 1): + single = labels.copy() + single[single != i] = 0 + single[single == i] = 1 + sing_dil = binary_dilation(single) + parallel[(sing_dil == dilate_minus) & (sing_dil == 1)] = i + labels = parallel.copy() + + binary = labels.copy() + binary[binary != 0] = 1 + binary += branch_mask + + # get and order coords, then get heights and distances relitive to node centre / highest point + heights = [] + distances = [] + for i in np.unique(labels)[1:]: + trace_img = np.where(labels == i, 1, 0) + trace_img = getSkeleton(img, trace_img, method="zhang").get_skeleton() + trace = order_branch(trace_img, branch_coords[0]) + height_trace = img[trace[:, 0], trace[:, 1]] + dist = nodeStats.coord_dist_rad(trace, centre) # self.coord_dist(trace) + dist, height_trace = nodeStats.average_uniques(dist, height_trace) # needs to be paired with coord_dist_rad + heights.append(height_trace) + distances.append( + dist - dist_zero_point # - 0 + ) # branch_dist[branch_heights.argmax()]) #dist[central_heights.argmax()]) + # Make like coord system using original branch + avg1 = [] + avg2 = [] + for mid_dist in branch_dist_norm: + for i, (distance, height) in enumerate(zip(distances, heights)): + # check if distance already in traces array + if (mid_dist == distance).any(): + idx = np.where(mid_dist == distance) + if i == 0: + avg1.append([mid_dist, height[idx][0]]) + else: + avg2.append([mid_dist, height[idx][0]]) + # if not, linearly interpolate the mid-branch value + else: + # get index after and before the mid branches' x coord + xidxs = nodeStats.above_below_value_idx(distance, mid_dist) + if xidxs is None: + pass # if indexes outside of range, pass + else: + point1 = [distance[xidxs[0]], height[xidxs[0]]] + point2 = [distance[xidxs[1]], height[xidxs[1]]] + y = nodeStats.lin_interp(point1, point2, xvalue=mid_dist) + if i == 0: + avg1.append([mid_dist, y]) + else: + avg2.append([mid_dist, y]) + avg1 = np.asarray(avg1) + avg2 = np.asarray(avg2) + # ensure arrays are same length to average + temp_x = branch_dist_norm[np.isin(branch_dist_norm, avg1[:, 0])] + common_dists = avg2[:, 0][np.isin(avg2[:, 0], temp_x)] + + common_avg_branch_heights = branch_heights[np.isin(branch_dist_norm, common_dists)] + common_avg1_heights = avg1[:, 1][np.isin(avg1[:, 0], common_dists)] + common_avg2_heights = avg2[:, 1][np.isin(avg2[:, 0], common_dists)] + + average_heights = (common_avg_branch_heights + common_avg1_heights + common_avg2_heights) / 3 + return ( + common_dists, + average_heights, + binary, + [[heights[0], branch_heights, heights[1]], [distances[0], branch_dist_norm, distances[1]]], + ) + + @staticmethod + def fill_holes(mask: npt.NDArray) -> npt.NDArray: + """ + Fill all holes within a binary mask. + + Parameters + ---------- + mask : npt.NDArray + Binary array of object. + + Returns + ------- + npt.NDArray + Binary array of object with any interior holes filled in. + """ + inv_mask = np.where(mask != 0, 0, 1) + lbl_inv = label(inv_mask, connectivity=1) + idxs, counts = np.unique(lbl_inv, return_counts=True) + max_idx = idxs[np.argmax(counts)] + return np.where(lbl_inv != max_idx, 1, 0) + + @staticmethod + def _remove_re_entering_branches(mask: npt.NDArray, remaining_branches: int = 1) -> npt.NDArray: + """ + Remove smallest branches which branches exit and re-enter the viewing area. + + Contninues until only remain. + + Parameters + ---------- + mask : npt.NDArray + Skeletonised binary mask of an object. + remaining_branches : int, optional + Number of objects (branches) to keep, by default 1. + + Returns + ------- + npt.NDArray + Mask with only a single skeletonised branch. + """ + rtn_image = mask.copy() + binary_image = mask.copy() + binary_image[binary_image != 0] = 1 + labels = label(binary_image) + + if labels.max() > remaining_branches: + lens = [labels[labels == i].size for i in range(1, labels.max() + 1)] + while len(lens) > remaining_branches: + smallest_idx = min(enumerate(lens), key=lambda x: x[1])[0] + rtn_image[labels == smallest_idx + 1] = 0 + lens.remove(min(lens)) + + return rtn_image + + @staticmethod + def only_centre_branches(node_image: npt.NDArray, node_coordinate: npt.NDArray) -> npt.NDArray[np.int32]: + """ + Remove all branches not connected to the current node. + + Parameters + ---------- + node_image : npt.NDArray + An image of the skeletonised area surrounding the node where + the background = 0, skeleton = 1, termini = 2, nodes = 3. + node_coordinate : npt.NDArray + 2x1 coordinate describing the position of a node. + + Returns + ------- + npt.NDArray[np.int32] + The initial node image but only with skeletal branches + connected to the middle node. + """ + node_image_cp = node_image.copy() + + # get node-only image + nodes = node_image_cp.copy() + nodes[nodes != 3] = 0 + labeled_nodes = label(nodes) + + # find which cluster is closest to the centre + node_coords = np.argwhere(nodes == 3) + min_coords = node_coords[abs(node_coords - node_coordinate).sum(axis=1).argmin()] + centre_idx = labeled_nodes[min_coords[0], min_coords[1]] + + # get nodeless image + nodeless = node_image_cp.copy() + nodeless = np.where( + (node_image == 1) | (node_image == 2), 1, 0 + ) # if termini, need this in the labeled branches too + nodeless[labeled_nodes == centre_idx] = 1 # return centre node + labeled_nodeless = label(nodeless) + + # apply to return image + for i in range(1, labeled_nodeless.max() + 1): + if (node_image_cp[labeled_nodeless == i] == 3).any(): + node_image_cp[labeled_nodeless != i] = 0 + break + + # remove small area around other nodes + labeled_nodes[labeled_nodes == centre_idx] = 0 + non_central_node_coords = np.argwhere(labeled_nodes != 0) + for coord in non_central_node_coords: + for j, coord_val in enumerate(coord): + if coord_val - 1 < 0: + coord[j] = 1 + if coord_val + 2 > node_image_cp.shape[j]: + coord[j] = node_image_cp.shape[j] - 2 + node_image_cp[coord[0] - 1 : coord[0] + 2, coord[1] - 1 : coord[1] + 2] = 0 + + return node_image_cp + + @staticmethod + def average_uniques(arr1: npt.NDArray, arr2: npt.NDArray) -> tuple: + """ + Obtain the unique values of both arrays, and the average of common values. + + Parameters + ---------- + arr1 : npt.NDArray + An array. + arr2 : npt.NDArray + An array. + + Returns + ------- + tuple + The unique values of both arrays, and the averaged common values. + """ + arr1_uniq, index = np.unique(arr1, return_index=True) + arr2_new = np.zeros_like(arr1_uniq).astype(np.float64) + for i, val in enumerate(arr1[index]): + mean = arr2[arr1 == val].mean() + arr2_new[i] += mean + + return arr1[index], arr2_new + + @staticmethod + def average_crossing_confs(node_dict) -> None | float: + """ + Return the average crossing confidence of all crossings in the molecule. + + Parameters + ---------- + node_dict : dict + A dictionary containing node statistics and information. + + Returns + ------- + Union[None, float] + The value of minimum confidence or none if not possible. + """ + sum_conf = 0 + valid_confs = 0 + for _, (_, values) in enumerate(node_dict.items()): + confidence = values["confidence"] + if confidence is not None: + sum_conf += confidence + valid_confs += 1 + try: + return sum_conf / valid_confs + except ZeroDivisionError: + return None + + @staticmethod + def minimum_crossing_confs(node_dict: dict) -> None | float: + """ + Return the minimum crossing confidence of all crossings in the molecule. + + Parameters + ---------- + node_dict : dict + A dictionary containing node statistics and information. + + Returns + ------- + Union[None, float] + The value of minimum confidence or none if not possible. + """ + confidences = [] + valid_confs = 0 + for _, (_, values) in enumerate(node_dict.items()): + confidence = values["confidence"] + if confidence is not None: + confidences.append(confidence) + valid_confs += 1 + try: + return min(confidences) + except ValueError: + return None + + def compile_metrics(self) -> None: + """Add the number of crossings, average and minimum crossing confidence to the metrics dictionary.""" + self.metrics["num_crossings"] = (self.node_centre_mask == 3).sum() + self.metrics["avg_crossing_confidence"] = nodeStats.average_crossing_confs(self.node_dicts) + self.metrics["min_crossing_confidence"] = nodeStats.minimum_crossing_confs(self.node_dicts) + + +def nodestats_image( + image: npt.NDArray, + disordered_tracing_direction_data: dict, + filename: str, + pixel_to_nm_scaling: float, + node_joining_length: float, + node_extend_dist: float, + branch_pairing_length: float, + pair_odd_branches: float, + pad_width: int, +) -> tuple: + """ + Initialise the nodeStats class. + + Parameters + ---------- + image : npt.NDArray + The array of pixels. + disordered_tracing_direction_data : dict + The images and bbox coordinates of the pruned skeletons. + filename : str + The name of the file being processed. For logging purposes. + pixel_to_nm_scaling : float + The pixel to nm scaling factor. + node_joining_length : float + The length over which to join skeletal intersections to be counted as one crossing. + node_joining_length : float + The distance over which to join nearby odd-branched nodes. + node_extend_dist : float + The distance under which to join odd-branched node regions. + branch_pairing_length : float + The length from the crossing point to pair and trace, obtaining FWHM's. + pair_odd_branches : bool + Whether to try and pair odd-branched nodes. + pad_width : int + The number of edge pixels to pad the image by. + + Returns + ------- + tuple[dict, pd.DataFrame, dict, dict] + The nodestats statistics for each crossing, crossing statistics to be added to the grain statistics, + an image dictionary of nodestats steps for the entire image, and single grain images. + """ + n_grains = len(disordered_tracing_direction_data) + img_base = np.zeros_like(image) + nodestats_data = {} + + # want to get each cropped image, use some anchor coords to match them onto the image, + # and compile all the grain images onto a single image + all_images = { + "convolved_skeletons": img_base.copy(), + "node_centres": img_base.copy(), + "connected_nodes": img_base.copy(), + } + nodestats_branch_images = {} + grainstats_additions = {} + + LOGGER.info(f"[{filename}] : Calculating NodeStats statistics for {n_grains} grains...") + + for n_grain, disordered_tracing_grain_data in disordered_tracing_direction_data.items(): + nodestats = None # reset the nodestats variable + try: + nodestats = nodeStats( + image=disordered_tracing_grain_data["original_image"], + mask=disordered_tracing_grain_data["original_grain"], + smoothed_mask=disordered_tracing_grain_data["smoothed_grain"], + skeleton=disordered_tracing_grain_data["pruned_skeleton"], + pixel_to_nm_scaling=pixel_to_nm_scaling, + filename=filename, + n_grain=n_grain, + node_joining_length=node_joining_length, + node_extend_dist=node_extend_dist, + branch_pairing_length=branch_pairing_length, + pair_odd_branches=pair_odd_branches, + ) + nodestats_dict, node_image_dict = nodestats.get_node_stats() + LOGGER.debug(f"[{filename}] : Nodestats processed {n_grain} of {n_grains}") + + # compile images + nodestats_images = { + "convolved_skeletons": nodestats.conv_skelly, + "node_centres": nodestats.node_centre_mask, + "connected_nodes": nodestats.connected_nodes, + } + nodestats_branch_images[n_grain] = node_image_dict + + # compile metrics + grainstats_additions[n_grain] = { + "image": filename, + "grain_number": int(n_grain.split("_")[-1]), + } + grainstats_additions[n_grain].update(nodestats.metrics) + if nodestats_dict: # if the grain's nodestats dict is not empty + nodestats_data[n_grain] = nodestats_dict + + # remap the cropped images back onto the original + for image_name, full_image in all_images.items(): + crop = nodestats_images[image_name] + bbox = disordered_tracing_grain_data["bbox"] + full_image[bbox[0] : bbox[2], bbox[1] : bbox[3]] += crop[pad_width:-pad_width, pad_width:-pad_width] + + except Exception as e: # pylint: disable=broad-exception-caught + LOGGER.error( + f"[{filename}] : Nodestats for {n_grain} failed. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) + nodestats_data[n_grain] = {} + + # turn the grainstats additions into a dataframe, # might need to do something for when everything is empty + grainstats_additions_df = pd.DataFrame.from_dict(grainstats_additions, orient="index") + + return nodestats_data, grainstats_additions_df, all_images, nodestats_branch_images diff --git a/topostats/tracing/ordered_tracing.py b/topostats/tracing/ordered_tracing.py new file mode 100644 index 0000000000..8bca508c8a --- /dev/null +++ b/topostats/tracing/ordered_tracing.py @@ -0,0 +1,975 @@ +"""Order single pixel skeletons with or without NodeStats Statistics.""" + +from __future__ import annotations + +import importlib.metadata +import logging +from itertools import combinations + +import numpy as np +import numpy.typing as npt +import pandas as pd +from skimage.morphology import binary_dilation, label +from topoly import jones, translate_code + +from topostats.logs.logs import LOGGER_NAME +from topostats.tracing.tracingfuncs import coord_dist, genTracingFuncs, order_branch, reorderTrace +from topostats.utils import convolve_skeleton, coords_2_img + +LOGGER = logging.getLogger(LOGGER_NAME) + + +class OrderedTraceNodestats: # pylint: disable=too-many-instance-attributes + """ + Order single pixel thick skeleton coordinates via NodeStats results. + + Parameters + ---------- + image : npt.NDArray + A cropped image array. + nodestats_dict : dict + The nodestats results for a specific grain. + skeleton : npt.NDArray + The pruned skeleton mask array. + filename : str + The image filename (for logging purposes). + """ + + def __init__( + self, + image: npt.NDArray, + nodestats_dict: dict, + skeleton: npt.NDArray, + filename: str, + ) -> None: + """ + Initialise the OrderedTraceNodestats class. + + Parameters + ---------- + image : npt.NDArray + A cropped image array. + nodestats_dict : dict + The nodestats results for a specific grain. + skeleton : npt.NDArray + The pruned skeleton mask array. + filename : str + The image filename (for logging purposes). + """ + self.image = image + self.nodestats_dict = nodestats_dict + self.filename = filename + self.skeleton = skeleton + + self.grain_tracing_stats = { + "num_mols": 0, + "circular": None, + } + self.mol_tracing_stats = {"circular": None, "topology": None, "topology_flip": None, "processing": "nodestats"} + + self.images = { + "over_under": np.zeros_like(image), + "all_molecules": np.zeros_like(image), + "ordered_traces": np.zeros_like(image), + "trace_segments": np.zeros_like(image), + } + + self.profiles = {} + + self.img_idx_to_node = {} + + self.ordered_coordinates = [] + + # pylint: disable=too-many-locals + # pylint: disable=too-many-branches + def compile_trace(self, reverse_min_conf_crossing: bool = False) -> tuple[list, npt.NDArray]: # noqa: C901 + """ + Obtain the trace and diagnostic crossing trace and molecule trace images. + + This function uses the branches and full-width half-maximums (FWHMs) identified in the node_stats dictionary + to create a continuous trace of the molecule. + + Parameters + ---------- + reverse_min_conf_crossing : bool + Whether to reverse the stacking order of the lowest confidence crossing in the trace. + + Returns + ------- + tuple[list, npt.NDArray] + A list of each complete path's ordered coordinates, and labeled crossing image array. + """ + # iterate through the dict to get branch coords, heights and fwhms + node_coords = [ + [stats["node_coords"] for branch_stats in stats["branch_stats"].values() if branch_stats["fwhm"]["fwhm"]] + for stats in self.nodestats_dict.values() + ] + node_coords = [lst for lst in node_coords if lst] + + crossing_coords = [ + [ + branch_stats["ordered_coords"] + for branch_stats in stats["branch_stats"].values() + if branch_stats["fwhm"]["fwhm"] + ] + for stats in self.nodestats_dict.values() + ] + crossing_coords = [lst for lst in crossing_coords if lst] + + fwhms = [ + [ + branch_stats["fwhm"]["fwhm"] + for branch_stats in stats["branch_stats"].values() + if branch_stats["fwhm"]["fwhm"] + ] + for stats in self.nodestats_dict.values() + ] + fwhms = [lst for lst in fwhms if lst] + + confidences = [stats["confidence"] for stats in self.nodestats_dict.values()] + + # obtain the index of the underlying branch + try: + low_conf_idx = np.nanargmin(np.array(confidences, dtype=float)) + except ValueError: # when no crossings or only 3-branch crossings + low_conf_idx = None + + # Get the image minus the crossing regions + nodes = np.zeros_like(self.skeleton) + for node_no in node_coords: # this stops unpaired branches from interacting with the pairs + nodes[node_no[0][:, 0], node_no[0][:, 1]] = 1 + minus = np.where(binary_dilation(binary_dilation(nodes)) == self.skeleton, 0, self.skeleton) + # remove crossings from skeleton + for crossings in crossing_coords: + for crossing in crossings: + minus[crossing[:, 0], crossing[:, 1]] = 0 + minus = label(minus) + + # setup z array + z = [] + # order minus segments + ordered = [] + for non_cross_segment_idx in range(1, minus.max() + 1): + arr = np.where(minus, minus == non_cross_segment_idx, 0) + ordered.append(order_branch(arr, [0, 0])) # orientated later + z.append(0) + self.img_idx_to_node[non_cross_segment_idx] = {} + + # add crossing coords to ordered segment list + uneven_count = non_cross_segment_idx + 1 + for node_num, node_crossing_coords in enumerate(crossing_coords): + z_idx = np.argsort(fwhms[node_num]) + z_idx[z_idx == 0] = -1 + if reverse_min_conf_crossing and low_conf_idx == node_num: + z_idx = z_idx[::-1] + fwhms[node_num] = fwhms[node_num][::-1] + for node_cross_idx, single_cross in enumerate(node_crossing_coords): + # check current single cross has no duplicate coords with ordered, except crossing points + uncommon_single_cross = np.array(single_cross).copy() + for coords in ordered: + uncommon_single_cross = self.remove_common_values( + uncommon_single_cross, np.array(coords), retain=node_coords[node_num][node_cross_idx] + ) + if len(uncommon_single_cross) > 0: + ordered.append(uncommon_single_cross) + z.append(z_idx[node_cross_idx]) + self.img_idx_to_node[uneven_count + node_cross_idx] = { + "node_idx": node_num, + "coords": single_cross, + "z_idx": z_idx[node_cross_idx], + } + uneven_count += len(node_crossing_coords) + + # get an image of each ordered segment + cross_add = np.zeros_like(self.image) + for i, coords in enumerate(ordered): + single_cross_img = coords_2_img(np.array(coords), cross_add) + cross_add[single_cross_img != 0] = i + 1 + + coord_trace, simple_trace = self.trace(ordered, cross_add, z, n=100) + + # obtain topology from the simple trace + topology = self.get_topology(simple_trace) + if reverse_min_conf_crossing and low_conf_idx is None: # when there's nothing to reverse + topology = [None for _ in enumerate(topology)] + + return coord_trace, topology, cross_add, crossing_coords, fwhms + + def compile_images(self, coord_trace: list, cross_add: npt.NDArray, crossing_coords: list, fwhms: list) -> None: + """ + Obtain all the diagnostic images based on the produced traces, and values. + + Crossing coords and fwhms are used as arguments as reversing the minimum confidence can modify these. + + Parameters + ---------- + coord_trace : list + List of N molecule objects containing 2xM arrays of X, Y coordinates. + cross_add : npt.NDArray + A labelled array with segments of the ordered trace. + crossing_coords : list + A list of I nodes objects containing 2xJ arrays of X, Y coordinates for each crossing branch. + fwhms : list + A list of I nodes objects containing FWHM values for each crossing branch. + """ + # visual over under img + self.images["trace_segments"] = cross_add + try: + self.images["over_under"] = self.get_over_under_img(coord_trace, fwhms, crossing_coords) + self.images["all_molecules"] = self.get_mols_img(coord_trace, fwhms, crossing_coords) + except IndexError: + pass + self.images["ordered_traces"] = ordered_trace_mask(coord_trace, self.image.shape) + + @staticmethod + def remove_common_values( + ordered_array: npt.NDArray, common_value_check_array: npt.NDArray, retain: list = () + ) -> np.array: + """ + Remove common values in common_value_check_array from ordered_array while retaining specified coordinates. + + Parameters + ---------- + ordered_array : npt.NDArray + Coordinate array to remove / retain values from. Will retain its order. + common_value_check_array : npt.NDArray + Coordinate array containing any common values to be removed from ordered_array. + retain : list, optional + List of possible coordinates to keep, by default (). + + Returns + ------- + np.array + Unique ordered_array values and retained coordinates. Retains the order of ordered_array. + """ + # Convert the arrays to sets for faster common value lookup + set_arr2 = {tuple(row) for row in common_value_check_array} + set_retain = {tuple(row) for row in retain} + # Create a new filtered list while maintaining the order of the first array + filtered_arr1 = [] + for coord in ordered_array: + tup_coord = tuple(coord) + if tup_coord not in set_arr2 or tup_coord in set_retain: + filtered_arr1.append(coord) + + return np.asarray(filtered_arr1) + + def get_topology(self, nxyz: npt.NDArray) -> list: + """ + Obtain a topological classification from ordered XYZ coordinates. + + Parameters + ---------- + nxyz : npt.NDArray + A 4xN array of the order index (n), x, y and pseudo z coordinates. + + Returns + ------- + list + Topology(s) of the provided traced coordinates. + """ + # Topoly doesn't work when 2 mols don't actually cross + topology = [] + lin_idxs = [] + nxyz_cp = nxyz.copy() + # remove linear mols as are just reidmiester moves + for i, mol_trace in enumerate(nxyz): + if mol_trace[-1][0] != 0: # mol is not circular + topology.append("linear") + lin_idxs.append(i) + # remove from list in reverse order so no conflicts + lin_idxs.sort(reverse=True) + for i in lin_idxs: + del nxyz_cp[i] + # classify topology for non-reidmeister moves + if len(nxyz_cp) != 0: + try: + pd_code = translate_code( + nxyz_cp, output_type="pdcode" + ) # pd code helps prevents freezing and spawning multiple processes + LOGGER.debug(f"{self.filename} : PD Code is: {pd_code}") + top_class = jones(pd_code) + except (IndexError, KeyError): + LOGGER.debug(f"{self.filename} : PD Code could not be obtained from trace coordinates.") + top_class = "N/A" + + # don't separate catenanes / overlaps - used for distribution comparison + for _ in range(len(nxyz_cp)): + topology.append(top_class) + + return topology + + def trace(self, ordered_segment_coords: list, both_img: npt.NDArray, zs: npt.NDArray, n: int = 100) -> list: + # pylint: disable=too-many-locals + """ + Obtain an ordered trace of each complete path. + + Here a 'complete path' means following and removing connected segments until + there are no more segments to follow. + + Parameters + ---------- + ordered_segment_coords : list + Ordered coordinates of each labeled segment in 'both_img'. + both_img : npt.NDArray + A skeletonised labeled image of each path segment. + zs : npt.NDArray + Array of pseudo heights of the traces. -1 is lowest, 0 is skeleton, then ascending integers for + levels of overs. + n : int + The number of points to use for the simplified traces. + + Returns + ------- + list + Ordered trace coordinates of each complete path. + """ + mol_coords = [] + simple_coords = [] + remaining = both_img.copy().astype(np.int32) + endpoints = np.unique(remaining[convolve_skeleton(remaining.astype(bool)) == 2]) # unique in case of whole mol + prev_segment = None + n_points_p_seg = (n - 2 * remaining.max()) // remaining.max() + + while remaining.max() != 0: + # select endpoint to start if there is one + endpoints = [i for i in endpoints if i in np.unique(remaining)] # remove if removed from remaining + if endpoints: + coord_idx = endpoints.pop(0) - 1 + else: # if no endpoints, just a loop + coord_idx = np.unique(remaining)[1] - 1 # avoid choosing 0 + coord_trace = np.empty((0, 3)).astype(np.int32) + simple_trace = np.empty((0, 3)).astype(np.int32) + + while coord_idx > -1: # either cycled through all or hits terminus -> all will be just background + remaining[remaining == coord_idx + 1] = 0 + trace_segment = self.get_trace_segment(remaining, ordered_segment_coords, coord_idx) + full_trace_segment = trace_segment.copy() + if len(coord_trace) > 0: # can only order when there's a reference point / segment + trace_segment = self.remove_common_values( + trace_segment, prev_segment + ) # remove overlaps in trace (may be more efficient to do it on the previous segment) + trace_segment, flipped = self.order_from_end(coord_trace[-1, :2], trace_segment) + full_trace_segment = full_trace_segment[::-1] if flipped else full_trace_segment + # get vector if crossing + if self.img_idx_to_node[coord_idx + 1]: + segment_vector = full_trace_segment[-1] - full_trace_segment.mean( + axis=0 + ) # from start to mean coord + segment_vector /= np.sqrt(segment_vector @ segment_vector) # normalise + self.img_idx_to_node[coord_idx + 1]["vector"] = segment_vector + prev_segment = trace_segment.copy() # update previous segment + trace_segment_z = np.column_stack( + (trace_segment, np.ones((len(trace_segment), 1)) * zs[coord_idx]) + ).astype( + np.int32 + ) # add z's + coord_trace = np.append(coord_trace, trace_segment_z.astype(np.int32), axis=0) + + # obtain a reduced coord version of the traces for Topoly + simple_trace_temp = self.reduce_rows( + trace_segment.astype(np.int32), n=n_points_p_seg + ) # reducing rows here ensures no segments are skipped + simple_trace_temp_z = np.column_stack( + (simple_trace_temp, np.ones((len(simple_trace_temp), 1)) * zs[coord_idx]) + ).astype( + np.int32 + ) # add z's + simple_trace = np.append(simple_trace, simple_trace_temp_z, axis=0) + + x, y = coord_trace[-1, :2] + coord_idx = remaining[x - 1 : x + 2, y - 1 : y + 2].max() - 1 # should only be one value + mol_coords.append(coord_trace) + + # Issue in 0_5 where wrong nxyz[0] selected, and == nxyz[-1] so always duplicated + nxyz = np.column_stack((np.arange(0, len(simple_trace)), simple_trace)) + end_to_end_dist_squared = (nxyz[0][1] - nxyz[-1][1]) ** 2 + (nxyz[0][2] - nxyz[-1][2]) ** 2 + if len(nxyz) > 2 and end_to_end_dist_squared <= 2: # pylint: disable=chained-comparison + # single coord traces mean nxyz[0]==[1] so cause issues when duplicating for topoly + nxyz = np.append(nxyz, nxyz[0][np.newaxis, :], axis=0) + simple_coords.append(nxyz) + + # convert into lists for Topoly + simple_coords = [[list(row) for row in mol] for mol in simple_coords] + + return mol_coords, simple_coords + + @staticmethod + def reduce_rows(array: npt.NDArray, n: int = 300) -> npt.NDArray: + """ + Reduce the number of rows in the array to `n`, keeping the first and last indexes. + + Parameters + ---------- + array : npt.NDArray + An array to reduce the number of rows in. + n : int, optional + The number of indexes in the array to keep, by default 300. + + Returns + ------- + npt.NDArray + The `array` reduced to only `n` + 2 elements, or if shorter, the same array. + """ + # removes reduces the number of rows (but keeping the first and last ones) + if array.shape[0] < n or array.shape[0] < 4: + return array + + idxs_to_keep = np.unique(np.linspace(0, array[1:-1].shape[0] - 1, n).astype(np.int32)) + new_array = array[1:-1][idxs_to_keep] + new_array = np.append(array[0][np.newaxis, :], new_array, axis=0) + return np.append(new_array, array[-1][np.newaxis, :], axis=0) + + @staticmethod + def get_trace_segment(remaining_img: npt.NDArray, ordered_segment_coords: list, coord_idx: int) -> npt.NDArray: + """ + Return an ordered segment at the end of the current one. + + Check the branch of given index to see if it contains an endpoint. If it does, + the segment coordinates will be returned starting from the endpoint. + + Parameters + ---------- + remaining_img : npt.NDArray + A 2D array representing an image composed of connected segments of different integers. + ordered_segment_coords : list + A list of 2xN coordinates representing each segment. + coord_idx : int + The index of the current segment to look at. There is an index mismatch between the + remaining_img and ordered_segment_coords by -1. + + Returns + ------- + npt.NDArray + 2xN array of coordinates representing a skeletonised ordered trace segment. + """ + start_xy = ordered_segment_coords[coord_idx][0] + start_max = remaining_img[start_xy[0] - 1 : start_xy[0] + 2, start_xy[1] - 1 : start_xy[1] + 2].max() - 1 + if start_max == -1: + return ordered_segment_coords[coord_idx] # start is endpoint + return ordered_segment_coords[coord_idx][::-1] # end is endpoint + + @staticmethod + def order_from_end(last_segment_coord: npt.NDArray, current_segment: npt.NDArray) -> npt.NDArray: + """ + Order the current segment to follow from the end of the previous one. + + Parameters + ---------- + last_segment_coord : npt.NDArray + X and Y coordinates of the end of the last segment. + current_segment : npt.NDArray + A 2xN array of coordinates of the current segment to order. + + Returns + ------- + npt.NDArray + The current segment orientated to follow on from the last. + bool + Whether the order has been flipped. + """ + start_xy = current_segment[0] + dist = np.sum((start_xy - last_segment_coord) ** 2) ** 0.5 + if dist <= np.sqrt(2): + return current_segment, False + return current_segment[::-1], True + + def get_over_under_img(self, coord_trace: list, fwhms: list, crossing_coords: list) -> npt.NDArray: + """ + Obtain a labelled image according to the main trace (=1), under (=2), over (=3). + + Parameters + ---------- + coord_trace : list + Ordered coordinate trace of each molecule. + fwhms : list + List of full-width half-maximums (FWHMs) for each crossing in the trace. + crossing_coords : list + The crossing coordinates of each branch crossing. + + Returns + ------- + npt.NDArray + 2D crossing order labelled image. + """ + # put down traces + img = np.zeros_like(self.skeleton) + for coords in coord_trace: + temp_img = np.zeros_like(img) + temp_img[coords[:, 0], coords[:, 1]] = 1 + # temp_img = binary_dilation(temp_img) + img[temp_img != 0] = 1 + + # place over/under strands onto image array + lower_idxs, upper_idxs = self.get_trace_idxs(fwhms) + for i, type_idxs in enumerate([lower_idxs, upper_idxs]): + for crossing, type_idx in zip(crossing_coords, type_idxs): + temp_img = np.zeros_like(img) + cross_coords = crossing[type_idx] + temp_img[cross_coords[:, 0], cross_coords[:, 1]] = 1 + # temp_img = binary_dilation(temp_img) + img[temp_img != 0] = i + 2 + + return img + + # pylint: disable=too-many-locals + def get_mols_img(self, coord_trace: list, fwhms: list, crossing_coords: list) -> npt.NDArray: + # pylint: disable=too-many-locals + """ + Obtain a labelled image according to each molecule traced N=3 -> n=1,2,3. + + Parameters + ---------- + coord_trace : list + Ordered coordinate trace of each molecule. + fwhms : list + List of full-width half-maximums (FWHMs) for each crossing in the trace. + crossing_coords : list + The crossing coordinates of each branch crossing. + + Returns + ------- + npt.NDArray + 2D individual 'molecule' labelled image. + """ + img = np.zeros_like(self.skeleton) + for mol_no, coords in enumerate(coord_trace): + temp_img = np.zeros_like(img) + temp_img[coords[:, 0], coords[:, 1]] = 1 + img[temp_img != 0] = mol_no + 1 + lower_idxs, upper_idxs = self.get_trace_idxs(fwhms) + + # plot separate mols + for type_idxs in [lower_idxs, upper_idxs]: + for node_crossing_coords, type_idx in zip(crossing_coords, type_idxs): + temp_img = np.zeros_like(img) + cross_coords = node_crossing_coords[type_idx] + # decide which val + matching_coords = np.array([]) + for trace in coord_trace: + c = 0 + # get overlaps between segment coords and crossing under coords + for cross_coord in cross_coords: + c += ((trace[:, :2] == cross_coord).sum(axis=1) == 2).sum() + matching_coords = np.append(matching_coords, c) + val = matching_coords.argmax() + 1 + temp_img[cross_coords[:, 0], cross_coords[:, 1]] = 1 + img[temp_img != 0] = val + + return img + + @staticmethod + def get_trace_idxs(fwhms: list) -> tuple[list, list]: + """ + Split under-passing and over-passing indices. + + Parameters + ---------- + fwhms : list + List of arrays of full-width half-maximum (FWHM) values for each crossing point. + + Returns + ------- + tuple[list, list] + All the under, and over indices of the for each node FWHMs in the provided FWHM list. + """ + # node fwhms can be a list of different lengths so cannot use np arrays + under_idxs = [] + over_idxs = [] + for node_fwhms in fwhms: + order = np.argsort(node_fwhms) + under_idxs.append(order[0]) + over_idxs.append(order[-1]) + return under_idxs, over_idxs + + def check_node_errorless(self) -> bool: + """ + Check if an error has occurred while processing the node dictionary. + + Returns + ------- + bool + Whether the error is present. + """ + for vals in self.nodestats_dict.values(): + if vals["error"]: + return False + return True + + def identify_writhes(self) -> str | dict: + """ + Identify the writhe topology at each crossing in the image. + + Returns + ------- + str | dict + A string of the whole grain writhe sign, and a dictionary linking each node to it's sign. + """ + # compile all vectors for each node and their z_idx + # - want for each node, ordered vectors according to z_idx + writhe_string = "" + node_to_writhe = {} + idx2node_df = pd.DataFrame.from_dict(self.img_idx_to_node, orient="index") + if idx2node_df.empty: # for when no crossovers but still crossings (i.e. unpaired 3-way) + return "", {} + + for node_num, node_df in idx2node_df.groupby("node_idx"): + vector_series = node_df.sort_values(by=["z_idx"], ascending=False)["vector"] + vectors = list(vector_series) + # get pairs + vector_combinations = list(combinations(vectors, 2)) + # calculate the writhe + temp_writhes = "" + for vector_pair in vector_combinations: # if > 2 crossing branches + temp_writhes += self.writhe_direction(vector_pair[0], vector_pair[1]) + if len(temp_writhes) > 1: + temp_writhes = f"({temp_writhes})" + node_to_writhe[node_num] = temp_writhes + writhe_string += temp_writhes + + return writhe_string, node_to_writhe + + @staticmethod + def writhe_direction(first_vector: npt.NDArray, second_vector: npt.NDArray) -> str: + """ + Use the cross product of crossing vectors to determine the writhe sign. + + Parameters + ---------- + first_vector : npt.NDArray + An x,y component vector of the overlying strand. + second_vector : npt.NDArray + An x,y component vector of the underlying strand. + + Returns + ------- + str + '+', '-' or '0' for positive, negative, or no writhe. + """ + cross = np.cross(first_vector, second_vector) + if cross < 0: + return "-" + if cross > 0: + return "+" + return "0" + + def run_nodestats_tracing(self) -> tuple[list, dict, dict]: + """ + Run the nodestats tracing pipeline. + + Returns + ------- + tuple[list, dict, dict] + A list of each molecules ordered trace coordinates, the ordered_tracing stats, and the images. + """ + ordered_traces, topology, cross_add, crossing_coords, fwhms = self.compile_trace( + reverse_min_conf_crossing=False + ) + self.compile_images(ordered_traces, cross_add, crossing_coords, fwhms) + self.grain_tracing_stats["num_mols"] = len(ordered_traces) + + writhe_string, node_to_writhes = self.identify_writhes() + self.grain_tracing_stats["writhe_string"] = writhe_string + for node_num, node_writhes in node_to_writhes.items(): # should self update as the dicts are linked + self.nodestats_dict[f"node_{node_num+1}"]["writhe"] = node_writhes + + topology_flip = self.compile_trace(reverse_min_conf_crossing=True)[1] + + ordered_trace_data = {} + grain_mol_tracing_stats = {} + for i, mol_trace in enumerate(ordered_traces): + if len(mol_trace) > 3: # if > 4 coords to trace + np.save(f"trace_xyz_{i}", mol_trace) + self.mol_tracing_stats["circular"] = linear_or_circular(mol_trace[:, :2]) + self.mol_tracing_stats["topology"] = topology[i] + self.mol_tracing_stats["topology_flip"] = topology_flip[i] + ordered_trace_data[f"mol_{i}"] = { + "ordered_coords": mol_trace[:, :2], + "heights": self.image[mol_trace[:, 0], mol_trace[:, 1]], + "distances": coord_dist(mol_trace[:, :2]), + "mol_stats": self.mol_tracing_stats, + } + grain_mol_tracing_stats[f"{i}"] = self.mol_tracing_stats + + return ordered_trace_data, self.grain_tracing_stats, grain_mol_tracing_stats, self.images + + +class OrderedTraceTopostats: + """ + Order single pixel thick skeleton coordinates via TopoStats. + + Parameters + ---------- + image : npt.NDArray + A cropped image array. + skeleton : npt.NDArray + The pruned skeleton mask array. + """ + + def __init__( + self, + image, + skeleton, + ) -> None: + """ + Initialise the OrderedTraceTopostats class. + + Parameters + ---------- + image : npt.NDArray + A cropped image array. + skeleton : npt.NDArray + The pruned skeleton mask array. + """ + self.image = image + self.skeleton = skeleton + self.grain_tracing_stats = { + "num_mols": 1, + "circular": None, + } + self.mol_tracing_stats = {"circular": None, "topology": None, "topology_flip": None, "processing": "topostats"} + + self.images = { + "ordered_traces": np.zeros_like(image), + "all_molecules": skeleton.copy(), + "over_under": skeleton.copy(), + "trace_segments": skeleton.copy(), + } + + @staticmethod + def get_ordered_traces(disordered_trace_coords: npt.NDArray, mol_is_circular: bool) -> list: + """ + Obtain ordered traces from disordered traces. + + Parameters + ---------- + disordered_trace_coords : npt.NDArray + A Nx2 array of coordinates to order. + mol_is_circular : bool + A flag of whether the molecule has at least one coordinate with only one neighbour. + + Returns + ------- + list + A list of each molecules ordered trace coordinates. + """ + if mol_is_circular: + ordered_trace, trace_completed = reorderTrace.circularTrace(disordered_trace_coords) + + if not trace_completed: + mol_is_circular = False + try: + ordered_trace = reorderTrace.linearTrace(ordered_trace) + except UnboundLocalError: + pass + + elif not mol_is_circular: + ordered_trace = reorderTrace.linearTrace(disordered_trace_coords) + + return [ordered_trace] + + def run_topostats_tracing(self) -> tuple[list, dict, dict]: + """ + Run the topostats tracing pipeline. + + Returns + ------- + tuple[list, dict, dict] + A list of each molecules ordered trace coordinates, the ordered_traicing stats, and the images. + """ + disordered_trace_coords = np.argwhere(self.skeleton == 1) + + self.mol_tracing_stats["circular"] = linear_or_circular(disordered_trace_coords) + self.mol_tracing_stats["topology"] = "0_1" if self.mol_tracing_stats["circular"] else "linear" + + ordered_trace = self.get_ordered_traces(disordered_trace_coords, self.mol_tracing_stats["circular"]) + + self.images["ordered_traces"] = ordered_trace_mask(ordered_trace, self.image.shape) + + ordered_trace_data = {} + for i, mol_trace in enumerate(ordered_trace): + ordered_trace_data[f"mol_{i}"] = { + "ordered_coords": mol_trace, + "heights": self.image[ordered_trace[0][:, 0], ordered_trace[0][:, 1]], + "distances": coord_dist(ordered_trace[0]), + "mol_stats": self.mol_tracing_stats, + } + + return ordered_trace_data, self.grain_tracing_stats, {"0": self.mol_tracing_stats}, self.images + + +def linear_or_circular(traces) -> bool: + """ + Determine whether the molecule is circular or linear via >1 points in the local start area. + + This function is sensitive to branches from the skeleton because it is based on whether any given point has zero + neighbours or not so the traces should be pruned. + + Parameters + ---------- + traces : npt.NDArray + The array of coordinates to be assessed. + + Returns + ------- + bool + Whether a molecule is linear or not (True if linear, False otherwise). + """ + points_with_one_neighbour = 0 + fitted_trace_list = traces.tolist() + + # For loop determines how many neighbours a point has - if only one it is an end + for x, y in fitted_trace_list: + if genTracingFuncs.count_and_get_neighbours(x, y, fitted_trace_list)[0] == 1: + points_with_one_neighbour += 1 + else: + pass + + if points_with_one_neighbour == 0: + return True + return False + + +def ordered_trace_mask(ordered_coordinates: npt.NDArray, shape: tuple) -> npt.NDArray: + """ + Obtain a mask of the trace coordinates with each trace pixel. + + Parameters + ---------- + ordered_coordinates : npt.NDArray + Ordered array of coordinates. + + shape : tuple + The shape of the array bounding the coordinates. + + Returns + ------- + npt.NDArray + NxM image with each pixel in the ordered trace labeled in ascending order. + """ + ordered_mask = np.zeros(shape) + if isinstance(ordered_coordinates, list): + for mol_coords in ordered_coordinates: + ordered_mask[mol_coords[:, 0], mol_coords[:, 1]] = np.arange(len(mol_coords)) + 1 + + return ordered_mask + + +# pylint: disable=too-many-locals +def ordered_tracing_image( + image: npt.NDArray, + disordered_tracing_direction_data: dict, + nodestats_direction_data: dict, + filename: str, + ordering_method: str, + pad_width: int, +) -> tuple[dict, pd.DataFrame, pd.DataFrame, dict]: + # pylint: disable=too-many-locals + """ + Run ordered tracing for an entire image of >=1 grains. + + Parameters + ---------- + image : npt.NDArray + Whole FOV image. + disordered_tracing_direction_data : dict + Dictionary result from the disordered traces. Fields used are "original_image" and "pruned_skeleton". + nodestats_direction_data : dict + Dictionary result from the nodestats analysis. + filename : str + Image filename (for logging purposes). + ordering_method : str + The method to order the trace coordinates - "topostats" or "nodestats". + pad_width : int + Width to pad the images by. + + Returns + ------- + tuple[dict, pd.DataFrame, pd.DataFrame, dict] + Results containing the ordered_trace_data (coordinates), any grain-level metrics to be added to the grains + dataframe, a dataframe of molecule statistics and a dictionary of diagnostic images. + """ + topoly_version = importlib.metadata.version("topoly") + print(f"Topoly version: {topoly_version}") + ordered_trace_full_images = { + "ordered_traces": np.zeros_like(image), + "all_molecules": np.zeros_like(image), + "over_under": np.zeros_like(image), + "trace_segments": np.zeros_like(image), + } + grainstats_additions = {} + molstats = {} + all_traces_data = {} + + LOGGER.info( + f"[{filename}] : Calculating Ordered Traces and statistics for " + + f"{len(disordered_tracing_direction_data)} grains..." + ) + + # iterate through disordered_tracing_dict + for grain_no, disordered_trace_data in disordered_tracing_direction_data.items(): + try: + # check if want to do nodestats tracing or not + if grain_no in list(nodestats_direction_data["stats"].keys()) and ordering_method == "nodestats": + LOGGER.debug(f"[{filename}] : Grain {grain_no} present in NodeStats. Tracing via Nodestats.") + nodestats_tracing = OrderedTraceNodestats( + image=nodestats_direction_data["images"][grain_no]["grain"]["grain_image"], + filename=filename, + nodestats_dict=nodestats_direction_data["stats"][grain_no], + skeleton=nodestats_direction_data["images"][grain_no]["grain"]["grain_skeleton"], + ) + if nodestats_tracing.check_node_errorless(): + ordered_traces_data, tracing_stats, grain_molstats, images = ( + nodestats_tracing.run_nodestats_tracing() + ) + LOGGER.debug(f"[{filename}] : Grain {grain_no} ordered via NodeStats.") + else: + LOGGER.debug(f"Nodestats dict has an error ({nodestats_direction_data['stats'][grain_no]['error']}") + # if not doing nodestats ordering, do original TS ordering + else: + LOGGER.debug(f"[{filename}] : {grain_no} not in NodeStats. Tracing normally.") + topostats_tracing = OrderedTraceTopostats( + image=disordered_trace_data["original_image"], + skeleton=disordered_trace_data["pruned_skeleton"], + ) + ordered_traces_data, tracing_stats, grain_molstats, images = topostats_tracing.run_topostats_tracing() + LOGGER.debug(f"[{filename}] : Grain {grain_no} ordered via TopoStats.") + + # compile traces + all_traces_data[grain_no] = ordered_traces_data + for mol_no, _ in ordered_traces_data.items(): + all_traces_data[grain_no][mol_no].update({"bbox": disordered_trace_data["bbox"]}) + # compile metrics + grainstats_additions[grain_no] = { + "image": filename, + "grain_number": int(grain_no.split("_")[-1]), + } + tracing_stats.pop("circular") + grainstats_additions[grain_no].update(tracing_stats) + # compile molecule metrics + for mol_no, molstat_values in grain_molstats.items(): + molstats[f"{grain_no.split('_')[-1]}_{mol_no}"] = { + "image": filename, + "grain_number": int(grain_no.split("_")[-1]), + "molecule_number": int(mol_no.split("_")[-1]), # pylint: disable=use-maxsplit-arg + } + molstats[f"{grain_no.split('_')[-1]}_{mol_no}"].update(molstat_values) + + # remap the cropped images back onto the original + for image_name, full_image in ordered_trace_full_images.items(): + crop = images[image_name] + bbox = disordered_trace_data["bbox"] + full_image[bbox[0] : bbox[2], bbox[1] : bbox[3]] += crop[pad_width:-pad_width, pad_width:-pad_width] + + except Exception as e: # pylint: disable=broad-exception-caught + LOGGER.error( + f"[{filename}] : Ordered tracing for {grain_no} failed. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) + all_traces_data[grain_no] = {} + + grainstats_additions_df = pd.DataFrame.from_dict(grainstats_additions, orient="index") + molstats_df = pd.DataFrame.from_dict(molstats, orient="index") + molstats_df.reset_index(drop=True, inplace=True) + + return all_traces_data, grainstats_additions_df, molstats_df, ordered_trace_full_images diff --git a/topostats/tracing/pruning.py b/topostats/tracing/pruning.py new file mode 100644 index 0000000000..b7d7685638 --- /dev/null +++ b/topostats/tracing/pruning.py @@ -0,0 +1,875 @@ +"""Prune branches from skeletons.""" + +from __future__ import annotations + +import logging +from collections.abc import Callable + +import numpy as np +import numpy.typing as npt + +# from skimage.morphology import binary_dilation, label +from skimage import morphology + +from topostats.logs.logs import LOGGER_NAME +from topostats.tracing.skeletonize import getSkeleton +from topostats.tracing.tracingfuncs import coord_dist, genTracingFuncs, order_branch +from topostats.utils import convolve_skeleton + +LOGGER = logging.getLogger(LOGGER_NAME) + + +def prune_skeleton(image: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> npt.NDArray: + """ + Pruning skeletons using different pruning methods. + + This is a thin wrapper to the methods provided within the pruning classes below. + + Parameters + ---------- + image : npt.NDArray + Original image as 2D numpy array. + skeleton : npt.NDArray + Skeleton to be pruned. + pixel_to_nm_scaling : float + The pixel to nm scaling for pruning by length. + **kwargs + Pruning options passed to the respective method. + + Returns + ------- + npt.NDArray + An array of the skeleton with spurious branching artefacts removed. + """ + if image.shape != skeleton.shape: + raise AttributeError("Error image and skeleton are not the same size.") + return _prune_method(image, skeleton, pixel_to_nm_scaling, **kwargs) + + +def _prune_method(image: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> Callable: + """ + Determine which skeletonize method to use. + + Parameters + ---------- + image : npt.NDArray + Original image as 2D numpy array. + skeleton : npt.NDArray + Skeleton to be pruned. + pixel_to_nm_scaling : float + The pixel to nm scaling for pruning by length. + **kwargs + Pruning options passed to the respective method. + + Returns + ------- + Callable + Returns the function appropriate for the required skeletonizing method. + + Raises + ------ + ValueError + Invalid method passed. + """ + method = kwargs.pop("method") + if method == "topostats": + return _prune_topostats(image, skeleton, pixel_to_nm_scaling, **kwargs) + # @maxgamill-sheffield I've read about a "Discrete Skeleton Evolultion" (DSE) method that might be useful + # @ns-rse (2024-06-04) : https://en.wikipedia.org/wiki/Discrete_skeleton_evolution + # https://link.springer.com/chapter/10.1007/978-3-540-74198-5_28 + # https://dl.acm.org/doi/10.5555/1780074.1780108 + # Python implementation : https://github.com/originlake/DSE-skeleton-pruning + raise ValueError(f"Invalid pruning method provided ({method}) please use one of 'topostats'.") + + +def _prune_topostats(img: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> npt.NDArray: + """ + Prune using the original TopoStats method. + + This is a modified version of the pubhlished Zhang method. + + Parameters + ---------- + img : npt.NDArray + Image used to find skeleton, may be original heights or binary mask. + skeleton : npt.NDArray + Binary mask of the skeleton. + pixel_to_nm_scaling : float + The pixel to nm scaling for pruning by length. + **kwargs + Pruning options passed to the topostatsPrune class. + + Returns + ------- + npt.NDArray + The skeleton with spurious branches removed. + """ + return topostatsPrune(img, skeleton, pixel_to_nm_scaling, **kwargs).prune_skeleton() + + +# class pruneSkeleton: pylint: disable=too-few-public-methods +# """ +# Class containing skeletonization pruning code from factory methods to functions dependent on the method. + +# Pruning is the act of removing spurious branches commonly found when implementing skeletonization algorithms. + +# Parameters +# ---------- +# image : npt.NDArray +# Original image from which the skeleton derives including heights. +# skeleton : npt.NDArray +# Single-pixel-thick skeleton pertaining to features of the image. +# """ + +# def __init__(self, image: npt.NDArray, skeleton: npt.NDArray) -> None: +# """ +# Initialise the class. + +# Parameters +# ---------- +# image : npt.NDArray +# Original image from which the skeleton derives including heights. +# skeleton : npt.NDArray +# Single-pixel-thick skeleton pertaining to features of the image. +# """ +# self.image = image +# self.skeleton = skeleton + +# def prune_skeleton( pylint: disable=dangerous-default-value +# self, +# prune_args: dict = {"pruning_method": "topostats"}, noqa: B006 +# ) -> npt.NDArray: +# """ +# Pruning skeletons. + +# This is a thin wrapper to the methods provided within the pruning classes below. + +# Parameters +# ---------- +# prune_args : dict +# Method to use, default is 'topostats'. + +# Returns +# ------- +# npt.NDArray +# An array of the skeleton with spurious branching artefacts removed. +# """ +# return self._prune_method(prune_args) + +# def _prune_method(self, prune_args: str = None) -> Callable: +# """ +# Determine which skeletonize method to use. + +# Parameters +# ---------- +# prune_args : str +# Method to use for skeletonizing, methods are 'topostats' other options are 'conv'. + +# Returns +# ------- +# Callable +# Returns the function appropriate for the required skeletonizing method. + +# Raises +# ------ +# ValueError +# Invalid method passed. +# """ +# method = prune_args.pop("pruning_method") +# if method == "topostats": +# return self._prune_topostats(self.image, self.skeleton, prune_args) +# I've read about a "Discrete Skeleton Evolultion" (DSE) method that might be useful +# @ns-rse (2024-06-04) : Citation or link? +# raise ValueError(method) + +# @staticmethod +# def _prune_topostats(img: npt.NDArray, skeleton: npt.NDArray, prune_args: dict) -> npt.NDArray: +# """ +# Prune using the original TopoStats method. + +# This is a modified version of the pubhlished Zhang method. + +# Parameters +# ---------- +# img : npt.NDArray +# Image used to find skeleton, may be original heights or binary mask. +# skeleton : npt.NDArray +# Binary mask of the skeleton. +# prune_args : dict +# Dictionary of pruning arguments. ??? Needs expanding on what the valid arguments are. + +# Returns +# ------- +# npt.NDArray +# The skeleton with spurious branches removed. +# """ +# return topostatsPrune(img, skeleton, **prune_args).prune_skeleton() + + +# Might be worth renaming this to reflect what it does which is prune by length and height +class topostatsPrune: + """ + Prune spurious skeletal branches based on their length and/or height. + + Contains all the functions used in the original TopoStats pruning code written by Joe Betton. + + Parameters + ---------- + img : npt.NDArray + Original image. + skeleton : npt.NDArray + Skeleton to be pruned. + pixel_to_nm_scaling : float + The pixel to nm scaling for pruning by length. + max_length : float + Maximum length of the branch to prune in nanometres (nm). + height_threshold : float + Absolute height value to remove branches below in nanometres (nm). + method_values : str + Method for obtaining the height thresholding values. Options are 'min' (minimum value of the branch), + 'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value). + method_outlier : str + Method for pruning brancvhes based on height. Options are 'abs' (below absolute value), 'mean_abs' (below the + skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range). + """ + + # pylint: disable=too-many-arguments + def __init__( + self, + img: npt.NDArray, + skeleton: npt.NDArray, + pixel_to_nm_scaling: float, + max_length: float = None, + height_threshold: float = None, + method_values: str = None, + method_outlier: str = None, + ) -> None: + """ + Initialise the class. + + Parameters + ---------- + img : npt.NDArray + Original image. + skeleton : npt.NDArray + Skeleton to be pruned. + pixel_to_nm_scaling : float + The pixel to nm scaling for pruning by length. + max_length : float + Maximum length of the branch to prune in nanometres (nm). + height_threshold : float + Absolute height value to remove branches below in nanometres (nm). + method_values : str + Method for obtaining the height thresholding values. Options are 'min' (minimum value of the branch), + 'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value). + method_outlier : str + Method for pruning brancvhes based on height. Options are 'abs' (below absolute value), 'mean_abs' (below + the skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range). + """ + self.img = img + self.skeleton = skeleton.copy() + self.pixel_to_nm_scaling = pixel_to_nm_scaling + self.max_length = max_length + self.height_threshold = height_threshold + self.method_values = method_values + self.method_outlier = method_outlier + + # Diverges from the change in layout to apply skeletonisation/pruning/tracing to individual grains and then process + # all grains in an image (possibly in parallel). + def prune_skeleton(self) -> npt.NDArray: + """ + Prune skeleton by length and/or height. + + If the class was initialised with both `max_length is not None` an d `height_threshold is not None` then length + based pruning is performed prior to height based pruning. + + Returns + ------- + npt.NDArray + A pruned skeleton. + """ + pruned_skeleton_mask = np.zeros_like(self.skeleton, dtype=np.uint8) + # print(f"{pruned_skeleton_mask=}") + labeled_skel = morphology.label(self.skeleton) + for i in range(1, labeled_skel.max() + 1): + single_skeleton = np.where(labeled_skel == i, 1, 0) + if self.max_length is not None: + LOGGER.debug(f": pruning.py : Pruning by length < {self.max_length}.") + single_skeleton = self._prune_by_length(single_skeleton, max_length=self.max_length) + if self.height_threshold is not None: + LOGGER.debug(": pruning.py : Pruning by height.") + single_skeleton = heightPruning( + self.img, + single_skeleton, + height_threshold=self.height_threshold, + method_values=self.method_values, + method_outlier=self.method_outlier, + ).skeleton_pruned + # skeletonise to remove nibs + # Discovered this caused an error when writing tests... + # + # numpy.core._exceptions._UFuncOutputCastingError: Cannot cast ufunc 'add' output from dtype('int8') to + # dtype('bool') with casting... + # pruned_skeleton_mask += getSkeleton(self.img, single_skeleton, method="zhang").get_skeleton() + pruned_skeleton = getSkeleton(self.img, single_skeleton, method="zhang").get_skeleton() + pruned_skeleton_mask += pruned_skeleton.astype(dtype=np.uint8) + return pruned_skeleton_mask + + def _prune_by_length( # pylint: disable=too-many-locals # noqa: C901 + self, single_skeleton: npt.NDArray, max_length: float + ) -> npt.NDArray: + """ + Remove hanging branches from a skeleton by their length. + + This is an iterative process as these are a persistent problem in the overall tracing process. + + Parameters + ---------- + single_skeleton : npt.NDArray + Binary array of the skeleton. + max_length : float + Maximum length of the branch to prune in nanometers (nm). + + Returns + ------- + npt.NDArray + Pruned skeleton as binary array. + """ + # get segments via convolution and removing junctions + conv_skeleton = convolve_skeleton(single_skeleton) + conv_skeleton[conv_skeleton == 3] = 0 + labeled_segments = morphology.label(conv_skeleton.astype(bool)) + + for segment_idx in range(1, labeled_segments.max() + 1): + # get single segment with endpoints==2 + segment = np.where(labeled_segments == segment_idx, conv_skeleton, 0) + # get segment length + ordered_coords = order_branch(np.where(segment != 0, 1, 0), [0, 0]) + segment_length = coord_dist(ordered_coords, self.pixel_to_nm_scaling)[-1] + # check if endpoint + if 2 in segment and segment_length < max_length: + # prune + single_skeleton[labeled_segments == segment_idx] = 0 + + return rm_nibs(single_skeleton) + + @staticmethod + def _find_branch_ends(coordinates: list) -> list: + """ + Identify branch ends. + + This is achieved by iterating through the coordinates and assessing the local pixel area. Ends have only one + adjacent pixel. + + Parameters + ---------- + coordinates : list + List of x, y coordinates of a branch. + + Returns + ------- + list + List of x, y coordinates of the branch ends. + """ + branch_ends = [] + + # Most of the branch ends are just points with one neighbour + for x, y in coordinates: + if genTracingFuncs.count_and_get_neighbours(x, y, coordinates)[0] == 1: + branch_ends.append([x, y]) + return branch_ends + + +class heightPruning: # pylint: disable=too-many-instance-attributes + """ + Pruning of branches based on height. + + Parameters + ---------- + image : npt.NDArray + Original image, typically the height data. + skeleton : npt.NDArray + Skeleton to prune branches from. + max_length : float + Maximum length of the branch to prune in nanometres (nm). + height_threshold : float + Absolute height value to remove branches below in nanometers (nm). + method_values : str + Method of obtaining the height thresholding values. Options are 'min' (minimum value of the branch), + 'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value). + method_outlier : str + Method to prune branches based on height. Options are 'abs' (below absolute value), 'mean_abs' (below the + skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range). + """ # numpydoc: ignore=PR01 + + def __init__( + self, + image: npt.NDArray, + skeleton: npt.NDArray, + max_length: float = None, + height_threshold: float = None, + method_values: str = None, + method_outlier: str = None, + ) -> None: + """ + Initialise the class. + + Parameters + ---------- + image : npt.NDArray + Original image, typically the height data. + skeleton : npt.NDArray + Skeleton to prune branches from. + max_length : float + Maximum length of the branch to prune in nanometres (nm). + height_threshold : float + Absolute height value to remove branches below in nanometers (nm). + method_values : str + Method of obtaining the height thresholding values. Options are 'min' (minimum value of the branch), + 'median' (median value of the branch) or 'mid' (ordered branch middle coordinate value). + method_outlier : str + Method to prune branches based on height. Options are 'abs' (below absolute value), 'mean_abs' (below the + skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range). + """ + self.image = image + self.skeleton = skeleton + self.skeleton_convolved = None + self.skeleton_branches = None + self.skeleton_branches_labelled = None + self.max_length = max_length + self.height_threshold = height_threshold + self.method_values = method_values + self.method_outlier = method_outlier + self.convolve_skeleton() + self.segment_skeleton() + self.label_branches() + self.skeleton_pruned = self.height_prune() + + def convolve_skeleton(self) -> None: + """Convolve skeleton.""" + self.skeleton_convolved = convolve_skeleton(self.skeleton) + + def segment_skeleton(self) -> None: + """Convolve skeleton and break into segments at nodes/junctions.""" + self.skeleton_branches = np.where(self.skeleton_convolved == 3, 0, self.skeleton) + + def label_branches(self) -> None: + """Label segmented branches.""" + self.skeleton_branches_labelled = morphology.label(self.skeleton_branches) + + def _get_branch_mins(self, segments: npt.NDArray) -> npt.NDArray: + """ + Collect the minimum height value of each individually labeled branch. + + Parameters + ---------- + segments : npt.NDArray + Integer labeled array matching the dimensions of the image. + + Returns + ------- + npt.NDArray + Array of minimum values of each branch index -1. + """ + return np.array([np.min(self.image[segments == i]) for i in range(1, segments.max() + 1)]) + + def _get_branch_medians(self, segments: npt.NDArray) -> npt.NDArray: + """ + Collect the median height value of each labeled branch. + + Parameters + ---------- + segments : npt.NDArray + Integer labeled array matching the dimensions of the image. + + Returns + ------- + npt.NDArray + Array of median values of each branch index -1. + """ + return np.array([np.median(self.image[segments == i]) for i in range(1, segments.max() + 1)]) + + def _get_branch_middles(self, segments: npt.NDArray) -> npt.NDArray: + """ + Collect the positionally ordered middle height value of each labeled branch. + + Where the branch has an even amount of points, average the two middle heights. + + Parameters + ---------- + segments : npt.NDArray + Integer labeled array matching the dimensions of the image. + + Returns + ------- + npt.NDArray + Array of middle values of each branch. + """ + branch_middles = np.zeros(segments.max()) + for i in range(1, segments.max() + 1): + segment = np.where(segments == i, 1, 0) + if segment.sum() > 2: + # sometimes start is not found ? + start = np.argwhere(convolve_skeleton(segment) == 2)[0] + ordered_coords = order_branch_from_end(segment, start) + # if even no. points, average two middles + middle_idx, middle_remainder = (len(ordered_coords) + 1) // 2 - 1, (len(ordered_coords) + 1) % 2 + mid_coord = ordered_coords[[middle_idx, middle_idx + middle_remainder]] + # height = image[mid_coord[:, 0], mid_coord[:, 1]].mean() + height = self.image[mid_coord[:, 0], mid_coord[:, 1]].mean() + else: + # if 2 points, need to average them + height = self.image[segment == 1].mean() + branch_middles[i - 1] += height + return branch_middles + + @staticmethod + def _get_abs_thresh_idx(height_values: npt.NDArray, threshold: float | int) -> npt.NDArray: + """ + Identify indices of labelled branches whose height values are less than a given threshold. + + Parameters + ---------- + height_values : npt.NDArray + Array of each branches heights. + threshold : float | int + Threshold for heights. + + Returns + ------- + npt.NDArray + Branch indices which are less than threshold. + """ + return np.asarray(np.where(height_values < threshold))[0] + 1 + + @staticmethod + def _get_mean_abs_thresh_idx( + height_values: npt.NDArray, threshold: float | int, image: npt.NDArray, skeleton: npt.NDArray + ) -> npt.NDArray: + """ + Identify indices of labelled branch whose height values are less than mean skeleton height - absolute threshold. + + For DNA a threshold of 0.85nm (the depth of the major groove) would ideally remove all segments whose lowest + point is < mean(height) - 0.85nm, i.e. 1.15nm. + + Parameters + ---------- + height_values : npt.NDArray + Array of branches heights. + threshold : float | int + Threshold to be subtracted from mean heights. + image : npt.NDArray + Original image of heights. + skeleton : npt.NDArray + Binary array of skeleton used to identify heights from original image to use. + + Returns + ------- + npt.NDArray + Branch indices which are less than mean(height) - threshold. + """ + avg = image[skeleton == 1].mean() + print(f"{avg=}") + print(f"{(avg-threshold)=}") + return np.asarray(np.where(np.asarray(height_values) < (avg - threshold)))[0] + 1 + + @staticmethod + def _get_iqr_thresh_idx(image: npt.NDArray, segments: npt.NDArray) -> npt.NDArray: + """ + Identify labelled branch indices whose heights are less than 1.5 x interquartile range of all heights. + + Parameters + ---------- + image : npt.NDArray + Original image with heights. + segments : npt.NDArray + Array of skeleton branches. + + Returns + ------- + npt.NDArray + Branch indices where heights are < 1.5 * inter-quartile range. + """ + coords = np.argwhere(segments != 0) + heights = image[coords[:, 0], coords[:, 1]] # all skel heights else distribution isn't representitive + q75, q25 = np.percentile(heights, [75, 25]) + iqr = q75 - q25 + threshold = q25 - 1.5 * iqr + print(f"{q25=}") + print(f"{q75=}") + print(f"{threshold=}") + low_coords = coords[heights < threshold] + low_segment_idxs = [] + low_segment_mins = [] + # iterate through each branch segment and see if any low_coords are in a branch + for segment_num in range(1, segments.max() + 1): + segment_coords = np.argwhere(segments == segment_num) + for low_coord in low_coords: + place = np.isin(segment_coords, low_coord).all(axis=1) + if place.any(): + low_segment_idxs.append(segment_num) + low_segment_mins.append(image[segments == segment_num].min()) + break + return np.array(low_segment_idxs)[np.argsort(low_segment_mins)] # sort in order of ascending mins + + @staticmethod + def check_skeleton_one_object(skeleton: npt.NDArray) -> bool: + """ + Ensure that the skeleton hasn't been broken up upon removing a segment. + + Parameters + ---------- + skeleton : npt.NDArray + 2D single pixel thick array. + + Returns + ------- + bool + True or False depending on whether there is 1 or !1 objects. + """ + skeleton = np.where(skeleton != 0, 1, 0) + return morphology.label(skeleton).max() == 1 + + def filter_segments(self, segments: npt.NDArray) -> npt.NDArray: + """ + Identify and remove segments of a skeleton based on the underlying image height. + + Parameters + ---------- + segments : npt.NDArray + A labelled 2D array of skeleton segments. + + Returns + ------- + npt.NDArray + The original skeleton without the segments identified by the height criteria. + """ + # Obtain the height of each branch via the min | median | mid methods + if self.method_values == "min": + height_values = self._get_branch_mins(segments) + elif self.method_values == "median": + height_values = self._get_branch_medians(segments) + elif self.method_values == "mid": + height_values = self._get_branch_middles(segments) + # threshold heights to obtain indexes of branches to be removed + if self.method_outlier == "abs": + idxs = self._get_abs_thresh_idx(height_values, self.height_threshold) + elif self.method_outlier == "mean_abs": + idxs = self._get_mean_abs_thresh_idx(height_values, self.height_threshold, self.image, self.skeleton) + elif self.method_outlier == "iqr": + idxs = self._get_iqr_thresh_idx(self.image, segments) + + # Only remove the bridge if the skeleton remains a single object. + skeleton_rtn = self.skeleton.copy() + for i in idxs: + temp_skel = self.skeleton.copy() + temp_skel[segments == i] = 0 + if self.check_skeleton_one_object(temp_skel): + skeleton_rtn[segments == i] = 0 + + return skeleton_rtn + + # def remove_bridges(self) -> npt.NDArray: + # """ + # Identify and remove skeleton bridges using the underlying image height. + + # Bridges cross the skeleton in places they shouldn't and are defined as an internal branch and thus have no + # endpoints. They occur due to poor thresholding creating holes in the mask, creating false "bridges" which + # misrepresent the skeleton of the molecule. + + # Returns + # ------- + # npt.NDArray + # A skeleton with internal branches removed by height. + # """ + # conv = convolve_skeleton(self.skeleton) + # # Split the skeleton into branches by removing junctions/nodes and label + # nodeless = np.where(conv == 3, 0, conv) + # segments = morphology.label(np.where(nodeless != 0, 1, 0)) + # # bridges should not concern endpoints so remove these + # for i in range(1, segments.max() + 1): + # if (conv[segments == i] == 2).any(): + # segments[segments == i] = 0 + # segments = morphology.label(np.where(segments != 0, 1, 0)) + + # # filter the segments based on height criteria + # return self.filter_segments(segments) + + def height_prune(self) -> npt.NDArray: + """ + Identify and remove spurious branches (containing endpoints) using the underlying image height. + + Returns + ------- + npt.NDArray + A skeleton with outer branches removed by height. + """ + conv = convolve_skeleton(self.skeleton) + segments = self._split_skeleton(conv) + # height pruning should only concern endpoints so remove internal connections + for i in range(1, segments.max() + 1): + if not (conv[segments == i] == 2).any(): + segments[segments == i] = 0 + segments = morphology.label(np.where(segments != 0, 1, 0)) + + # filter the segments based on height criteria + return self.filter_segments(segments) + + @staticmethod + def _split_skeleton(skeleton: npt.NDArray) -> npt.NDArray: + """ + Split the skeleton into branches by removing junctions/nodes and label branches. + + Parameters + ---------- + skeleton : npt.NDArray + Convolved skeleton to be split. This should have nodes labelled as 3, ends as 2 and all other points as 1. + + Returns + ------- + npt.NDArray + Removes the junctions (3) and returns all remaining sections as labelled segments. + """ + nodeless = np.where(skeleton == 3, 0, skeleton) + return morphology.label(np.where(nodeless != 0, 1, 0)) + + +def order_branch_from_end(nodeless: npt.NDArray, start: list, max_length: float = np.inf) -> npt.NDArray: + """ + Take a linear branch and orders its coordinates starting from a specific endpoint. + + NB - It may be possible to use np.lexsort() to order points, see topostats.measure.feret.sort_coords() for an + example of how to sort by row or column coordinates, which end of the branch this is from probably doesn't matter + as one only wants to find the mid-point I think. + + Parameters + ---------- + nodeless : npt.NDArray + A 2D binary array where there are no crossing pixels. + start : list + A coordinate to start closest to / at. + max_length : float, optional + The maximum length to order along the branch, in pixels, by default np.inf. + + Returns + ------- + npt.NDArray + The input linear branch ordered from the start coordinate. + """ + dist = 0 + # add starting point to ordered array + ordered = [] + ordered.append(start) + nodeless[start[0], start[1]] = 0 # remove from array + + # iterate to order the rest of the points + current_point = ordered[-1] # get last point + area, _ = local_area_sum(nodeless, current_point) # look at local area + local_next_point = np.argwhere( + area.reshape( + ( + 3, + 3, + ) + ) + == 1 + ) - (1, 1) + dist += np.sqrt(2) if abs(local_next_point).sum() > 1 else 1 + + while len(local_next_point) != 0 and dist <= max_length: + next_point = (current_point + local_next_point)[0] + # find where to go next + ordered.append(next_point) + nodeless[next_point[0], next_point[1]] = 0 # set value to zero + current_point = ordered[-1] # get last point + area, _ = local_area_sum(nodeless, current_point) # look at local area + local_next_point = np.argwhere( + area.reshape( + ( + 3, + 3, + ) + ) + == 1 + ) - (1, 1) + dist += np.sqrt(2) if abs(local_next_point).sum() > 1 else 1 + + return np.array(ordered) + + +def rm_nibs(skeleton): # pylint: disable=too-many-locals + """ + Remove single pixel branches (nibs) not identified by nearest neighbour algorithms as there may be >2 neighbours. + + Parameters + ---------- + skeleton : npt.NDArray + A single pixel thick trace. + + Returns + ------- + npt.NDArray + A skeleton with single pixel nibs removed. + """ + conv_skel = convolve_skeleton(skeleton) + nodes = np.where(conv_skel == 3, 1, 0) + labeled_nodes = morphology.label(nodes) + nodeless = np.where((conv_skel == 1) | (conv_skel == 2), 1, 0) + labeled_nodeless = morphology.label(nodeless) + size_1_idxs = [] + + for node_num in range(1, labeled_nodes.max() + 1): + node = np.where(labeled_nodes == node_num, 1, 0) + dil = morphology.binary_dilation(node, footprint=np.ones((3, 3))) + minus = np.where(dil != node, 1, 0) + + idxs = labeled_nodeless[minus == 1] + idxs = idxs[idxs != 0] + for nodeless_num in np.unique(idxs): + # if all of the branch is in surrounding node area + branch_size = (labeled_nodeless == nodeless_num).sum() + branch_idx_in_surr_area = (idxs == nodeless_num).sum() + if branch_size == branch_idx_in_surr_area: + size_1_idxs.append(nodeless_num) + + unique, counts = np.unique(np.array(size_1_idxs), return_counts=True) + + for k, count in enumerate(counts): + if count == 1: + skeleton[labeled_nodeless == unique[k]] = 0 + + return skeleton + + +def local_area_sum(img: npt.NDArray, point: list | tuple | npt.NDArray) -> tuple: + """ + Evaluate the local area around a point in a binary map. + + Parameters + ---------- + img : npt.NDArray + Binary array of image. + point : list | tuple | npt.NDArray + Coordinates of a point within the binary_map. + + Returns + ------- + tuple + Tuple consisting of an array values of the local coordinates around the point and the number of neighbours + around the point. + """ + if img[point[0], point[1]] > 1: + raise ValueError("binary_map is not binary!") + # Capture if point is on the top or left edge or array + try: + local_pixels = img[point[0] - 1 : point[0] + 2, point[1] - 1 : point[1] + 2].flatten() + except IndexError as exc: + raise IndexError("Point can not be on the edge of an array.") from exc + # Above does not capture points on right or bottom since slicing arrays beyond their indexes simply extends them + # Therefore check that we have an array of length 9 + if local_pixels.shape[0] == 9: + local_pixels[4] = 0 # ensure centre is 0 + if local_pixels.sum() <= 8: + return local_pixels, local_pixels.sum() + raise ValueError("'binary_map' is not binary!") + raise IndexError("'point' is on right or bottom edge of 'binary_map'") diff --git a/topostats/tracing/skeletonize.py b/topostats/tracing/skeletonize.py index 74e005d552..a3ccdd1083 100644 --- a/topostats/tracing/skeletonize.py +++ b/topostats/tracing/skeletonize.py @@ -3,111 +3,580 @@ import logging from collections.abc import Callable +import numpy as np import numpy.typing as npt -from skimage.morphology import skeletonize, thin +from skimage.morphology import medial_axis, skeletonize, thin from topostats.logs.logs import LOGGER_NAME LOGGER = logging.getLogger(LOGGER_NAME) -def get_skeleton(image: npt.NDArray, method: str) -> npt.NDArray: +class getSkeleton: # pylint: disable=too-few-public-methods """ - Skeletonizing masked molecules. + Class skeletonising images. Parameters ---------- image : npt.NDArray - Image of molecule to be skeletonized. + Image used to generate the mask. + mask : npt.NDArray + Binary mask of features. method : str - Method to use, default is 'zhang' other options are 'lee', and 'thin'. - - Returns - ------- - npt.NDArray - Skeletonised version of the image.all($0). - - Notes - ----- - This is a thin wrapper to the methods provided - by the `skimage.morphology - `_ - module. See also the `examples - _ + Method for skeletonizing. Options 'zhang' (default), 'lee', 'medial_axis', 'thin' and 'topostats'. + height_bias : float + Ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 is all pixels + smiilar to Zhang. """ - skeletonizer = _get_skeletonize(method) - return skeletonizer(image) + def __init__(self, image: npt.NDArray, mask: npt.NDArray, method: str = "zhang", height_bias: float = 0.6): + """ + Initialise the class. -def _get_skeletonize(method: str = "zhang") -> Callable: - """ - Creator component which determines which skeletonize method to use. + This is a thin wrapper to the methods provided by the `skimage.morphology + `_ + module. See also the `examples + _ - Parameters - ---------- - method : str - Method to use for skeletonizing, methods are 'zhang' (default), 'lee', and 'thin'. + Parameters + ---------- + image : npt.NDArray + Image used to generate the mask. + mask : npt.NDArray + Binary mask of features. + method : str + Method for skeletonizing. Options 'zhang' (default), 'lee', 'medial_axis', 'thin' and 'topostats'. + height_bias : float + Ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 is all + pixels smiilar to Zhang. + """ + # Q What benefit is there to having a class getSkeleton over the get_skeleton() function? Ostensibly the class + # is doing only one thing, we don't need to change state/modify anything here. Beyond encapsulating all + # functions in a single class this feels like overkill. + self.image = image + self.mask = mask + self.method = method + self.height_bias = height_bias - Returns - ------- - Callable - Returns the function appropriate for the required skeletonizing method. - """ - if method == "zhang": - return _skeletonize_zhang - if method == "lee": - return _skeletonize_lee - if method == "thin": - return _skeletonize_thin - raise ValueError(method) + def get_skeleton(self) -> npt.NDArray: + """ + Skeletonise molecules. + Returns + ------- + npt.NDArray + Skeletonised version of the binary mask (possibly using criteria from the image). + """ + return self._get_skeletonize() -def _skeletonize_zhang(image: npt.NDArray) -> npt.NDArray: - """ - Skeletonize using Zhang method. + def _get_skeletonize(self) -> Callable: + """ + Determine which skeletonise method to use. - Parameters - ---------- - image : npt.NDArray - Numpy array to be skeletonized. + Returns + ------- + Callable + Returns the function appropriate for the required skeletonizing method. + """ + if self.method == "zhang": + return self._skeletonize_zhang(mask=self.mask).astype(np.int8) + if self.method == "lee": + return self._skeletonize_lee(mask=self.mask).astype(np.int8) + if self.method == "medial_axis": + return self._skeletonize_medial_axis(mask=self.mask).astype(np.int8) + if self.method == "thin": + return self._skeletonize_thin(mask=self.mask).astype(np.int8) + if self.method == "topostats": + return self._skeletonize_topostats(image=self.image, mask=self.mask, height_bias=self.height_bias).astype( + np.int8 + ) + raise ValueError(self.method) - Returns - ------- - npt.NDArray - Skeletonized Numpy array. - """ - return skeletonize(image, method="zhang") + @staticmethod + def _skeletonize_zhang(mask: npt.NDArray) -> npt.NDArray: + """ + Use scikit-image implementation of the Zhang skeletonisation method. + Parameters + ---------- + mask : npt.NDArray + Binary array to skeletonise. -def _skeletonize_lee(image: npt.NDArray) -> npt.NDArray: - """ - Skeletonize using Lee method. + Returns + ------- + npt.NDArray + Mask array reduced to a single pixel thickness. + """ + return skeletonize(mask, method="zhang") - Parameters - ---------- - image : npt.NDArray - Numpy array to be skeletonized. + @staticmethod + def _skeletonize_lee(mask: npt.NDArray) -> npt.NDArray: + """ + Use scikit-image implementation of the Lee skeletonisation method. - Returns - ------- - npt.NDArray - Skeletonized Numpy array. - """ - return skeletonize(image, method="lee") + Parameters + ---------- + mask : npt.NDArray + Binary array to skeletonise. + + Returns + ------- + npt.NDArray + Mask array reduced to a single pixel thickness. + """ + return skeletonize(mask, method="lee") + + @staticmethod + def _skeletonize_medial_axis(mask: npt.NDArray) -> npt.NDArray: + """ + Use scikit-image implementation of the Medial axis skeletonisation method. + + Parameters + ---------- + mask : npt.NDArray + Binary array to skeletonise. + + Returns + ------- + npt.NDArray + Mask array reduced to a single pixel thickness. + """ + return medial_axis(mask, return_distance=False) + + @staticmethod + def _skeletonize_thin(mask: npt.NDArray) -> npt.NDArray: + """ + Use scikit-image implementation of the thinning skeletonisation method. + + Parameters + ---------- + mask : npt.NDArray + Binary array to skeletonise. + + Returns + ------- + npt.NDArray + Mask array reduced to a single pixel thickness. + """ + return thin(mask) + @staticmethod + def _skeletonize_topostats(image: npt.NDArray, mask: npt.NDArray, height_bias: float = 0.6) -> npt.NDArray: + """ + Use scikit-image implementation of the Zhang skeletonisation method. -def _skeletonize_thin(image: npt.NDArray) -> npt.NDArray: + This method is based on Zhang's method but produces different results (less branches but slightly less + accurate). + + Parameters + ---------- + image : npt.NDArray + Original image with heights. + mask : npt.NDArray + Binary array to skeletonise. + height_bias : float + Ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 is all + pixels smiilar to Zhang. + + Returns + ------- + npt.NDArray + Masked array reduced to a single pixel thickness. + """ + return topostatsSkeletonize(image, mask, height_bias).do_skeletonising() + + +class topostatsSkeletonize: # pylint: disable=too-many-instance-attributes """ - Skeletonize using thinning method. + Skeletonise a binary array following Zhang's algorithm (Zhang and Suen, 1984). + + Modifications are made to the published algorithm during the removal step to remove a fraction of the smallest pixel + values opposed to all of them in the aforementioned algorithm. All operations are performed on the mask entered. Parameters ---------- image : npt.NDArray - Numpy array to be skeletonized. - - Returns - ------- - npt.NDArray - Skeletonized Numpy array. + Original 2D image containing the height data. + mask : npt.NDArray + Binary image containing the object to be skeletonised. Dimensions should match those of 'image'. + height_bias : float + Ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 is all pixels + smiilar to Zhang. """ - return thin(image) + + def __init__(self, image: npt.NDArray, mask: npt.NDArray, height_bias: float = 0.6): + """ + Initialise the class. + + Parameters + ---------- + image : npt.NDArray + Original 2D image containing the height data. + mask : npt.NDArray + Binary image containing the object to be skeletonised. Dimensions should match those of 'image'. + height_bias : float + Ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 is all + pixels smiilar to Zhang. + """ + self.image = image + self.mask = mask.copy() + self.height_bias = height_bias + + self.skeleton_converged = False + self.p2 = None + self.p3 = None + self.p4 = None + self.p5 = None + self.p6 = None + self.p7 = None + self.p8 = None + self.p9 = None + self.counter = 0 + + def do_skeletonising(self) -> npt.NDArray: + """ + Perform skeletonisation. + + Returns + ------- + npt.NDArray + The single pixel thick, skeletonised array. + """ + while not self.skeleton_converged: + self._do_skeletonising_iteration() + # When skeleton converged do an additional iteration of thinning to remove hanging points + self.final_skeletonisation_iteration() + self.mask = getSkeleton( + image=self.image, mask=self.mask, method="zhang" + ).get_skeleton() # not sure if this is needed? + + return self.mask + + def _do_skeletonising_iteration(self) -> None: + """ + Obtain the local binary pixel environment and assess the local pixel values. + + This determines whether to delete a point according to the Zhang algorithm. + + Then removes ratio of lowest intensity (height) pixels to total pixels fitting the skeletonisation criteria. 1 + is all pixels smiilar to Zhang. + """ + skel_img = self.mask.copy() + pixels_to_delete = [] + # Sub-iteration 1 - binary check + mask_coordinates = np.argwhere(self.mask == 1).tolist() + for point in mask_coordinates: + if self._delete_pixel_subit1(point): + pixels_to_delete.append(point) + + # remove points based on height (lowest height_bias%) + pixels_to_delete = np.asarray(pixels_to_delete) # turn into array + if pixels_to_delete.shape != (0,): # ensure array not empty + skel_img[pixels_to_delete[:, 0], pixels_to_delete[:, 1]] = 2 + heights = self.image[pixels_to_delete[:, 0], pixels_to_delete[:, 1]] # get heights of pixels + height_sort_idx = self.sort_and_shuffle(heights)[1][ + : int(np.ceil(len(heights) * self.height_bias)) + ] # idx of lowest height_bias% + self.mask[pixels_to_delete[height_sort_idx, 0], pixels_to_delete[height_sort_idx, 1]] = ( + 0 # remove lowest height_bias% + ) + + pixels_to_delete = [] + # Sub-iteration 2 - binary check + mask_coordinates = np.argwhere(self.mask == 1).tolist() + for point in mask_coordinates: + if self._delete_pixel_subit2(point): + pixels_to_delete.append(point) + + # remove points based on height (lowest height_bias%) + pixels_to_delete = np.asarray(pixels_to_delete) + if pixels_to_delete.shape != (0,): + skel_img[pixels_to_delete[:, 0], pixels_to_delete[:, 1]] = 3 + heights = self.image[pixels_to_delete[:, 0], pixels_to_delete[:, 1]] + height_sort_idx = self.sort_and_shuffle(heights)[1][ + : int(np.ceil(len(heights) * self.height_bias)) + ] # idx of lowest height_bias% + self.mask[pixels_to_delete[height_sort_idx, 0], pixels_to_delete[height_sort_idx, 1]] = ( + 0 # remove lowest height_bias% + ) + + if len(pixels_to_delete) == 0: + self.skeleton_converged = True + + def _delete_pixel_subit1(self, point: list) -> bool: + """ + Check whether a single point (P1) should be deleted based on its local binary environment. + + (a) 2 ≤ B(P1) ≤ 6, where B(P1) is the number of non-zero neighbours of P1. + (b) A(P1) = 1, where A(P1) is the # of 01's around P1. + (C) P2 * P4 * P6 = 0 + (d) P4 * P6 * P8 = 0 + + Parameters + ---------- + point : list + List of [x, y] coordinate positions. + + Returns + ------- + bool + Indicates whether to delete depending on whether the surrounding points have met the criteria of the binary + thin a, b returncount, c and d checks below. + """ + self.p7, self.p8, self.p9, self.p6, self.p2, self.p5, self.p4, self.p3 = self.get_local_pixels_binary( + self.mask, point[0], point[1] + ) + return ( + self._binary_thin_check_a() + and self._binary_thin_check_b_returncount() == 1 + # c and d remove only north-west corner points and south-east boundary points. + and self._binary_thin_check_c() + and self._binary_thin_check_d() + ) + + def _delete_pixel_subit2(self, point: list) -> bool: + """ + Check whether a single point (P1) should be deleted based on its local binary environment. + + (a) 2 ≤ B(P1) ≤ 6, where B(P1) is the number of non-zero neighbours of P1. + (b) A(P1) = 1, where A(P1) is the # of 01's around P1. + (c') P2 * P4 * P8 = 0 + (d') P2 * P6 * P8 = 0 + + Parameters + ---------- + point : list + List of [x, y] coordinate positions. + + Returns + ------- + bool + Whether surrounding points have met the criteria of the binary thin a, b returncount, csharp and dsharp + checks below. + """ + self.p7, self.p8, self.p9, self.p6, self.p2, self.p5, self.p4, self.p3 = self.get_local_pixels_binary( + self.mask, point[0], point[1] + ) + # Add in generic code here to protect high points from being deleted + return ( + self._binary_thin_check_a() + and self._binary_thin_check_b_returncount() == 1 + # c' and d' remove only north-west boundary points or south-east corner points. + and self._binary_thin_check_csharp() + and self._binary_thin_check_dsharp() + ) + + def _binary_thin_check_a(self) -> bool: + """ + Check the surrounding area to see if the point lies on the edge of the grain. + + Condition A protects the endpoints (which will be < 2) + + Returns + ------- + bool + If point lies on edge of graph and isn't an endpoint. + """ + return 2 <= self.p2 + self.p3 + self.p4 + self.p5 + self.p6 + self.p7 + self.p8 + self.p9 <= 6 + + def _binary_thin_check_b_returncount(self) -> int: + """ + Count local area 01's in order around P1. + + ??? What does this mean? + + Returns + ------- + int + The number of 01's around P1. + """ + return sum( + [ + [self.p2, self.p3] == [0, 1], + [self.p3, self.p4] == [0, 1], + [self.p4, self.p5] == [0, 1], + [self.p5, self.p6] == [0, 1], + [self.p6, self.p7] == [0, 1], + [self.p7, self.p8] == [0, 1], + [self.p8, self.p9] == [0, 1], + [self.p9, self.p2] == [0, 1], + ] + ) + + def _binary_thin_check_c(self) -> bool: + """ + Check if p2, p4 or p6 is 0. + + Returns + ------- + bool + If p2, p4 or p6 is 0. + """ + return self.p2 * self.p4 * self.p6 == 0 + + def _binary_thin_check_d(self) -> bool: + """ + Check if p4, p6 or p8 is 0. + + Returns + ------- + bool + If p4, p6 or p8 is 0. + """ + return self.p4 * self.p6 * self.p8 == 0 + + def _binary_thin_check_csharp(self) -> bool: + """ + Check if p2, p4 or p8 is 0. + + Returns + ------- + bool + If p2, p4 or p8 is 0. + """ + return self.p2 * self.p4 * self.p8 == 0 + + def _binary_thin_check_dsharp(self) -> bool: + """ + Check if p2, p6 or p8 is 0. + + Returns + ------- + bool + If p2, p6 or p8 is 0. + """ + return self.p2 * self.p6 * self.p8 == 0 + + def final_skeletonisation_iteration(self) -> None: + """ + Remove "hanging" pixels. + + Examples of such pixels are: + [0, 0, 0] [0, 1, 0] [0, 0, 0] + [0, 1, 1] [0, 1, 1] [0, 1, 1] + case 1: [0, 1, 0] or case 2: [0, 1, 0] or case 3: [1, 1, 0] + + This is useful for the future functions that rely on local pixel environment + to make assessments about the overall shape/structure of traces. + """ + remaining_coordinates = np.argwhere(self.mask).tolist() + + for x, y in remaining_coordinates: + self.p7, self.p8, self.p9, self.p6, self.p2, self.p5, self.p4, self.p3 = self.get_local_pixels_binary( + self.mask, x, y + ) + + # Checks for case 1 and 3 pixels + if ( + self._binary_thin_check_b_returncount() == 2 + and self._binary_final_thin_check_a() + and not self.binary_thin_check_diag() + ): + self.mask[x, y] = 0 + # Checks for case 2 pixels + elif self._binary_thin_check_b_returncount() == 3 and self._binary_final_thin_check_b(): + self.mask[x, y] = 0 + + def _binary_final_thin_check_a(self) -> bool: + """ + Assess if local area has 4-connectivity. + + Returns + ------- + bool + Logical indicator of whether if any neighbours of the 4-connections have a near pixel. + """ + return 1 in (self.p2 * self.p4, self.p4 * self.p6, self.p6 * self.p8, self.p8 * self.p2) + + def _binary_final_thin_check_b(self) -> bool: + """ + Assess if local area 4-connectivity is connected to multiple branches. + + Returns + ------- + bool + Logical indicator of whether if any neighbours of the 4-connections have a near pixel. + """ + return 1 in ( + self.p2 * self.p4 * self.p6, + self.p4 * self.p6 * self.p8, + self.p6 * self.p8 * self.p2, + self.p8 * self.p2 * self.p4, + ) + + def binary_thin_check_diag(self) -> bool: + """ + Check if opposite corner diagonals are present. + + Returns + ------- + bool + Whether a diagonal exists. + """ + return 1 in (self.p7 * self.p3, self.p5 * self.p9) + + @staticmethod + def get_local_pixels_binary(binary_map: npt.NDArray, x: int, y: int) -> npt.NDArray: + """ + Value of pixels in the local 8-connectivity area around the coordinate (P1) described by x and y. + + P1 must not lie on the edge of the binary map. + + [[p7, p8, p9], [[0,1,2], + [p6, P1, p2], -> [3,4,5], -> [0,1,2,3,5,6,7,8] + [p5, p4, p3]] [6,7,8]] + + delete P1 to only get local area. + + Parameters + ---------- + binary_map : npt.NDArray + Binary mask of image. + x : int + X coordinate within the binary map. + y : int + Y coordinate within the binary map. + + Returns + ------- + npt.NDArray + Flattened 8-long array describing the values in the binary map around the x,y point. + """ + local_pixels = binary_map[x - 1 : x + 2, y - 1 : y + 2].flatten() + return np.delete(local_pixels, 4) + + @staticmethod + def sort_and_shuffle(arr: npt.NDArray, seed: int = 23790101) -> tuple[npt.NDArray, npt.NDArray]: + """ + Sort array in ascending order and shuffle the order of identical values are the same. + + Parameters + ---------- + arr : npt.NDArray + A flattened (1D) array. + seed : int + Seed for random number generator. + + Returns + ------- + npt.NDArray + An ascending order array where identical value orders are also shuffled. + npt.NDArray + An ascending order index array of above where identical value orders are also shuffled. + """ + # Find unique values + unique_values_r = np.unique(arr) + + rng = np.random.default_rng(seed) + + # Shuffle the order of elements with the same value + sorted_and_shuffled_indices: list = [] + for val in unique_values_r: + indices = np.where(arr == val)[0] + rng.shuffle(indices) + sorted_and_shuffled_indices.extend(indices) + + # Rearrange the sorted array according to shuffled indices + sorted_and_shuffled_arr: list = arr[sorted_and_shuffled_indices] + + return sorted_and_shuffled_arr, sorted_and_shuffled_indices diff --git a/topostats/tracing/splining.py b/topostats/tracing/splining.py new file mode 100644 index 0000000000..22c9266c94 --- /dev/null +++ b/topostats/tracing/splining.py @@ -0,0 +1,648 @@ +"""Order single pixel skeletons with or without NodeStats Statistics.""" + +from __future__ import annotations + +import logging +import math + +import numpy as np +import numpy.typing as npt +import pandas as pd +from scipy import interpolate as interp + +from topostats.logs.logs import LOGGER_NAME + +LOGGER = logging.getLogger(LOGGER_NAME) + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-instance-attributes +# pylint: disable=too-many-locals +# pylint: disable=too-many-positional-arguments + + +# pylint: disable=too-many-instance-attributes +class splineTrace: + # pylint: disable=too-many-instance-attributes + """ + Smooth the ordered trace via an average of splines. + + Parameters + ---------- + image : npt.NDArray + Whole image containing all molecules and grains. + mol_ordered_tracing_data : dict + Molecule ordered trace dictionary containing Nx2 ordered coords and molecule statistics. + pixel_to_nm_scaling : float + The pixel to nm scaling factor, by default 1. + spline_step_size : float + Step length in meters to use a coordinate for splining. + spline_linear_smoothing : float + Amount of linear spline smoothing. + spline_circular_smoothing : float + Amount of circular spline smoothing. + spline_degree : int + Degree of the spline. Cubic splines are recommended. Even values of k should be avoided especially with a + small s-value. + """ + + # pylint: disable=too-many-arguments + def __init__( + self, + image: npt.NDArray, + mol_ordered_tracing_data: dict, + pixel_to_nm_scaling: float, + spline_step_size: float, + spline_linear_smoothing: float, + spline_circular_smoothing: float, + spline_degree: int, + ) -> None: + # pylint: disable=too-many-arguments + """ + Initialise the splineTrace class. + + Parameters + ---------- + image : npt.NDArray + Whole image containing all molecules and grains. + mol_ordered_tracing_data : dict + Nx2 ordered trace coordinates. + pixel_to_nm_scaling : float + The pixel to nm scaling factor, by default 1. + spline_step_size : float + Step length in meters to use a coordinate for splining. + spline_linear_smoothing : float + Amount of linear spline smoothing. + spline_circular_smoothing : float + Amount of circular spline smoothing. + spline_degree : int + Degree of the spline. Cubic splines are recommended. Even values of k should be avoided especially with a + small s-value. + """ + self.image = image + self.number_of_rows, self.number_of_columns = image.shape + self.mol_ordered_trace = mol_ordered_tracing_data["ordered_coords"] + self.mol_is_circular = mol_ordered_tracing_data["mol_stats"]["circular"] + self.pixel_to_nm_scaling = pixel_to_nm_scaling + self.spline_step_size = spline_step_size + self.spline_linear_smoothing = spline_linear_smoothing + self.spline_circular_smoothing = spline_circular_smoothing + self.spline_degree = spline_degree + + self.tracing_stats = { + "contour_length": None, + "end_to_end_distance": None, + } + + def get_splined_traces( + self, + fitted_trace: npt.NDArray, + ) -> npt.NDArray: + """ + Get a splined version of the fitted trace - useful for finding the radius of gyration etc. + + This function actually calculates the average of several splines which is important for getting a good fit on + the lower resolution data. + + Parameters + ---------- + fitted_trace : npt.NDArray + Numpy array of the fitted trace. + + Returns + ------- + npt.NDArray + Splined (smoothed) array of trace. + """ + # Calculate the step size in pixels from the step size in metres. + # Should always be at least 1. + # Note that step_size_m is in m and pixel_to_nm_scaling is in m because of the legacy code which seems to almost + # always have pixel_to_nm_scaling be set in metres using the flag convert_nm_to_m. No idea why this is the case. + step_size_px = max(int(self.spline_step_size / (self.pixel_to_nm_scaling * 1e-9)), 1) + # Splines will be totalled and then divived by number of splines to calculate the average spline + spline_sum = None + + # Get the length of the fitted trace + fitted_trace_length = fitted_trace.shape[0] + + # If the fitted trace is less than the degree plus one, then there is no + # point in trying to spline it, just return the fitted trace + if fitted_trace_length < self.spline_degree + 1: + LOGGER.debug( + f"Fitted trace for grain {step_size_px} too small ({fitted_trace_length}), returning fitted trace" + ) + + return fitted_trace + + # There cannot be fewer than degree + 1 points in the spline + # Decrease the step size to ensure more than this number of points + while fitted_trace_length / step_size_px < self.spline_degree + 1: + # Step size cannot be less than 1 + if step_size_px <= 1: + step_size_px = 1 + break + step_size_px = -1 + + # Set smoothness and periodicity appropriately for linear / circular molecules. + spline_smoothness, spline_periodicity = ( + (self.spline_circular_smoothing, 2) if self.mol_is_circular else (self.spline_linear_smoothing, 0) + ) + + # Create an array of evenly spaced points between 0 and 1 for the splines to be evaluated at. + # This is needed to ensure that the splines are all the same length as the number of points + # in the spline is controlled by the ev_array variable. + ev_array = np.linspace(0, 1, fitted_trace_length * step_size_px) + + # Find as many splines as there are steps in step size, this allows for a better spline to be obtained + # by averaging the splines. Think of this like weaving a lot of splines together along the course of + # the trace. Example spline coordinate indexes: [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], where spline + # 1 takes every 4th coordinate, starting at position 0, then spline 2 takes every 4th coordinate + # starting at position 1, etc... + for i in range(step_size_px): + # Sample the fitted trace at every step_size_px pixels + sampled = [fitted_trace[j, :] for j in range(i, fitted_trace_length, step_size_px)] + + # Scipy.splprep cannot handle duplicate consecutive x, y tuples, so remove them. + # Get rid of any consecutive duplicates in the sampled coordinates + sampled = self.remove_duplicate_consecutive_tuples(tuple_list=sampled) + + x_sampled = sampled[:, 0] + y_sampled = sampled[:, 1] + + # Use scipy's B-spline functions + # tck is a tuple, (t,c,k) containing the vector of knots, the B-spline coefficients + # and the degree of the spline. + # s is the smoothing factor, per is the periodicity, k is the degree of the spline + tck = interp.splprep( + [x_sampled, y_sampled], + s=spline_smoothness, + per=spline_periodicity, + # quiet=self.spline_quiet, + k=self.spline_degree, + )[0] + # splev returns a tuple (x_coords ,y_coords) containing the smoothed coordinates of the + # spline, constructed from the B-spline coefficients and knots. The number of points in + # the spline is controlled by the ev_array variable. + # ev_array is an array of evenly spaced points between 0 and 1. + # This is to ensure that the splines are all the same length. + # Tck simply provides the coefficients for the spline. + out = interp.splev(ev_array, tck) + splined_trace = np.column_stack((out[0], out[1])) + + # Add the splined trace to the spline_sum array for averaging later + if spline_sum is None: + spline_sum = np.array(splined_trace) + else: + spline_sum = np.add(spline_sum, splined_trace) + + # Find the average spline between the set of splines + # This is an attempt to find a better spline by averaging our candidates + + return np.divide(spline_sum, [step_size_px, step_size_px]) + + @staticmethod + # Perhaps we need a module for array functions? + def remove_duplicate_consecutive_tuples(tuple_list: list[tuple | npt.NDArray]) -> list[tuple]: + """ + Remove duplicate consecutive tuples from a list. + + Parameters + ---------- + tuple_list : list[tuple | npt.NDArray] + List of tuples or numpy ndarrays to remove consecutive duplicates from. + + Returns + ------- + list[Tuple] + List of tuples with consecutive duplicates removed. + + Examples + -------- + For the list of tuples [(1, 2), (1, 2), (1, 2), (2, 3), (2, 3), (3, 4)], this function will return + [(1, 2), (2, 3), (3, 4)] + """ + duplicates_removed = [] + for index, tup in enumerate(tuple_list): + if index == 0 or not np.array_equal(tuple_list[index - 1], tup): + duplicates_removed.append(tup) + return np.array(duplicates_removed) + + def run_spline_trace(self) -> tuple[npt.NDArray, dict]: + """ + Pipeline to run the splining smoothing and obtaining smoothing stats. + + Returns + ------- + tuple[npt.NDArray, dict] + Tuple of Nx2 smoothed trace coordinates, and smoothed trace statistics. + """ + # fitted trace + # fitted_trace = self.get_fitted_traces(self.ordered_trace, mol_is_circular) + # splined trace + splined_trace = self.get_splined_traces(self.mol_ordered_trace) + # compile CL & E2E distance + self.tracing_stats["contour_length"] = ( + measure_contour_length(splined_trace, self.mol_is_circular, self.pixel_to_nm_scaling) * 1e-9 + ) + self.tracing_stats["end_to_end_distance"] = ( + measure_end_to_end_distance(splined_trace, self.mol_is_circular, self.pixel_to_nm_scaling) * 1e-9 + ) + + return splined_trace, self.tracing_stats + + +class windowTrace: + """ + Obtain a smoothed trace of a molecule. + + Parameters + ---------- + mol_ordered_tracing_data : dict + Molecule ordered trace dictionary containing Nx2 ordered coords and molecule statistics. + pixel_to_nm_scaling : float, optional + The pixel to nm scaling factor, by default 1. + rolling_window_size : np.float64, optional + The length of the rolling window too average over, by default 6.0. + """ + + def __init__( + self, + mol_ordered_tracing_data: dict, + pixel_to_nm_scaling: float, + rolling_window_size: float, + ) -> None: + """ + Initialise the windowTrace class. + + Parameters + ---------- + mol_ordered_tracing_data : dict + Molecule ordered trace dictionary containing Nx2 ordered coords and molecule statistics. + pixel_to_nm_scaling : float, optional + The pixel to nm scaling factor, by default 1. + rolling_window_size : np.float64, optional + The length of the rolling window too average over, by default 6.0. + """ + self.mol_ordered_trace = mol_ordered_tracing_data["ordered_coords"] + self.mol_is_circular = mol_ordered_tracing_data["mol_stats"]["circular"] + self.pixel_to_nm_scaling = pixel_to_nm_scaling + self.rolling_window_size = rolling_window_size / 1e-9 # for nm scaling factor + + self.tracing_stats = { + "contour_length": None, + "end_to_end_distance": None, + } + + @staticmethod + def pool_trace_circular( + pixel_trace: npt.NDArray[np.int32], rolling_window_size: np.float64 = 6.0, pixel_to_nm_scaling: float = 1 + ) -> npt.NDArray[np.float64]: + """ + Smooth a pixelwise ordered trace of circular molecules via a sliding window. + + Parameters + ---------- + pixel_trace : npt.NDArray[np.int32] + Nx2 ordered trace coordinates. + rolling_window_size : np.float64, optional + The length of the rolling window too average over, by default 6.0. + pixel_to_nm_scaling : float, optional + The pixel to nm scaling factor, by default 1. + + Returns + ------- + npt.NDArray[np.float64] + MxN Smoothed ordered trace coordinates. + """ + # Pool the trace points + pooled_trace = [] + + for i in range(len(pixel_trace)): + binned_points = [] + current_length = 0 + j = 1 + + # compile rolling window + while current_length < rolling_window_size: + current_index = i + j + previous_index = i + j - 1 + while current_index >= len(pixel_trace): + current_index -= len(pixel_trace) + while previous_index >= len(pixel_trace): + previous_index -= len(pixel_trace) + current_length += ( + np.linalg.norm(pixel_trace[current_index] - pixel_trace[previous_index]) * pixel_to_nm_scaling + ) + binned_points.append(pixel_trace[current_index]) + j += 1 + + # Get the mean of the binned points + pooled_trace.append(np.mean(binned_points, axis=0)) + + return np.array(pooled_trace) + + @staticmethod + def pool_trace_linear( + pixel_trace: npt.NDArray[np.int32], rolling_window_size: np.float64 = 6.0, pixel_to_nm_scaling: float = 1 + ) -> npt.NDArray[np.float64]: + """ + Smooth a pixelwise ordered trace of linear molecules via a sliding window. + + Parameters + ---------- + pixel_trace : npt.NDArray[np.int32] + Nx2 ordered trace coordinates. + rolling_window_size : np.float64, optional + The length of the rolling window too average over, by default 6.0. + pixel_to_nm_scaling : float, optional + The pixel to nm scaling factor, by default 1. + + Returns + ------- + npt.NDArray[np.float64] + MxN Smoothed ordered trace coordinates. + """ + pooled_trace = [pixel_trace[0]] # Add first coord as to not cut it off + + # Get average point for trace in rolling window + for i in range(0, len(pixel_trace)): + binned_points = [] + current_length = 0 + j = 0 + # Compile rolling window + while current_length < rolling_window_size: + current_index = i + j + previous_index = i + j - 1 + if current_index >= len(pixel_trace): # exit if exceeding the trace + break + current_length += ( + np.linalg.norm(pixel_trace[current_index] - pixel_trace[previous_index]) * pixel_to_nm_scaling + ) + binned_points.append(pixel_trace[current_index]) + j += 1 + else: + # Get the mean of the binned points + pooled_trace.append(np.mean(binned_points, axis=0)) + + # Exit if reached the end of the trace + if current_index + 1 >= len(pixel_trace): + break + + pooled_trace.append(pixel_trace[-1]) # Add last coord as to not cut it off + + # Check if the first two points are the same and remove the first point if they are + # This can happen if the algorithm happens to add the first point naturally due to having a small + # rolling window size. + if np.array_equal(pooled_trace[0], pooled_trace[1]): + pooled_trace.pop(0) + + # Check if the last two points are the same and remove the last point if they are + # This can happen if the algorithm happens to add the last point naturally due to having a small + # rolling window size. + if np.array_equal(pooled_trace[-1], pooled_trace[-2]): + pooled_trace.pop(-1) + + return np.array(pooled_trace) + + def run_window_trace(self) -> tuple[npt.NDArray, dict]: + """ + Pipeline to run the rolling window smoothing and obtaining smoothing stats. + + Returns + ------- + tuple[npt.NDArray, dict] + Tuple of Nx2 smoothed trace coordinates, and smoothed trace statistics. + """ + # fitted trace + # fitted_trace = self.get_fitted_traces(self.ordered_trace, mol_is_circular) + # splined trace + if self.mol_is_circular: + splined_trace = self.pool_trace_circular( + self.mol_ordered_trace, self.rolling_window_size, self.pixel_to_nm_scaling + ) + else: + splined_trace = self.pool_trace_linear( + self.mol_ordered_trace, self.rolling_window_size, self.pixel_to_nm_scaling + ) + # compile CL & E2E distance + self.tracing_stats["contour_length"] = ( + measure_contour_length(splined_trace, self.mol_is_circular, self.pixel_to_nm_scaling) * 1e-9 + ) + self.tracing_stats["end_to_end_distance"] = ( + measure_end_to_end_distance(splined_trace, self.mol_is_circular, self.pixel_to_nm_scaling) * 1e-9 + ) + + return splined_trace, self.tracing_stats + + +def measure_contour_length(splined_trace: npt.NDArray, mol_is_circular: bool, pixel_to_nm_scaling: float) -> float: + """ + Contour length for each of the splined traces accounting for whether the molecule is circular or linear. + + Contour length units are nm. + + Parameters + ---------- + splined_trace : npt.NDArray + The splined trace. + mol_is_circular : bool + Whether the molecule is circular or not. + pixel_to_nm_scaling : float + Scaling factor from pixels to nanometres. + + Returns + ------- + float + Length of molecule in nanometres (nm). + """ + if mol_is_circular: + for num in range(len(splined_trace)): + x1 = splined_trace[num - 1, 0] + y1 = splined_trace[num - 1, 1] + x2 = splined_trace[num, 0] + y2 = splined_trace[num, 1] + + try: + hypotenuse_array.append(math.hypot((x1 - x2), (y1 - y2))) + except NameError: + hypotenuse_array = [math.hypot((x1 - x2), (y1 - y2))] + + contour_length = np.sum(np.array(hypotenuse_array)) * pixel_to_nm_scaling + del hypotenuse_array + + else: + for num in range(len(splined_trace)): + try: + x1 = splined_trace[num, 0] + y1 = splined_trace[num, 1] + x2 = splined_trace[num + 1, 0] + y2 = splined_trace[num + 1, 1] + + try: + hypotenuse_array.append(math.hypot((x1 - x2), (y1 - y2))) + except NameError: + hypotenuse_array = [math.hypot((x1 - x2), (y1 - y2))] + except IndexError: # IndexError happens at last point in array + contour_length = np.sum(np.array(hypotenuse_array)) * pixel_to_nm_scaling + del hypotenuse_array + break + return contour_length + + +def measure_end_to_end_distance(splined_trace, mol_is_circular, pixel_to_nm_scaling: float): + """ + Euclidean distance between the start and end of linear molecules. + + The hypotenuse is calculated between the start ([0,0], [0,1]) and end ([-1,0], [-1,1]) of linear + molecules. If the molecule is circular then the distance is set to zero (0). + + Parameters + ---------- + splined_trace : npt.NDArray + The splined trace. + mol_is_circular : bool + Whether the molecule is circular or not. + pixel_to_nm_scaling : float + Scaling factor from pixels to nanometres. + + Returns + ------- + float + Length of molecule in nanometres (nm). + """ + if not mol_is_circular: + return ( + math.hypot((splined_trace[0, 0] - splined_trace[-1, 0]), (splined_trace[0, 1] - splined_trace[-1, 1])) + * pixel_to_nm_scaling + ) + return 0 + + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-locals +def splining_image( + image: npt.NDArray, + ordered_tracing_direction_data: dict, + pixel_to_nm_scaling: float, + filename: str, + method: str, + rolling_window_size: float, + spline_step_size: float, + spline_linear_smoothing: float, + spline_circular_smoothing: float, + spline_degree: int, +) -> tuple[dict, pd.DataFrame, pd.DataFrame]: + # pylint: disable=too-many-arguments + # pylint: disable=too-many-locals + """ + Obtain smoothed traces of pixel-wise ordered traces for molecules in an image. + + Parameters + ---------- + image : npt.NDArray + Whole image containing all molecules and grains. + ordered_tracing_direction_data : dict + Dictionary result from the ordered traces. + pixel_to_nm_scaling : float + Scaling factor from pixels to nanometres. + filename : str + Name of the image file. + method : str + Method of trace smoothing, options are 'splining' and 'rolling_window'. + rolling_window_size : float + Length in meters to average coordinates over in the rolling window. + spline_step_size : float + Step length in meters to use a coordinate for splining. + spline_linear_smoothing : float + Amount of linear spline smoothing. + spline_circular_smoothing : float + Amount of circular spline smoothing. + spline_degree : int + Degree of the spline. Cubic splines are recommended. Even values of k should be avoided especially with a + small s-value. + + Returns + ------- + tuple[dict, pd.DataFrame, pd.DataFrame] + A spline data dictionary for all molecules, and a grainstats dataframe additions dataframe and molecule + statistics dataframe. + """ + grainstats_additions = {} + molstats = {} + all_splines_data = {} + + mol_count = 0 + for mol_trace_data in ordered_tracing_direction_data.values(): + mol_count += len(mol_trace_data) + LOGGER.info(f"[{filename}] : Calculating Splining statistics for {mol_count} molecules...") + + # iterate through disordered_tracing_dict + for grain_no, ordered_grain_data in ordered_tracing_direction_data.items(): + grain_trace_stats = {"total_contour_length": 0, "average_end_to_end_distance": 0} + all_splines_data[grain_no] = {} + mol_no = None + for mol_no, mol_trace_data in ordered_grain_data.items(): + try: + LOGGER.debug(f"[{filename}] : Splining {grain_no} - {mol_no}") + # check if want to do nodestats tracing or not + if method == "rolling_window": + splined_data, tracing_stats = windowTrace( + mol_ordered_tracing_data=mol_trace_data, + pixel_to_nm_scaling=pixel_to_nm_scaling, + rolling_window_size=rolling_window_size, + ).run_window_trace() + + # if not doing nodestats ordering, do original TS ordering + else: # method == "spline": + splined_data, tracing_stats = splineTrace( + image=image, + mol_ordered_tracing_data=mol_trace_data, + pixel_to_nm_scaling=pixel_to_nm_scaling, + spline_step_size=spline_step_size, + spline_linear_smoothing=spline_linear_smoothing, + spline_circular_smoothing=spline_circular_smoothing, + spline_degree=spline_degree, + ).run_spline_trace() + + # get combined stats for the grains + grain_trace_stats["total_contour_length"] += tracing_stats["contour_length"] + grain_trace_stats["average_end_to_end_distance"] += tracing_stats["end_to_end_distance"] + + # get individual mol stats + all_splines_data[grain_no][mol_no] = { + "spline_coords": splined_data, + "bbox": mol_trace_data["bbox"], + "tracing_stats": tracing_stats, + } + molstats[grain_no.split("_")[-1] + "_" + mol_no.split("_")[-1]] = { + "image": filename, + "grain_number": int(grain_no.split("_")[-1]), + "molecule_number": int(mol_no.split("_")[-1]), + } + molstats[grain_no.split("_")[-1] + "_" + mol_no.split("_")[-1]].update(tracing_stats) + LOGGER.debug(f"[{filename}] : Finished splining {grain_no} - {mol_no}") + + except Exception as e: # pylint: disable=broad-exception-caught + LOGGER.error( + f"[{filename}] : Splining for {grain_no} failed. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) + all_splines_data[grain_no] = {} + + if mol_no is None: + LOGGER.warning(f"[{filename}] : No molecules found for grain {grain_no}") + else: + # average the e2e dists -> mol_no should always be in the grain dict + grain_trace_stats["average_end_to_end_distance"] /= len(ordered_grain_data) + + # compile metrics + grainstats_additions[grain_no] = { + "image": filename, + "grain_number": int(grain_no.split("_")[-1]), + } + grainstats_additions[grain_no].update(grain_trace_stats) + + # convert grainstats metrics to dataframe + splining_stats_df = pd.DataFrame.from_dict(grainstats_additions, orient="index") + molstats_df = pd.DataFrame.from_dict(molstats, orient="index") + molstats_df.reset_index(drop=True, inplace=True) + return all_splines_data, splining_stats_df, molstats_df diff --git a/topostats/tracing/tracingfuncs.py b/topostats/tracing/tracingfuncs.py index 512f3dcb2a..97cc4c4e7d 100644 --- a/topostats/tracing/tracingfuncs.py +++ b/topostats/tracing/tracingfuncs.py @@ -1,847 +1,19 @@ """Miscellaneous tracing functions.""" +from __future__ import annotations import numpy as np import numpy.typing as npt import matplotlib.pyplot as plt import math +from topostats.utils import convolve_skeleton -class getSkeleton: - """ - Skeltonisation : "A Fast Parallel Algorithm for Thinning Digital Patterns" by Zhang et al., 1984. - Parameters - ---------- - image_data : npt.NDArray - Image to be traced. - binary_map : npt.NDArray - Image mask. - number_of_columns : int - Number of columns. - number_of_rows : int - Number of rows. - pixel_size : float - Pixel to nm scaling. +class reorderTrace: + """ + Class to aid the consecutive ordering of adjacent coordinates of a pixel grid. """ - def __init__( - self, - image_data: npt.NDArray, - binary_map: npt.NDArray, - number_of_columns: int, - number_of_rows: int, - pixel_size: float, - ) -> None: - """ - Initialise the class. - - Parameters - ---------- - image_data : npt.NDArray - Image to be traced. - binary_map : npt.NDArray - Image mask. - number_of_columns : int - Number of columns. - number_of_rows : int - Number of rows. - pixel_size : float - Pixel to nm scaling. - """ - self.image_data = image_data - self.binary_map = binary_map - self.number_of_columns = number_of_columns - self.number_of_rows = number_of_rows - self.pixel_size = pixel_size - - self.p2 = 0 - self.p3 = 0 - self.p4 = 0 - self.p5 = 0 - self.p6 = 0 - self.p7 = 0 - self.p8 = 0 - - # skeletonising variables - self.mask_being_skeletonised = [] - self.output_skeleton = [] - self.skeleton_converged = False - self.pruning = True - - # Height checking variables - self.average_height = 0 - # self.cropping_dict = self._initialiseHeightFindingDict() - self.highest_points = {} - self.search_window = int(3 / (pixel_size * 1e9)) - # Check that the search window is bigger than 0: - if self.search_window < 2: - self.search_window = 3 - self.dir_search = int(0.75 / (pixel_size * 1e9)) - if self.dir_search < 3: - self.dir_search = 3 - - self.getDNAmolHeightStats() - self.doSkeletonising() - - def getDNAmolHeightStats(self): - """Get molecule heights.""" - coordinates = np.argwhere(self.binary_map == 1) - flat_indices = np.ravel_multi_index(coordinates.T, self.image_data.shape) - heights = self.image_data.flat[flat_indices] - self.average_height = np.average(heights) - - def doSkeletonising(self): - """Check if the skeletonising is finished.""" - - self.mask_being_skeletonised = self.binary_map - - while not self.skeleton_converged: - self._doSkeletonisingIteration() - - # When skeleton converged do an additional iteration of thinning to remove hanging points - self.finalSkeletonisationIteration() - - self.pruning = True - while self.pruning: - self.pruneSkeleton() - - self.output_skeleton = np.argwhere(self.mask_being_skeletonised == 1) - - def _doSkeletonisingIteration(self): - """ - Do an iteration of skeletonisation. - - Check for the local binary pixel environment and assess the local height values to decide whether to delete a - point. - """ - - number_of_deleted_points = 0 - pixels_to_delete = [] - - # Sub-iteration 1 - binary check - mask_coordinates = np.argwhere(self.mask_being_skeletonised == 1).tolist() - for point in mask_coordinates: - if self._deletePixelSubit1(point): - pixels_to_delete.append(point) - - # Check the local height values to determine if pixels should be deleted - # pixels_to_delete = self._checkHeights(pixels_to_delete) - - for x, y in pixels_to_delete: - number_of_deleted_points += 1 - self.mask_being_skeletonised[x, y] = 0 - pixels_to_delete = [] - - # Sub-iteration 2 - binary check - mask_coordinates = np.argwhere(self.mask_being_skeletonised == 1).tolist() - for point in mask_coordinates: - if self._deletePixelSubit2(point): - pixels_to_delete.append(point) - - # Check the local height values to determine if pixels should be deleted - # pixels_to_delete = self._checkHeights(pixels_to_delete) - - for x, y in pixels_to_delete: - number_of_deleted_points += 1 - self.mask_being_skeletonised[x, y] = 0 - - if number_of_deleted_points == 0: - self.skeleton_converged = True - - def _deletePixelSubit1(self, point: npt.NDArray) -> bool: - """ - Check whether a point should be deleted based on local binary environment and local height values. - - Parameters - ---------- - point : npt.NDArray - Point to be checked. - - Returns - ------- - bool - Whether the point should be deleted. - """ - - self.p2, self.p3, self.p4, self.p5, self.p6, self.p7, self.p8, self.p9 = genTracingFuncs.getLocalPixelsBinary( - self.mask_being_skeletonised, point[0], point[1] - ) - - if ( - self._binaryThinCheck_a() - and self._binaryThinCheck_b() - and self._binaryThinCheck_c() - and self._binaryThinCheck_d() - ): - return True - else: - return False - - def _deletePixelSubit2(self, point: npt.NDArray) -> bool: - """ - Check whether a point should be deleted based on local binary environment and local height values. - - Parameters - ---------- - point : npt.NDArray - Point to be checked. - - Returns - ------- - bool - Whether the point should be deleted. - """ - - self.p2, self.p3, self.p4, self.p5, self.p6, self.p7, self.p8, self.p9 = genTracingFuncs.getLocalPixelsBinary( - self.mask_being_skeletonised, point[0], point[1] - ) - - # Add in generic code here to protect high points from being deleted - if ( - self._binaryThinCheck_a() - and self._binaryThinCheck_b() - and self._binaryThinCheck_csharp() - and self._binaryThinCheck_dsharp() - ): - return True - else: - return False - - """These functions are ripped from the Zhang et al. paper and do the basic - skeletonisation steps - - I can use the information from the c,d,c' and d' tests to determine a good - direction to search for higher height values """ - - def _binaryThinCheck_a(self) -> bool: - """ - Binary thin check A. - - Returns - ------- - bool: - Whether the condition is met. - """ - # Condition A protects the endpoints (which will be > 2) - add in code here to prune low height points - if 2 <= self.p2 + self.p3 + self.p4 + self.p5 + self.p6 + self.p7 + self.p8 + self.p9 <= 6: - return True - else: - return False - - def _binaryThinCheck_b(self) -> bool: - """ - Binary thin check B. - - Returns - ------- - bool: - Whether the condition is met.""" - count = 0 - - if [self.p2, self.p3] == [0, 1]: - count += 1 - if [self.p3, self.p4] == [0, 1]: - count += 1 - if [self.p4, self.p5] == [0, 1]: - count += 1 - if [self.p5, self.p6] == [0, 1]: - count += 1 - if [self.p6, self.p7] == [0, 1]: - count += 1 - if [self.p7, self.p8] == [0, 1]: - count += 1 - if [self.p8, self.p9] == [0, 1]: - count += 1 - if [self.p9, self.p2] == [0, 1]: - count += 1 - - if count == 1: - return True - else: - return False - - def _binaryThinCheck_c(self) -> bool: - """ - Binary thin check C. - - Returns - ------- - bool: - Whether the condition is met. - """ - if self.p2 * self.p4 * self.p6 == 0: - return True - else: - return False - - def _binaryThinCheck_d(self) -> bool: - """ - Binary thin check D. - - Returns - ------- - bool: - Whether the condition is met. - """ - if self.p4 * self.p6 * self.p8 == 0: - return True - else: - return False - - def _binaryThinCheck_csharp(self) -> bool: - """ - Binary thin check C#. - - Returns - ------- - bool: - Whether the condition is met. - """ - if self.p2 * self.p4 * self.p8 == 0: - return True - else: - return False - - def _binaryThinCheck_dsharp(self) -> bool: - """ - Binary thin check D# - - Returns - ------- - bool: - Whether the condition is met. - """ - if self.p2 * self.p6 * self.p8 == 0: - return True - else: - return False - - def _checkHeights(self, candidate_points: npt.NDArray) -> npt.NDArray: - """Check heights. - - Parameters - ---------- - candidate_points : npt.NDArray) - > npt.NDArra - Candidate points to be checked. - - Returns - ------- - npt.NDArray - Candidate points. - """ - try: - candidate_points = candidate_points.tolist() - except AttributeError: - pass - - for x, y in candidate_points: - # if point is basically at background don't bother assessing height and just delete: - if self.image_data[x, y] < 1e-9: - continue - - # Check if the point has already been identified as a high point - try: - self.highest_points[(x, y)] - candidate_points.pop(candidate_points.index([x, y])) - # print(x,y) - continue - except KeyError: - pass - - ( - self.p2, - self.p3, - self.p4, - self.p5, - self.p6, - self.p7, - self.p8, - self.p9, - ) = genTracingFuncs.getLocalPixelsBinary(self.mask_being_skeletonised, x, y) - - print([self.p9, self.p2, self.p3], [self.p8, 1, self.p4], [self.p7, self.p6, self.p5]) - - height_points_to_check = self._checkWhichHeightPoints() - height_points = np.around(self.cropping_dict[height_points_to_check](x, y), decimals=11) - test_value = np.around(self.image_data[x, y], decimals=11) - # print(height_points_to_check, [x,y], self.image_data[x,y], height_points) - - # if the candidate points is the highest local point don't delete it - if test_value >= sorted(height_points)[-1]: - print([self.p9, self.p2, self.p3], [self.p8, 1, self.p4], [self.p7, self.p6, self.p5]) - print(height_points_to_check, [x, y], self.image_data[x, y], height_points) - self.highest_points[(x, y)] = height_points_to_check - candidate_points.pop(candidate_points.index([x, y])) - print(height_points_to_check, (x, y)) - else: - x_n, y_n = self._identifyHighestPoint(x, y, height_points_to_check, height_points) - self.highest_points[(x_n, y_n)] = height_points_to_check - pass - - return candidate_points - - def _checkWhichHeightPoints(self): - """Check which height points.""" - # Is the point on the left hand edge? - # if (self.p8 == 1 and self.p4 == 0 and self.p2 == self.p6): - if self.p7 + self.p8 + self.p9 == 3 and self.p3 + self.p4 + self.p5 == 0 and self.p2 == self.p6: - """e.g. [1, 1, 0] - [1, 1, 0] - [1, 1, 0]""" - return "horiz_left" - # elif (self.p8 == 0 and self.p4 == 1 and self.p2 == self.p6): - elif self.p7 + self.p8 + self.p9 == 0 and self.p3 + self.p4 + self.p5 == 3 and self.p2 == self.p6: - """e.g. [0, 1, 1] - [0, 1, 1] - [0, 1, 1]""" - return "horiz_right" - # elif (self.p2 == 1 and self.p6 == 0 and self.p4 == self.p8): - elif self.p9 + self.p2 + self.p3 == 3 and self.p5 + self.p6 + self.p7 == 0 and self.p4 == self.p8: - """e.g. [1, 1, 1] - [1, 1, 1] - [0, 0, 0]""" - return "vert_up" - # elif (self.p2 == 0 and self.p6 == 1 and self.p4 == self.p8): - elif ( - self.p9 + self.p2 + self.p3 == 0 and self.p5 + self.p6 + self.p7 == 3 and self.p4 == self.p8 - ): # and self.p4 == self.p8): - """e.g. [0, 0, 0] - [1, 1, 1] - [1, 1, 1]""" - return "vert_down" - elif self.p2 + self.p8 <= 1 and self.p4 + self.p5 + self.p6 >= 2: - """e.g. [0, 0, 1] [0, 0, 0] - [0, 1, 1] [0, 1, 1] - [1, 1, 1] or [0, 1, 1]""" - return "diagright_down" - elif self.p4 + self.p6 <= 1 and self.p8 + self.p9 + self.p2 >= 2: - """e.g. [1, 1, 1] [1, 1, 0] - [1, 1, 0] [1, 1, 0] - [1, 0, 0] or [0, 0, 0]""" - return "diagright_up" - elif self.p2 + self.p4 <= 1 and self.p8 + self.p7 + self.p6 >= 2: - """e.g. [1, 0, 0] [0, 0, 0] - [1, 1, 0] [1, 1, 0] - [1, 1, 1] or [1, 1, 0]""" - return "diagleft_down" - elif self.p8 + self.p6 <= 1 and self.p2 + self.p3 + self.p4 >= 2: - """e.g. [1, 1, 1] [0, 1, 1] - [0, 1, 1] [0, 1, 1] - [0, 0, 1] or [0, 0, 0]""" - return "diagleft_up" - # else: - # return 'save' - - def _initialiseHeightFindingDict(self): - height_cropping_funcs = {} - - height_cropping_funcs["horiz_left"] = self._getHorizontalLeftHeights - height_cropping_funcs["horiz_right"] = self._getHorizontalRightHeights - height_cropping_funcs["vert_up"] = self._getVerticalUpwardHeights - height_cropping_funcs["vert_down"] = self._getVerticalDonwardHeights - height_cropping_funcs["diagleft_up"] = self._getDiaganolLeftUpwardHeights - height_cropping_funcs["diagleft_down"] = self._getDiaganolLeftDownwardHeights - height_cropping_funcs["diagright_up"] = self._getHorizontalRightHeights - height_cropping_funcs["diagright_down"] = self._getHorizontalRightHeights - height_cropping_funcs["save"] = self._savePoint - - return height_cropping_funcs - - def _getHorizontalLeftHeights(self, x: int, y: int) -> float: - """ - Calculate heights left (west). - - Parameters - ---------- - x: int - X coordinate. - y: int - Y coordinate. - - Returns - ------- - float - Height left (west). - """ - heights = [] # [self.image_data[x,y]] - - for i in range(-self.search_window, self.search_window): - if i == 0: - continue - heights.append(self.image_data[x - i, y]) - return heights - - def _getHorizontalRightHeights(self, x, y): - """ - Calculate heights right (east). - - Parameters - ---------- - x: int - X coordinate. - y: int - Y coordinate. - - Returns - ------- - float - Height right (east). - """ - heights = [] # [self.image_data[x,y]] - - for i in range(-self.search_window, self.search_window): - if i == 0: - continue - heights.append(self.image_data[x + i, y]) - return heights - - def _getVerticalUpwardHeights(self, x, y): - """ - Calculate heights upwards (north). - - Parameters - ---------- - x: int - X coordinate. - y: int - Y coordinate. - - Returns - ------- - float - Height upwards (north). - """ - heights = [] # [self.image_data[x,y]] - - for i in range(-self.search_window, self.search_window): - if i == 0: - continue - heights.append(self.image_data[x, y + i]) - return heights - - def _getVerticalDonwardHeights(self, x, y): - """ - Calculate heights downwards (south). - - Parameters - ---------- - x: int - X coordinate. - y: int - Y coordinate. - - Returns - ------- - float - Height downwards (south). - """ - heights = [] # [self.image_data[x,y]] - - for i in range(-self.search_window, self.search_window): - if i == 0: - continue - heights.append(self.image_data[x, y - i]) - return heights - - def _getDiaganolLeftUpwardHeights(self, x, y): - """ - Calculate heights diagonal left upwards (north east). - - Parameters - ---------- - x: int - X coordinate. - y: int - Y coordinate. - - Returns - ------- - float - Height to diagonal left upwards (north east). - """ - heights = [] # [self.image_data[x,y]] - - for i in range(-self.search_window, self.search_window): - if i == 0: - continue - heights.append(self.image_data[x + i, y + i]) - return heights - - def _getDiaganolLeftDownwardHeights(self, x, y): - """ - Calculate heights diagonal left downwards (south west). - - Parameters - ---------- - x: int - X coordinate. - y: int - Y coordinate. - - Returns - ------- - float - Height diagonal left downwards (south west). - """ - heights = [] # [self.image_data[x,y]] - - for i in range(-self.search_window, self.search_window): - if i == 0: - continue - heights.append(self.image_data[x - i, y - i]) - return heights - - def _getDiaganolRightUpwardHeights(self, x: int, y: int) -> float: - """ - Calculate heights diagonal right upwards (north east). - - Parameters - ---------- - x: int - X coordinate. - y: int - Y coordinate. - - Returns - ------- - float - Height diagonal right upwards (north east). - """ - heights = [] # [self.image_data[x,y]] - - for i in range(-self.search_window, self.search_window): - if i == 0: - continue - heights.append(self.image_data[x - i, y + i]) - return heights - - def _getDiaganolRightDownwardHeights(self, x, y): - """ - Calculate heights diagonal right downwards (south east). - - Parameters - ---------- - x: int - X coordinate. - y: int - Y coordinate. - - Returns - ------- - float - Height heights diagonal right downwards (south east). - """ - heights = [] # [self.image_data[x,y]] - - for i in range(-self.search_window, self.search_window): - if i == 0: - continue - heights.append(self.image_data[x + i, y - i]) - return heights - - def _condemnPoint(self, x: int, y: int) -> float: - """ - Condemn a point. - - Parameters - ---------- - x: int - X coordinate. - y: int - Y coordinate. - - Returns - ------- - float - Height to be condemned. - """ - heights = [] # [self.image_data[x,y]] - - for i in range(1, self.search_window): - heights.append(10) - return heights - - def _identifyHighestPoint(self, x, y, index_direction, indexed_heights): - highest_value = 0 - - offset = len(indexed_heights) / 2 - - for num, height_value in enumerate(indexed_heights): - if height_value > highest_value: - highest_point = height_value - index_position = (num + 1) - offset - - if index_direction == "horiz_left": - return x - num, y - elif index_direction == "horiz_right": - return x + num, y - elif index_direction == "vert_up": - return x, y + num - elif index_direction == "vert_down": - return x, y - num - elif index_direction == "diagleft_up": - return x + num, y + num - elif index_direction == "diagleft_down": - return x + num, y - num - elif index_direction == "diagright_up": - return x - num, y + num - elif index_direction == "diagright_down": - return x - num, y - num - - def finalSkeletonisationIteration(self): - """A final skeletonisation iteration that removes "hanging" pixels. - - Examples of such pixels are: - - [0, 0, 0] [0, 1, 0] [0, 0, 0] - [0, 1, 1] [0, 1, 1] [0, 1, 1] - case 1: [0, 1, 0] or case 2: [0, 1, 0] or case 3: [1, 1, 0] - - This is useful for the future functions that rely on local pixel environment - to make assessments about the overall shape/structure of traces""" - - remaining_coordinates = np.argwhere(self.mask_being_skeletonised).tolist() - - for x, y in remaining_coordinates: - ( - self.p2, - self.p3, - self.p4, - self.p5, - self.p6, - self.p7, - self.p8, - self.p9, - ) = genTracingFuncs.getLocalPixelsBinary(self.mask_being_skeletonised, x, y) - - # Checks for case 1 pixels - if self._binaryThinCheck_b_returncount() == 2 and self._binaryFinalThinCheck_a(): - self.mask_being_skeletonised[x, y] = 0 - # Checks for case 2 pixels - elif self._binaryThinCheck_b_returncount() == 3 and self._binaryFinalThinCheck_b(): - self.mask_being_skeletonised[x, y] = 0 - - def _binaryFinalThinCheck_a(self): - """Binary final thin check A.""" - if self.p2 * self.p4 == 1: - return True - elif self.p4 * self.p6 == 1: - return True - elif self.p6 * self.p8 == 1: - return True - elif self.p8 * self.p2 == 1: - return True - - def _binaryFinalThinCheck_b(self): - """Binary final thin check B.""" - if self.p2 * self.p4 * self.p6 == 1: - return True - elif self.p4 * self.p6 * self.p8 == 1: - return True - elif self.p6 * self.p8 * self.p2 == 1: - return True - elif self.p8 * self.p2 * self.p4 == 1: - return True - - def _binaryThinCheck_b_returncount(self): - """Binary final thin check B return count.""" - count = 0 - - if [self.p2, self.p3] == [0, 1]: - count += 1 - if [self.p3, self.p4] == [0, 1]: - count += 1 - if [self.p4, self.p5] == [0, 1]: - count += 1 - if [self.p5, self.p6] == [0, 1]: - count += 1 - if [self.p6, self.p7] == [0, 1]: - count += 1 - if [self.p7, self.p8] == [0, 1]: - count += 1 - if [self.p8, self.p9] == [0, 1]: - count += 1 - if [self.p9, self.p2] == [0, 1]: - count += 1 - - return count - - def pruneSkeleton(self): - """Function to remove the hanging branches from the skeletons. - - These are a persistent problem in the overall tracing process.""" - - number_of_branches = 0 - coordinates = np.argwhere(self.mask_being_skeletonised == 1).tolist() - - # The branches are typically short so if a branch is longer than a quarter - # of the total points its assumed to be part of the real data - length_of_trace = len(coordinates) - max_branch_length = int(length_of_trace * 0.15) - - # _deleteSquareEnds(coordinates) - - # first check to find all the end coordinates in the trace - potential_branch_ends = self._findBranchEnds(coordinates) - - # Now check if its a branch - and if it is delete it - for x_b, y_b in potential_branch_ends: - branch_coordinates = [[x_b, y_b]] - branch_continues = True - temp_coordinates = coordinates[:] - temp_coordinates.pop(temp_coordinates.index([x_b, y_b])) - - count = 0 - - while branch_continues: - no_of_neighbours, neighbours = genTracingFuncs.countandGetNeighbours(x_b, y_b, temp_coordinates) - - # If branch continues - if no_of_neighbours == 1: - x_b, y_b = neighbours[0] - branch_coordinates.append([x_b, y_b]) - temp_coordinates.pop(temp_coordinates.index([x_b, y_b])) - - # If the branch reaches the edge of the main trace - elif no_of_neighbours > 1: - branch_coordinates.pop(branch_coordinates.index([x_b, y_b])) - branch_continues = False - is_branch = True - # Weird case that happens sometimes - elif no_of_neighbours == 0: - is_branch = True - branch_continues = False - - if len(branch_coordinates) > max_branch_length: - branch_continues = False - is_branch = False - - if is_branch: - number_of_branches += 1 - for x, y in branch_coordinates: - self.mask_being_skeletonised[x, y] = 0 - - remaining_coordinates = np.argwhere(self.mask_being_skeletonised) - - if number_of_branches == 0: - self.pruning = False - - def _findBranchEnds(self, coordinates): - potential_branch_ends = [] - - # Most of the branch ends are just points with one neighbour - for x, y in coordinates: - if genTracingFuncs.countNeighbours(x, y, coordinates) == 1: - potential_branch_ends.append([x, y]) - # Find the ends that are 3/4 neighbouring points - return potential_branch_ends - - def _deleteSquareEnds(self, coordinates): - for x, y in coordinates: - pass - - -class reorderTrace: @staticmethod def linearTrace(trace_coordinates): """My own function to order the points from a linear trace. @@ -867,7 +39,7 @@ def linearTrace(trace_coordinates): # Find one of the end points for i, (x, y) in enumerate(trace_coordinates): - if genTracingFuncs.countNeighbours(x, y, trace_coordinates) == 1: + if genTracingFuncs.count_and_get_neighbours(x, y, trace_coordinates)[0] == 1: ordered_points = [[x, y]] trace_coordinates.pop(i) break @@ -880,7 +52,7 @@ def linearTrace(trace_coordinates): x_n, y_n = ordered_points[-1] # get the last point to be added to the array and find its neighbour - no_of_neighbours, neighbour_array = genTracingFuncs.countandGetNeighbours( + no_of_neighbours, neighbour_array = genTracingFuncs.count_and_get_neighbours( x_n, y_n, remaining_unordered_coords ) @@ -896,7 +68,7 @@ def linearTrace(trace_coordinates): remaining_unordered_coords.pop(remaining_unordered_coords.index(best_next_pixel)) continue elif no_of_neighbours == 0: - # nn, neighbour_array_all_coords = genTracingFuncs.countandGetNeighbours(x_n, y_n, trace_coordinates) + # nn, neighbour_array_all_coords = genTracingFuncs.count_and_get_neighbours(x_n, y_n, trace_coordinates) # best_next_pixel = genTracingFuncs.checkVectorsCandidatePoints(x_n, y_n, ordered_points, neighbour_array_all_coords) best_next_pixel = genTracingFuncs.findBestNextPoint( x_n, y_n, ordered_points, remaining_unordered_coords @@ -908,7 +80,7 @@ def linearTrace(trace_coordinates): ordered_points.append(best_next_pixel) # If the tracing has reached the other end of the trace then its finished - if genTracingFuncs.countNeighbours(x_n, y_n, trace_coordinates) == 1: + if genTracingFuncs.count_and_get_neighbours(x_n, y_n, trace_coordinates)[0] == 1: break return np.array(ordered_points) @@ -927,7 +99,7 @@ def circularTrace(trace_coordinates): # Find a sensible point to start of the end points for i, (x, y) in enumerate(trace_coordinates): - if genTracingFuncs.countNeighbours(x, y, trace_coordinates) == 2: + if genTracingFuncs.count_and_get_neighbours(x, y, trace_coordinates)[0] == 2: ordered_points = [[x, y]] remaining_unordered_coords.pop(i) break @@ -935,7 +107,9 @@ def circularTrace(trace_coordinates): # Randomly choose one of the neighbouring points as the next point x_n = ordered_points[0][0] y_n = ordered_points[0][1] - no_of_neighbours, neighbour_array = genTracingFuncs.countandGetNeighbours(x_n, y_n, remaining_unordered_coords) + no_of_neighbours, neighbour_array = genTracingFuncs.count_and_get_neighbours( + x_n, y_n, remaining_unordered_coords + ) ordered_points.append(neighbour_array[0]) remaining_unordered_coords.pop(remaining_unordered_coords.index(neighbour_array[0])) @@ -944,7 +118,7 @@ def circularTrace(trace_coordinates): while remaining_unordered_coords: x_n, y_n = ordered_points[-1] # get the last point to be added to the array and find its neighbour - no_of_neighbours, neighbour_array = genTracingFuncs.countandGetNeighbours( + no_of_neighbours, neighbour_array = genTracingFuncs.count_and_get_neighbours( x_n, y_n, remaining_unordered_coords ) @@ -975,7 +149,7 @@ def circularTrace(trace_coordinates): elif no_of_neighbours == 0: # Check if the tracing is finished - nn, neighbour_array_all_coords = genTracingFuncs.countandGetNeighbours(x_n, y_n, trace_coordinates) + nn, neighbour_array_all_coords = genTracingFuncs.count_and_get_neighbours(x_n, y_n, trace_coordinates) if ordered_points[0] in neighbour_array_all_coords: break @@ -1070,55 +244,7 @@ def getLocalPixelsBinary(binary_map, x, y): return p2, p3, p4, p5, p6, p7, p8, p9 @staticmethod - def countNeighbours(x, y, trace_coordinates): - """Counts the number of neighbouring points for a given coordinate in - a list of points""" - - number_of_neighbours = 0 - if [x, y + 1] in trace_coordinates: - number_of_neighbours += 1 - if [x + 1, y + 1] in trace_coordinates: - number_of_neighbours += 1 - if [x + 1, y] in trace_coordinates: - number_of_neighbours += 1 - if [x + 1, y - 1] in trace_coordinates: - number_of_neighbours += 1 - if [x, y - 1] in trace_coordinates: - number_of_neighbours += 1 - if [x - 1, y - 1] in trace_coordinates: - number_of_neighbours += 1 - if [x - 1, y] in trace_coordinates: - number_of_neighbours += 1 - if [x - 1, y + 1] in trace_coordinates: - number_of_neighbours += 1 - return number_of_neighbours - - @staticmethod - def getNeighbours(x, y, trace_coordinates): - """Returns an array containing the neighbouring points for a given - coordinate in a list of points""" - - neighbour_array = [] - if [x, y + 1] in trace_coordinates: - neighbour_array.append([x, y + 1]) - if [x + 1, y + 1] in trace_coordinates: - neighbour_array.append([x + 1, y + 1]) - if [x + 1, y] in trace_coordinates: - neighbour_array.append([x + 1, y]) - if [x + 1, y - 1] in trace_coordinates: - neighbour_array.append([x + 1, y - 1]) - if [x, y - 1] in trace_coordinates: - neighbour_array.append([x, y - 1]) - if [x - 1, y - 1] in trace_coordinates: - neighbour_array.append([x - 1, y - 1]) - if [x - 1, y] in trace_coordinates: - neighbour_array.append([x - 1, y]) - if [x - 1, y + 1] in trace_coordinates: - neighbour_array.append([x - 1, y + 1]) - return neighbour_array - - @staticmethod - def countandGetNeighbours(x, y, trace_coordinates): + def count_and_get_neighbours(x, y, trace_coordinates) -> tuple[int, list]: """Returns the number of neighbouring points for a coordinate and an array containing the those points""" @@ -1257,3 +383,150 @@ def checkVectorsCandidatePoints(x, y, ordered_points, candidate_points): ordered_x_y_theta = sorted(x_y_theta, key=lambda x: x[2]) return [ordered_x_y_theta[0][0], ordered_x_y_theta[0][1]] + + +def order_branch(binary_image: npt.NDArray, anchor: list): + """ + Order a linear branch by identifying an endpoint, and looking at the local area of the point to find the next. + + Parameters + ---------- + binary_image : npt.NDArray + A binary image of a skeleton segment to order it's points. + anchor : list + A list of 2 integers representing the coordinate to order the branch from the endpoint closest to this. + + Returns + ------- + npt.NDArray + An array of ordered coordinates. + """ + skel = binary_image.copy() + + if len(np.argwhere(skel == 1)) < 3: # if < 3 coords just return them + return np.argwhere(skel == 1) + + # get branch starts + endpoints_highlight = convolve_skeleton(skel) + endpoints = np.argwhere(endpoints_highlight == 2) + if len(endpoints) != 0: # if any endpoints, start closest to anchor + dist_vals = abs(endpoints - anchor).sum(axis=1) + start = endpoints[np.argmin(dist_vals)] + else: # will be circular so pick the first coord (is this always the case?) + start = np.argwhere(skel == 1)[0] + # order the points according to what is nearby + ordered = order_branch_from_start(skel, start) + + return np.array(ordered) + + +def order_branch_from_start( + nodeless: npt.NDArray, start: npt.NDArray, max_length: float | np.inf = np.inf +) -> npt.NDArray: + """ + Order an unbranching skeleton from an end (startpoint) along a specified length. + + Parameters + ---------- + nodeless : npt.NDArray + A 2D array of a binary unbranching skeleton. + start : npt.NDArray + 2x1 coordinate that must exist in 'nodeless'. + max_length : float | np.inf, optional + Maximum length to traverse along while ordering, by default np.inf. + + Returns + ------- + npt.NDArray + Ordered coordinates. + """ + dist = 0 + # add starting point to ordered array + ordered = [] + ordered.append(start) + nodeless[start[0], start[1]] = 0 # remove from array + + # iterate to order the rest of the points + current_point = ordered[-1] # get last point + area, _ = local_area_sum(nodeless, current_point) # look at local area + local_next_point = np.argwhere( + area.reshape( + ( + 3, + 3, + ) + ) + == 1 + ) - (1, 1) + dist += np.sqrt(2) if abs(local_next_point).sum() > 1 else 1 + + while len(local_next_point) != 0 and dist <= max_length: + next_point = (current_point + local_next_point)[0] + # find where to go next + ordered.append(next_point) + nodeless[next_point[0], next_point[1]] = 0 # set value to zero + current_point = ordered[-1] # get last point + area, _ = local_area_sum(nodeless, current_point) # look at local area + local_next_point = np.argwhere( + area.reshape( + ( + 3, + 3, + ) + ) + == 1 + ) - (1, 1) + dist += np.sqrt(2) if abs(local_next_point).sum() > 1 else 1 + + return np.array(ordered) + + +def local_area_sum(binary_map: npt.NDArray, point: list | tuple | npt.NDArray) -> npt.NDArray: + """ + Evaluate the local area around a point in a binary map. + + Parameters + ---------- + binary_map : npt.NDArray + A binary array of an image. + point : Union[list, tuple, npt.NDArray] + A single object containing 2 integers relating to a point within the binary_map. + + Returns + ------- + npt.NDArray + An array values of the local coordinates around the point. + int + A value corresponding to the number of neighbours around the point in the binary_map. + """ + x, y = point + local_pixels = binary_map[x - 1 : x + 2, y - 1 : y + 2].flatten() + local_pixels[4] = 0 # ensure centre is 0 + return local_pixels, local_pixels.sum() + + +def coord_dist(coords: npt.NDArray, pixel_to_nm_scaling: float = 1) -> npt.NDArray: + """ + Accumulate a real distance traversing from pixel to pixel from a list of coordinates. + + Parameters + ---------- + coords : npt.NDArray + A Nx2 integer array corresponding to the ordered coordinates of a binary trace. + pixel_to_nm_scaling : float + The pixel to nanometer scaling factor. + + Returns + ------- + npt.NDArray + An array of length N containing thcumulative sum of the distances. + """ + dist_list = [0] + dist = 0 + for i in range(len(coords) - 1): + if abs(coords[i] - coords[i + 1]).sum() == 2: + dist += 2**0.5 + else: + dist += 1 + dist_list.append(dist) + return np.asarray(dist_list) * pixel_to_nm_scaling diff --git a/topostats/utils.py b/topostats/utils.py index 23ea0e18e3..e9ad1134f0 100644 --- a/topostats/utils.py +++ b/topostats/utils.py @@ -11,6 +11,7 @@ import numpy as np import numpy.typing as npt import pandas as pd +from scipy.ndimage import convolve from topostats.logs.logs import LOGGER_NAME from topostats.thresholds import threshold @@ -21,7 +22,7 @@ ALL_STATISTICS_COLUMNS = ( "image", "basename", - "molecule_number", + "grain_number", "area", "area_cartesian_bbox", "aspect_ratio", @@ -123,13 +124,20 @@ def update_plotting_config(plotting_config: dict) -> dict: f"Main plotting options that need updating/adding to plotting dict :\n{pformat(main_config, indent=4)}" ) for image, options in plotting_config["plot_dict"].items(): + main_config_temp = main_config.copy() LOGGER.debug(f"Dictionary for image : {image}") LOGGER.debug(f"{pformat(options, indent=4)}") # First update options with values that exist in main_config - plotting_config["plot_dict"][image] = update_config(options, main_config) + # We must however be careful not to update the colourmap for diagnostic traces + if ( + not plotting_config["plot_dict"][image]["core_set"] + and "mask_cmap" in plotting_config["plot_dict"][image].keys() + ): + main_config_temp.pop("mask_cmap") + plotting_config["plot_dict"][image] = update_config(options, main_config_temp) LOGGER.debug(f"Updated values :\n{pformat(plotting_config['plot_dict'][image])}") # Then combine the remaining key/values we need from main_config that don't already exist - for key_main, value_main in main_config.items(): + for key_main, value_main in main_config_temp.items(): if key_main not in plotting_config["plot_dict"][image]: plotting_config["plot_dict"][image][key_main] = value_main LOGGER.debug(f"After adding missing configuration options :\n{pformat(plotting_config['plot_dict'][image])}") @@ -230,7 +238,7 @@ def get_thresholds( # noqa: C901 Returns ------- - Dict + dict Dictionary of thresholds, contains keys 'below' and optionally 'above'. """ thresholds = defaultdict() @@ -261,7 +269,7 @@ def get_thresholds( # noqa: C901 return thresholds -def create_empty_dataframe(columns: set = ALL_STATISTICS_COLUMNS, index: str = "molecule_number") -> pd.DataFrame: +def create_empty_dataframe(columns: set = ALL_STATISTICS_COLUMNS, index: str = "grain_number") -> pd.DataFrame: """ Create an empty data frame for returning when no results are found. @@ -339,3 +347,62 @@ def check(coord: npt.NDArray, max_val: int, padding: int) -> npt.NDArray: return coord return check(row_coord, max_row, padding), check(col_coord, max_col, padding) + + +def convolve_skeleton(skeleton: npt.NDArray) -> npt.NDArray: + """ + Convolve skeleton with a 3x3 kernel. + + This produces an array where the branches of the skeleton are denoted with '1', endpoints are denoted as '2', and + pixels at nodes as '3'. + + Parameters + ---------- + skeleton : npt.NDArray + Single pixel thick binary trace(s) within an array. + + Returns + ------- + npt.NDArray + The skeleton (=1) with endpoints (=2), and crossings (=3) highlighted. + """ + conv = convolve(skeleton.astype(np.int32), np.ones((3, 3))) + conv[skeleton == 0] = 0 # remove non-skeleton points + conv[conv == 3] = 1 # skelly = 1 + conv[conv > 3] = 3 # nodes = 3 + return conv + + +class ResolutionError(Exception): + """Raised when the image resolution is too small for accuurate tracing.""" + + pass # pylint: disable=unnecessary-pass + + +def coords_2_img(coords, image, ordered=False) -> np.ndarray: + """ + Convert coordinates to a binary image. + + Parameters + ---------- + coords : np.ndarray + An array of 2xN integer coordinates. + image : np.ndarray + An MxL array to assign the above coordinates onto. + ordered : bool, optional + If True, incremements the value of each coord to show order. + + Returns + ------- + np.ndarray + An array the same shape as 'image' with the coordinates highlighted. + """ + comb = np.zeros_like(image) + if ordered: + comb[coords[:, 0].astype(np.int32), coords[:, 1].astype(np.int32)] = np.arange(1, len(coords) + 1) + else: + coords = coords[ + (coords[:, 0] < image.shape[0]) & (coords[:, 1] < image.shape[1]) & (coords[:, 0] > 0) & (coords[:, 1] > 0) + ] + comb[np.floor(coords[:, 0]).astype(np.int32), np.floor(coords[:, 1]).astype(np.int32)] = 1 + return comb diff --git a/topostats/validation.py b/topostats/validation.py index a0bf4fcdee..1353da07fe 100644 --- a/topostats/validation.py +++ b/topostats/validation.py @@ -11,6 +11,7 @@ LOGGER = logging.getLogger(LOGGER_NAME) # pylint: disable=line-too-long +# pylint: disable=too-many-lines def validate_config(config: dict, schema: Schema, config_type: str) -> None: @@ -207,25 +208,100 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "valid values are 'True' or 'False'", ), }, - "dnatracing": { + "disordered_tracing": { "run": Or( True, False, - error="Invalid value in config for 'dnatracing.run', valid values are 'True' or 'False'", + error="Invalid value in config for 'disordered_tracing.run', valid values are 'True' or 'False'", ), "min_skeleton_size": lambda n: n > 0.0, - "skeletonisation_method": Or( - "zhang", - "lee", - "thin", - "topostats", - error="Invalid value in config for 'dnatracing.skeletonisation_method'," - "valid values are 'zhang' or 'lee', 'thin' or 'topostats'", + "pad_width": lambda n: n > 0.0, + "mask_smoothing_params": { + "gaussian_sigma": Or( + float, + int, + None, + ), + "dilation_iterations": Or( + int, + None, + ), + "holearea_min_max": [ + Or( + int, + float, + None, + error=( + "Invalid value in config for 'disordered_tracing.mask_smoothing_params.holearea_min_max', valid values " + "are int, float or null" + ), + ), + ], + }, + "skeletonisation_params": { + "method": Or( + "zhang", + "lee", + "thin", + "medial_axis", + "topostats", + error="Invalid value in config for 'disordered_tracing.skeletonisation_method'," + "valid values are 'zhang', 'lee', 'thin', 'medial_axis', 'topostats'", + ), + "height_bias": lambda n: 0 < n <= 1, + }, + "pruning_params": { + "method": Or( + "topostats", + error="Invalid value in config for 'disordered_tracing.pruning_method', valid values are 'topostats'", + ), + "max_length": lambda n: n >= 0, + "method_values": Or("min", "median", "mid"), + "method_outlier": Or("abs", "mean_abs", "iqr"), + "height_threshold": Or(int, float, None), + }, + }, + "nodestats": { + "run": Or( + True, + False, + error="Invalid value in config for 'nodestats.run', valid values are 'True' or 'False'", + ), + "node_joining_length": float, + "node_extend_dist": float, + "branch_pairing_length": float, + "pair_odd_branches": bool, + "pad_width": lambda n: n > 0.0, + }, + "ordered_tracing": { + "run": Or( + True, + False, + error="Invalid value in config for 'ordered_tracing.run', valid values are 'True' or 'False'", + ), + "ordering_method": Or( + "nodestats", + "original", + error="Invalid value in config for 'ordered_tracing.ordering_method', valid values are 'nodestats' or 'original'", + ), + "pad_width": lambda n: n > 0.0, + }, + "splining": { + "run": Or( + True, + False, + error="Invalid value in config for 'splining.run', valid values are 'True' or 'False'", ), + "method": Or( + "spline", + "rolling_window", + error="Invalid value in config for 'splining.method', valid values are 'spline' or 'rolling_window'", + ), + "rolling_window_size": lambda n: n > 0.0, "spline_step_size": lambda n: n > 0.0, "spline_linear_smoothing": lambda n: n >= 0.0, "spline_circular_smoothing": lambda n: n >= 0.0, - "pad_width": lambda n: n > 0.0, + "spline_degree": int, # "cores": lambda n: n > 0.0, }, "plotting": { @@ -665,6 +741,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "figure", error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), + "mask_cmap": str, }, "labelled_regions_01": { "filename": str, @@ -683,6 +760,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "figure", error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), + "mask_cmap": str, }, "tidied_border": { "filename": str, @@ -700,6 +778,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "figure", error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), + "mask_cmap": str, }, "removed_noise": { "filename": str, @@ -711,6 +790,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "Invalid value in config 'removed_noise.image_type', valid values " "are 'binary' or 'non-binary'" ), ), + "mask_cmap": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -729,6 +809,26 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "are 'binary' or 'non-binary'" ), ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": Or( + lambda n: n > 0, + "figure", + error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", + ), + }, + "removed_objects_too_small_to_process": { + "filename": str, + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'removed_objects_too_small_to_process.image_type', valid values " + "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -763,6 +863,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "are 'binary' or 'non-binary'" ), ), + "mask_cmap": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -781,6 +882,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "are 'binary' or 'non-binary'" ), ), + "mask_cmap": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -821,17 +923,29 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "figure", error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), + "mask_cmap": str, }, - "all_molecule_traces": { - "title": str, + "grain_image": { "image_type": Or( "binary", "non-binary", error=( - "Invalid value in config 'all_molecule_traces.image_type', valid values " - "are 'binary' or 'non-binary'" + "Invalid value in config 'grain_image.image_type', valid values " "are 'binary' or 'non-binary'" ), ), + "core_set": False, + "savefig_dpi": Or( + lambda n: n > 0, + "figure", + error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", + ), + }, + "grain_mask": { + "image_type": Or( + "binary", + "non-binary", + error=("Invalid value in config 'grain_mask.image_type', valid values " "are 'binary' or 'non-binary'"), + ), "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -839,27 +953,209 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), }, - "grain_image": { + "grain_mask_image": { "image_type": Or( "binary", "non-binary", error=( - "Invalid value in config 'grain_image.image_type', valid values " "are 'binary' or 'non-binary'" + "Invalid value in config 'grain_mask_image.image_type', valid values " + "are 'binary' or 'non-binary'" ), ), - "core_set": False, + "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, "figure", error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), }, - "grain_mask": { + "orig_grain": { + "filename": str, + "title": str, "image_type": Or( "binary", "non-binary", - error=("Invalid value in config 'grain_mask.image_type', valid values " "are 'binary' or 'non-binary'"), + error=( + "Invalid value in config 'coloured_boxes.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + }, + "smoothed_grain": { + "filename": str, + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'coloured_boxes.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + }, + "skeleton": { + "filename": str, + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'coloured_boxes.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "pruned_skeleton": { + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'coloured_boxes.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "branch_indexes": { + "filename": str, + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'branch_indexes.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "branch_types": { + "filename": str, + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'coloured_boxes.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "convolved_skeletons": { + "filename": str, + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'convolved_skeleton.image_type', valid values " + "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "node_centres": { + "filename": str, + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'node_centres.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "connected_nodes": { + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'node_branch_mask.image_type', valid values " + "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "node_area_skeleton": { + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'node_area_skeleton.image_type', valid values " + "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "node_branch_mask": { + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'node_branch_mask.image_type', valid values " + "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "node_avg_mask": { + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'node_avg_mask.image_type', valid values " "are 'binary' or 'non-binary'" + ), ), + "mask_cmap": str, + "core_set": bool, + }, + "node_line_trace": { + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'coloured_boxes.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + }, + "ordered_traces": { + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'all_molecule_traces.image_type', valid values " + "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -867,15 +1163,18 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), }, - "grain_mask_image": { + "trace_segments": { + "filename": str, + "title": str, "image_type": Or( "binary", "non-binary", error=( - "Invalid value in config 'grain_mask_image.image_type', valid values " + "Invalid value in config 'all_molecule_traces.image_type', valid values " "are 'binary' or 'non-binary'" ), ), + "mask_cmap": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -883,15 +1182,62 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), }, - "single_molecule_trace": { + "over_under": { + "filename": str, + "title": str, "image_type": Or( "binary", "non-binary", error=( - "Invalid value in config 'single_molecule_trace.image_type', valid values " + "Invalid value in config 'all_molecule_traces.image_type', valid values " "are 'binary' or 'non-binary'" ), ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": Or( + lambda n: n > 0, + "figure", + error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", + ), + }, + "all_molecules": { + "filename": str, + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'coloured_boxes.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "fitted_trace": { + "filename": str, + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'coloured_boxes.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, + "splined_trace": { + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'splined_trace.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "title": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0,