-
Notifications
You must be signed in to change notification settings - Fork 0
/
__init__.py
1757 lines (1520 loc) · 56.2 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import json
import multiprocessing
import os
import pathlib
import random
import re
import tempfile
import webbrowser
from collections import defaultdict
from multiprocessing.sharedctypes import RawArray
import pandas as pd
import mss
from downloadunzip import extract
from list_all_files_recursively import get_folder_file_complete_path
from tolerant_isinstance import isinstance_tolerant
from touchtouch import touch
import cv2
from a_cv_imwrite_imread_plus import open_image_in_cv, save_cv_image
from functools import cache, reduce
from collections import deque
from intersectioncalc import find_rectangle_intersections
import math
from a_cv2_easy_resize import add_easy_resize_to_cv2
add_easy_resize_to_cv2()
import inspect
import functools
import sys
import warnings
from collections.abc import Iterable
import numexpr
import numpy as np
haystackpic = []
needlepics = []
minimum_match = 0.5
tempmodule = sys.modules[__name__]
tempmodule.allftconvolve = {}
import scipy
tempmodule.allftconvolve["normal"] = scipy.signal.fftconvolve
usecupy = False
try:
import cupy as cp
from cupyx.scipy.signal import fftconvolve as fftconvolv
tempmodule.allftconvolve["cp"] = fftconvolv
except Exception as fe:
print(fe)
try:
import pyfftw
scipy.fft.set_backend(pyfftw.interfaces.scipy_fft)
pyfftw.interfaces.cache.enable()
except Exception as fe:
print(fe)
class skimage_deprecation(Warning):
"""Create our own deprecation class, since Python >= 2.7
silences deprecations by default.
"""
pass
def _get_stack_rank(func):
"""Return function rank in the call stack."""
if _is_wrapped(func):
return 1 + _get_stack_rank(func.__wrapped__)
else:
return 0
def _is_wrapped(func):
return "__wrapped__" in dir(func)
def _get_stack_length(func):
"""Return function call stack length."""
return _get_stack_rank(func.__globals__.get(func.__name__, func))
class _DecoratorBaseClass:
"""Used to manage decorators' warnings stacklevel.
The `_stack_length` class variable is used to store the number of
times a function is wrapped by a decorator.
Let `stack_length` be the total number of times a decorated
function is wrapped, and `stack_rank` be the rank of the decorator
in the decorators stack. The stacklevel of a warning is then
`stacklevel = 1 + stack_length - stack_rank`.
"""
_stack_length = {}
def get_stack_length(self, func):
return self._stack_length.get(func.__name__, _get_stack_length(func))
class change_default_value(_DecoratorBaseClass):
"""Decorator for changing the default value of an argument.
Parameters
----------
arg_name: str
The name of the argument to be updated.
new_value: any
The argument new value.
changed_version : str
The package version in which the change will be introduced.
warning_msg: str
Optional warning message. If None, a generic warning message
is used.
"""
def __init__(self, arg_name, *, new_value, changed_version, warning_msg=None):
self.arg_name = arg_name
self.new_value = new_value
self.warning_msg = warning_msg
self.changed_version = changed_version
def __call__(self, func):
parameters = inspect.signature(func).parameters
arg_idx = list(parameters.keys()).index(self.arg_name)
old_value = parameters[self.arg_name].default
stack_rank = _get_stack_rank(func)
if self.warning_msg is None:
self.warning_msg = (
f"The new recommended value for {self.arg_name} is "
f"{self.new_value}. Until version {self.changed_version}, "
f"the default {self.arg_name} value is {old_value}. "
f"From version {self.changed_version}, the {self.arg_name} "
f"default value will be {self.new_value}. To avoid "
f"this warning, please explicitly set {self.arg_name} value."
)
@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank
if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys():
# warn that arg_name default value changed:
warnings.warn(self.warning_msg, FutureWarning, stacklevel=stacklevel)
return func(*args, **kwargs)
return fixed_func
class remove_arg(_DecoratorBaseClass):
"""Decorator to remove an argument from function's signature.
Parameters
----------
arg_name: str
The name of the argument to be removed.
changed_version : str
The package version in which the warning will be replaced by
an error.
help_msg: str
Optional message appended to the generic warning message.
"""
def __init__(self, arg_name, *, changed_version, help_msg=None):
self.arg_name = arg_name
self.help_msg = help_msg
self.changed_version = changed_version
def __call__(self, func):
parameters = inspect.signature(func).parameters
arg_idx = list(parameters.keys()).index(self.arg_name)
warning_msg = (
f"{self.arg_name} argument is deprecated and will be removed "
f"in version {self.changed_version}. To avoid this warning, "
f"please do not use the {self.arg_name} argument. Please "
f"see {func.__name__} documentation for more details."
)
if self.help_msg is not None:
warning_msg += f" {self.help_msg}"
stack_rank = _get_stack_rank(func)
@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank
if len(args) > arg_idx or self.arg_name in kwargs.keys():
# warn that arg_name is deprecated
warnings.warn(warning_msg, FutureWarning, stacklevel=stacklevel)
return func(*args, **kwargs)
return fixed_func
def docstring_add_deprecated(func, kwarg_mapping, deprecated_version):
"""Add deprecated kwarg(s) to the "Other Params" section of a docstring.
Parameters
---------
func : function
The function whose docstring we wish to update.
kwarg_mapping : dict
A dict containing {old_arg: new_arg} key/value pairs as used by
`deprecate_kwarg`.
deprecated_version : str
A major.minor version string specifying when old_arg was
deprecated.
Returns
-------
new_doc : str
The updated docstring. Returns the original docstring if numpydoc is
not available.
"""
if func.__doc__ is None:
return None
try:
from numpydoc.docscrape import FunctionDoc, Parameter
except ImportError:
# Return an unmodified docstring if numpydoc is not available.
return func.__doc__
Doc = FunctionDoc(func)
for old_arg, new_arg in kwarg_mapping.items():
desc = [
f"Deprecated in favor of `{new_arg}`.",
"",
f".. deprecated:: {deprecated_version}",
]
Doc["Other Parameters"].append(
Parameter(name=old_arg, type="DEPRECATED", desc=desc)
)
new_docstring = str(Doc)
# new_docstring will have a header starting with:
#
# .. function:: func.__name__
#
# and some additional blank lines. We strip these off below.
split = new_docstring.split("\n")
no_header = split[1:]
while not no_header[0].strip():
no_header.pop(0)
# Store the initial description before any of the Parameters fields.
# Usually this is a single line, but the while loop covers any case
# where it is not.
descr = no_header.pop(0)
while no_header[0].strip():
descr += "\n " + no_header.pop(0)
descr += "\n\n"
# '\n ' rather than '\n' here to restore the original indentation.
final_docstring = descr + "\n ".join(no_header)
# strip any extra spaces from ends of lines
final_docstring = "\n".join([line.rstrip() for line in final_docstring.split("\n")])
return final_docstring
class deprecate_kwarg(_DecoratorBaseClass):
"""Decorator ensuring backward compatibility when argument names are
modified in a function definition.
Parameters
----------
kwarg_mapping: dict
Mapping between the function's old argument names and the new
ones.
deprecated_version : str
The package version in which the argument was first deprecated.
warning_msg: str
Optional warning message. If None, a generic warning message
is used.
removed_version : str
The package version in which the deprecated argument will be
removed.
"""
def __init__(
self, kwarg_mapping, deprecated_version, warning_msg=None, removed_version=None
):
self.kwarg_mapping = kwarg_mapping
if warning_msg is None:
self.warning_msg = (
"`{old_arg}` is a deprecated argument name " "for `{func_name}`. "
)
if removed_version is not None:
self.warning_msg += (
f"It will be removed in " f"version {removed_version}. "
)
self.warning_msg += "Please use `{new_arg}` instead."
else:
self.warning_msg = warning_msg
self.deprecated_version = deprecated_version
def __call__(self, func):
stack_rank = _get_stack_rank(func)
@functools.wraps(func)
def fixed_func(*args, **kwargs):
stacklevel = 1 + self.get_stack_length(func) - stack_rank
for old_arg, new_arg in self.kwarg_mapping.items():
if old_arg in kwargs:
# warn that the function interface has changed:
warnings.warn(
self.warning_msg.format(
old_arg=old_arg, func_name=func.__name__, new_arg=new_arg
),
FutureWarning,
stacklevel=stacklevel,
)
# Substitute new_arg to old_arg
kwargs[new_arg] = kwargs.pop(old_arg)
# Call the function with the fixed arguments
return func(*args, **kwargs)
if func.__doc__ is not None:
newdoc = docstring_add_deprecated(
func, self.kwarg_mapping, self.deprecated_version
)
fixed_func.__doc__ = newdoc
return fixed_func
class channel_as_last_axis:
"""Decorator for automatically making channels axis last for all arrays.
This decorator reorders axes for compatibility with functions that only
support channels along the last axis. After the function call is complete
the channels axis is restored back to its original position.
Parameters
----------
channel_arg_positions : tuple of int, optional
Positional arguments at the positions specified in this tuple are
assumed to be multichannel arrays. The default is to assume only the
first argument to the function is a multichannel array.
channel_kwarg_names : tuple of str, optional
A tuple containing the names of any keyword arguments corresponding to
multichannel arrays.
multichannel_output : bool, optional
A boolean that should be True if the output of the function is not a
multichannel array and False otherwise. This decorator does not
currently support the general case of functions with multiple outputs
where some or all are multichannel.
"""
def __init__(
self,
channel_arg_positions=(0,),
channel_kwarg_names=(),
multichannel_output=True,
):
self.arg_positions = set(channel_arg_positions)
self.kwarg_names = set(channel_kwarg_names)
self.multichannel_output = multichannel_output
def __call__(self, func):
@functools.wraps(func)
def fixed_func(*args, **kwargs):
channel_axis = kwargs.get("channel_axis", None)
if channel_axis is None:
return func(*args, **kwargs)
# TODO: convert scalars to a tuple in anticipation of eventually
# supporting a tuple of channel axes. Right now, only an
# integer or a single-element tuple is supported, though.
if np.isscalar(channel_axis):
channel_axis = (channel_axis,)
if len(channel_axis) > 1:
raise ValueError("only a single channel axis is currently supported")
if channel_axis == (-1,) or channel_axis == -1:
return func(*args, **kwargs)
if self.arg_positions:
new_args = []
for pos, arg in enumerate(args):
if pos in self.arg_positions:
new_args.append(np.moveaxis(arg, channel_axis[0], -1))
else:
new_args.append(arg)
new_args = tuple(new_args)
else:
new_args = args
for name in self.kwarg_names:
kwargs[name] = np.moveaxis(kwargs[name], channel_axis[0], -1)
# now that we have moved the channels axis to the last position,
# change the channel_axis argument to -1
kwargs["channel_axis"] = -1
# Call the function with the fixed arguments
out = func(*new_args, **kwargs)
if self.multichannel_output:
out = np.moveaxis(out, -1, channel_axis[0])
return out
return fixed_func
class deprecated:
"""Decorator to mark deprecated functions with warning.
Adapted from <http://wiki.python.org/moin/PythonDecoratorLibrary>.
Parameters
----------
alt_func : str
If given, tell user what function to use instead.
behavior : {'warn', 'raise'}
Behavior during call to deprecated function: 'warn' = warn user that
function is deprecated; 'raise' = raise error.
removed_version : str
The package version in which the deprecated function will be removed.
"""
def __init__(self, alt_func=None, behavior="warn", removed_version=None):
self.alt_func = alt_func
self.behavior = behavior
self.removed_version = removed_version
def __call__(self, func):
alt_msg = ""
if self.alt_func is not None:
alt_msg = f" Use ``{self.alt_func}`` instead."
rmv_msg = ""
if self.removed_version is not None:
rmv_msg = f" and will be removed in version {self.removed_version}"
msg = f"Function ``{func.__name__}`` is deprecated{rmv_msg}.{alt_msg}"
@functools.wraps(func)
def wrapped(*args, **kwargs):
if self.behavior == "warn":
func_code = func.__code__
warnings.simplefilter("always", skimage_deprecation)
warnings.warn_explicit(
msg,
category=skimage_deprecation,
filename=func_code.co_filename,
lineno=func_code.co_firstlineno + 1,
)
elif self.behavior == "raise":
raise skimage_deprecation(msg)
return func(*args, **kwargs)
# modify doc string to display deprecation warning
doc = "**Deprecated function**." + alt_msg
if wrapped.__doc__ is None:
wrapped.__doc__ = doc
else:
wrapped.__doc__ = doc + "\n\n " + wrapped.__doc__
return wrapped
def get_bound_method_class(m):
"""Return the class for a bound method."""
return m.im_class if sys.version < "3" else m.__self__.__class__
def safe_as_int(val, atol=1e-3):
"""
Attempt to safely cast values to integer format.
Parameters
----------
val : scalar or iterable of scalars
Number or container of numbers which are intended to be interpreted as
integers, e.g., for indexing purposes, but which may not carry integer
type.
atol : float
Absolute tolerance away from nearest integer to consider values in
``val`` functionally integers.
Returns
-------
val_int : NumPy scalar or ndarray of dtype `np.int64`
Returns the input value(s) coerced to dtype `np.int64` assuming all
were within ``atol`` of the nearest integer.
Notes
-----
This operation calculates ``val`` modulo 1, which returns the mantissa of
all values. Then all mantissas greater than 0.5 are subtracted from one.
Finally, the absolute tolerance from zero is calculated. If it is less
than ``atol`` for all value(s) in ``val``, they are rounded and returned
in an integer array. Or, if ``val`` was a scalar, a NumPy scalar type is
returned.
If any value(s) are outside the specified tolerance, an informative error
is raised.
Examples
--------
>>> safe_as_int(7.0)
7
>>> safe_as_int([9, 4, 2.9999999999])
array([9, 4, 3])
>>> safe_as_int(53.1)
Traceback (most recent call last):
...
ValueError: Integer argument required but received 53.1, check inputs.
>>> safe_as_int(53.01, atol=0.01)
53
"""
mod = np.asarray(val) % 1 # Extract mantissa
# Check for and subtract any mod values > 0.5 from 1
if mod.ndim == 0: # Scalar input, cannot be indexed
if mod > 0.5:
mod = 1 - mod
else: # Iterable input, now ndarray
mod[mod > 0.5] = 1 - mod[mod > 0.5] # Test on each side of nearest int
try:
np.testing.assert_allclose(mod, 0, atol=atol)
except AssertionError:
raise ValueError(
f"Integer argument required but received " f"{val}, check inputs."
)
return np.round(val).astype(np.int64)
def check_shape_equality(*images):
"""Check that all images have the same shape"""
image0 = images[0]
if not all(image0.shape == image.shape for image in images[1:]):
raise ValueError("Input images must have the same dimensions.")
return
def slice_at_axis(sl, axis):
"""
Construct tuple of slices to slice an array in the given dimension.
Parameters
----------
sl : slice
The slice for the given dimension.
axis : int
The axis to which `sl` is applied. All other dimensions are left
"unsliced".
Returns
-------
sl : tuple of slices
A tuple with slices matching `shape` in length.
Examples
--------
>>> slice_at_axis(slice(None, 3, -1), 1)
(slice(None, None, None), slice(None, 3, -1), Ellipsis)
"""
return (slice(None),) * axis + (sl,) + (...,)
def reshape_nd(arr, ndim, dim):
"""Reshape a 1D array to have n dimensions, all singletons but one.
Parameters
----------
arr : array, shape (N,)
Input array
ndim : int
Number of desired dimensions of reshaped array.
dim : int
Which dimension/axis will not be singleton-sized.
Returns
-------
arr_reshaped : array, shape ([1, ...], N, [1,...])
View of `arr` reshaped to the desired shape.
Examples
--------
>>> rng = np.random.default_rng()
>>> arr = rng.random(7)
>>> reshape_nd(arr, 2, 0).shape
(7, 1)
>>> reshape_nd(arr, 3, 1).shape
(1, 7, 1)
>>> reshape_nd(arr, 4, -1).shape
(1, 1, 1, 7)
"""
if arr.ndim != 1:
raise ValueError("arr must be a 1D array")
new_shape = [1] * ndim
new_shape[dim] = -1
return np.reshape(arr, new_shape)
def check_nD(array, ndim, arg_name="image"):
"""
Verify an array meets the desired ndims and array isn't empty.
Parameters
----------
array : array-like
Input array to be validated
ndim : int or iterable of ints
Allowable ndim or ndims for the array.
arg_name : str, optional
The name of the array in the original function.
"""
array = np.asanyarray(array)
msg_incorrect_dim = "The parameter `%s` must be a %s-dimensional array"
msg_empty_array = "The parameter `%s` cannot be an empty array"
if isinstance(ndim, int):
ndim = [ndim]
if array.size == 0:
raise ValueError(msg_empty_array % (arg_name))
if array.ndim not in ndim:
raise ValueError(
msg_incorrect_dim % (arg_name, "-or-".join([str(n) for n in ndim]))
)
def convert_to_float(image, preserve_range):
"""Convert input image to float image with the appropriate range.
Parameters
----------
image : ndarray
Input image.
preserve_range : bool
Determines if the range of the image should be kept or transformed
using img_as_float. Also see
https://scikit-image.org/docs/dev/user_guide/data_types.html
Notes
-----
* Input images with `float32` data type are not upcast.
Returns
-------
image : ndarray
Transformed version of the input.
"""
if image.dtype == np.float16:
return image.astype(np.float32)
if preserve_range:
# Convert image to double only if it is not single or double
# precision float
if image.dtype.char not in "df":
image = image.astype(float)
else:
image = img_as_float(image)
return image
def _validate_interpolation_order(image_dtype, order):
"""Validate and return spline interpolation's order.
Parameters
----------
image_dtype : dtype
Image dtype.
order : int, optional
The order of the spline interpolation. The order has to be in
the range 0-5. See `skimage.transform.warp` for detail.
Returns
-------
order : int
if input order is None, returns 0 if image_dtype is bool and 1
otherwise. Otherwise, image_dtype is checked and input order
is validated accordingly (order > 0 is not supported for bool
image dtype)
"""
if order is None:
return 0 if image_dtype == bool else 1
if order < 0 or order > 5:
raise ValueError("Spline interpolation order has to be in the " "range 0-5.")
if image_dtype == bool and order != 0:
raise ValueError(
"Input image dtype is bool. Interpolation is not defined "
"with bool data type. Please set order to 0 or explicitly "
"cast input image to another data type."
)
return order
def _to_np_mode(mode):
"""Convert padding modes from `ndi.correlate` to `np.pad`."""
mode_translation_dict = dict(nearest="edge", reflect="symmetric", mirror="reflect")
if mode in mode_translation_dict:
mode = mode_translation_dict[mode]
return mode
def _to_ndimage_mode(mode):
"""Convert from `numpy.pad` mode name to the corresponding ndimage mode."""
mode_translation_dict = dict(
constant="constant",
edge="nearest",
symmetric="reflect",
reflect="mirror",
wrap="wrap",
)
if mode not in mode_translation_dict:
raise ValueError(
f"Unknown mode: '{mode}', or cannot translate mode. The "
f"mode should be one of 'constant', 'edge', 'symmetric', "
f"'reflect', or 'wrap'. See the documentation of numpy.pad for "
f"more info."
)
return _fix_ndimage_mode(mode_translation_dict[mode])
def _fix_ndimage_mode(mode):
# SciPy 1.6.0 introduced grid variants of constant and wrap which
# have less surprising behavior for images. Use these when available
grid_modes = {"constant": "grid-constant", "wrap": "grid-wrap"}
return grid_modes.get(mode, mode)
new_float_type = {
# preserved types
np.float32().dtype.char: np.float32,
np.float64().dtype.char: np.float64,
np.complex64().dtype.char: np.complex64,
np.complex128().dtype.char: np.complex128,
# altered types
np.float16().dtype.char: np.float32,
"g": np.float64, # np.float128 ; doesn't exist on windows
"G": np.complex128, # np.complex256 ; doesn't exist on windows
}
def _supported_float_type(input_dtype, allow_complex=False):
"""Return an appropriate floating-point dtype for a given dtype.
float32, float64, complex64, complex128 are preserved.
float16 is promoted to float32.
complex256 is demoted to complex128.
Other types are cast to float64.
Parameters
----------
input_dtype : np.dtype or Iterable of np.dtype
The input dtype. If a sequence of multiple dtypes is provided, each
dtype is first converted to a supported floating point type and the
final dtype is then determined by applying `np.result_type` on the
sequence of supported floating point types.
allow_complex : bool, optional
If False, raise a ValueError on complex-valued inputs.
Returns
-------
float_type : dtype
Floating-point dtype for the image.
"""
if isinstance(input_dtype, Iterable) and not isinstance(input_dtype, str):
return np.result_type(*(_supported_float_type(d) for d in input_dtype))
input_dtype = np.dtype(input_dtype)
if not allow_complex and input_dtype.kind == "c":
raise ValueError("complex valued input is not supported")
return new_float_type.get(input_dtype.char, np.float64)
def identity(image, *args, **kwargs):
"""Returns the first argument unmodified."""
return image
def as_binary_ndarray(array, *, variable_name):
"""Return `array` as a numpy.ndarray of dtype bool.
Raises
------
ValueError:
An error including the given `variable_name` if `array` can not be
safely cast to a boolean array.
"""
array = np.asarray(array)
if array.dtype != bool:
if np.any((array != 1) & (array != 0)):
raise ValueError(
f"{variable_name} array is not of dtype boolean or "
f"contains values other than 0 and 1 so cannot be "
f"safely cast to boolean array."
)
return np.asarray(array, dtype=bool)
def _window_sum_2d(image, window_shape):
window_sum = np.cumsum(image, axis=0)
window_sum = window_sum[window_shape[0] : -1] - window_sum[: -window_shape[0] - 1]
window_sum = np.cumsum(window_sum, axis=1)
window_sum = (
window_sum[:, window_shape[1] : -1] - window_sum[:, : -window_shape[1] - 1]
)
return window_sum
def _window_sum_3d(image, window_shape):
window_sum = _window_sum_2d(image, window_shape)
window_sum = np.cumsum(window_sum, axis=2)
window_sum = (
window_sum[:, :, window_shape[2] : -1]
- window_sum[:, :, : -window_shape[2] - 1]
)
return window_sum
def match_template(
image,
template,
pad_input=False,
mode="constant",
constant_values=0,
):
"""Match a template to a 2-D or 3-D image using normalized correlation.
The output is an array with values between -1.0 and 1.0. The value at a
given position corresponds to the correlation coefficient between the image
and the template.
For `pad_input=True` matches correspond to the center and otherwise to the
top-left corner of the template. To find the best match you must search for
peaks in the response (output) image.
Parameters
----------
image : (M, N[, D]) array
2-D or 3-D input image.
template : (m, n[, d]) array
Template to locate. It must be `(m <= M, n <= N[, d <= D])`.
pad_input : bool
If True, pad `image` so that output is the same size as the image, and
output values correspond to the template center. Otherwise, the output
is an array with shape `(M - m + 1, N - n + 1)` for an `(M, N)` image
and an `(m, n)` template, and matches correspond to origin
(top-left corner) of the template.
mode : see `numpy.pad`, optional
Padding mode.
constant_values : see `numpy.pad`, optional
Constant values used in conjunction with ``mode='constant'``.
Returns
-------
output : array
Response image with correlation coefficients.
Notes
-----
Details on the cross-correlation are presented in [1]_. This implementation
uses FFT convolutions of the image and the template. Reference [2]_
presents similar derivations but the approximation presented in this
reference is not used in our implementation.
References
----------
.. [1] J. P. Lewis, "Fast Normalized Cross-Correlation", Industrial Light
and Magic.
.. [2] Briechle and Hanebeck, "Template Matching using Fast Normalized
Cross Correlation", Proceedings of the SPIE (2001).
:DOI:`10.1117/12.421129`
Examples
--------
>>> template = np.zeros((3, 3))
>>> template[1, 1] = 1
>>> template
array([[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]])
>>> image = np.zeros((6, 6))
>>> image[1, 1] = 1
>>> image[4, 4] = -1
>>> image
array([[ 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., -1., 0.],
[ 0., 0., 0., 0., 0., 0.]])
>>> result = match_template(image, template)
>>> np.round(result, 3)
array([[ 1. , -0.125, 0. , 0. ],
[-0.125, -0.125, 0. , 0. ],
[ 0. , 0. , 0.125, 0.125],
[ 0. , 0. , 0.125, -1. ]])
>>> result = match_template(image, template, pad_input=True)
>>> np.round(result, 3)
array([[-0.125, -0.125, -0.125, 0. , 0. , 0. ],
[-0.125, 1. , -0.125, 0. , 0. , 0. ],
[-0.125, -0.125, -0.125, 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0.125, 0.125, 0.125],
[ 0. , 0. , 0. , 0.125, -1. , 0.125],
[ 0. , 0. , 0. , 0.125, 0.125, 0.125]])
"""
image_shape = image.shape
pad_width = tuple((width, width) for width in template.shape)
if mode == "constant":
image = np.pad(
image, pad_width=pad_width, mode=mode, constant_values=constant_values
)
else:
image = np.pad(image, pad_width=pad_width, mode=mode)
nuim = numexpr.evaluate("image ** 2", global_dict={}, local_dict={"image": image})
if image.ndim == 2:
image_window_sum = _window_sum_2d(image, template.shape)
image_window_sum2 = _window_sum_2d(nuim, template.shape)
elif image.ndim == 3:
image_window_sum = _window_sum_3d(image, template.shape)
image_window_sum2 = _window_sum_3d(nuim, template.shape)
template_mean = template.mean()
template_volume = math.prod(template.shape)
# template_ssd = np.sum((template - template_mean) ** 2)
template_ssd = numexpr.evaluate(
"sum((template - template_mean) ** 2)",
global_dict={},
local_dict={"template": template, "template_mean": template_mean},
)
if not usecupy:
if image.ndim == 2:
xcorr = tempmodule.allftconvolve["normal"](
image, template[::-1, ::-1], mode="valid"
)[1:-1, 1:-1]
elif image.ndim == 3:
xcorr = tempmodule.allftconvolve["normal"](
image, template[::-1, ::-1, ::-1], mode="valid"
)[1:-1, 1:-1, 1:-1]
else:
if image.ndim == 2:
xcorr2 = cp.array(
tempmodule.allftconvolve["cp"](
cp.array(image), cp.array(template[::-1, ::-1]), mode="valid"
)[1:-1, 1:-1]
)
elif image.ndim == 3:
xcorr2 = cp.array(
tempmodule.allftconvolve["cp"](
cp.array(image), cp.array(template[::-1, ::-1, ::-1]), mode="valid"
)[1:-1, 1:-1, 1:-1]
)
xcorr = xcorr2.get()
numerator = numexpr.evaluate(
"xcorr - image_window_sum * template_mean",
global_dict={},
local_dict={
"xcorr": xcorr,
"image_window_sum": image_window_sum,
"template_mean": template_mean,
},
)
denominator = image_window_sum2
numexpr.evaluate(
"image_window_sum * image_window_sum",
out=image_window_sum,
global_dict={},
local_dict={
"image_window_sum": image_window_sum,
},
)
numexpr.evaluate(
"image_window_sum / template_volume",
out=image_window_sum,
global_dict={},
local_dict={
"image_window_sum": image_window_sum,
"template_volume": template_volume,
},