Skip to content

Commit

Permalink
add electronCount for single frame
Browse files Browse the repository at this point in the history
add individual frame counting

- no support for row-dark currently
- had to change ctor for vectorToPyArray, see here: pybind/pybind11#1042 (comment) . This may be a numpy 2 thing
- outputs will be a SparseArray with one frame/scan shape = (1, 1) and no metadata. This is to take advantage of methods in SparseArray
  • Loading branch information
swelborn committed Sep 12, 2024
1 parent 84d8e39 commit cb97047
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 2 deletions.
47 changes: 46 additions & 1 deletion python/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ py::array_t<T> vectorToPyArray(std::vector<T>&& v)
auto deleter = [](void* v) { delete reinterpret_cast<std::vector<T>*>(v); };
auto* ptr = new std::vector<T>(std::move(v));
auto capsule = py::capsule(ptr, deleter);
return py::array(ptr->size(), ptr->data(), capsule);
py::array_t<T> arr({ ptr->size() }, { sizeof(T) }, ptr->data(), capsule);
return arr;
}

struct ElectronCountedDataPyArray
Expand Down Expand Up @@ -214,6 +215,44 @@ ElectronCountedDataPyArray electronCount(Reader* reader,
return electronCount(reader, options.toCpp());
}

// Function to process individual frames
template <typename FrameType>
py::array_t<uint32_t> electronCount(
py::array_t<FrameType>& frame, Dimensions2D frameDimensions,
const ElectronCountOptionsClassicPy& options)
{
py::buffer_info frameBufferInfo = frame.request();

if (frameBufferInfo.ndim != 2 ||
frameBufferInfo.format != py::format_descriptor<FrameType>::format()) {
throw std::runtime_error(
"Input frame must be a 2D array of the correct type.");
}

const ElectronCountOptionsClassic cppOptions = options.toCpp();

// Convert the buffer to a std::vector
// TODO: is there a way to avoid this copy? For e.g span impl:
// https://github.com/pybind/pybind11/issues/1042#issuecomment-663154709
std::vector<FrameType> frameVec(static_cast<FrameType*>(frameBufferInfo.ptr),
static_cast<FrameType*>(frameBufferInfo.ptr) +
frameBufferInfo.size);

// Call the electronCount function with the std::vector
if (!cppOptions.darkReference) {
return vectorToPyArray(
electronCount<FrameType, false>(frameVec, frameDimensions, cppOptions));
} else {
return vectorToPyArray(
electronCount<FrameType, true>(frameVec, frameDimensions, cppOptions));
}

std::vector<uint32_t> result =
electronCount<FrameType, true>(frameVec, frameDimensions, cppOptions);

return vectorToPyArray(std::move(result));
}

// Explicitly instantiate version for py::array_t
template std::vector<STEMImage> createSTEMImages(
const std::vector<std::vector<py::array_t<uint32_t>>>& sparseData,
Expand Down Expand Up @@ -490,6 +529,12 @@ PYBIND11_MODULE(_image, m)
electronCount,
py::call_guard<py::gil_scoped_release>());

// Count individual frame
m.def("electron_count_frame",
(py::array_t<uint32_t>(*)(py::array_t<uint16_t>&, Dimensions2D,
const ElectronCountOptionsClassicPy&)) &
electronCount);

// Calculate thresholds, with gain
m.def(
"calculate_thresholds",
Expand Down
68 changes: 68 additions & 0 deletions python/stempy/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,74 @@ def electron_count(reader, darkreference=None, number_of_samples=40,

return array


def electron_count_frame(
frame: np.ndarray,
options=None,
darkreference=None,
background_threshold=4.0,
xray_threshold=2000.0,
gain=None,
):
"""Generate a list of coordinates of electron hits for a single 2D numpy array.
:param frame: the frame.
:type frame: numpy.ndarray
:param options: the options to use for electron counting. If set, all other
parameters are ignored.
:type options: stempy.image.ElectronCountOptionsClassic
:param darkreference: the dark reference to subtract, potentially generated
via stempy.image.calculate_average().
:type darkreference: stempy.image.ImageArray or numpy.ndarray
:param background_threshold: the threshold for background
:type background_threshold: float
:param xray_threshold: the threshold for x-rays
:type xray_threshold: float
:param gain: the gain mask to apply. Must match the frame dimensions.
:type gain: numpy.ndarray (2D)
:return: the coordinates of the electron hits for the frame.
:rtype: numpy.ndarray
"""

if gain is not None:
# Invert, as we will multiply in C++
# It also must be a float32
gain = np.power(gain, -1)
gain = _safe_cast(gain, np.float32, "gain")

if options is None:
if isinstance(darkreference, np.ndarray):
# Must be float32 for correct conversions
darkreference = _safe_cast(darkreference, np.float32, "dark reference")

options = _image.ElectronCountOptionsClassic()

options.dark_reference = darkreference
options.background_threshold = background_threshold
options.x_ray_threshold = xray_threshold
options.gain = gain
options.apply_row_dark_subtraction = False
options.optimized_mean = 0.0
options.apply_row_dark_use_mean = False
else:
if options.apply_row_dark_subtraction:
print("Warning: apply_row_dark_subtraction is not supported "
"for single frame electron counting. Ignoring this option.")
options.apply_row_dark_subtraction = False
options.gain = gain

electron_counts = _image.electron_count_frame(frame, frame.shape, options)
np_data = np.empty((1, 1), dtype=object)
np_data[0, 0] = np.array(electron_counts, copy=False)
kwargs = {
"data": np_data,
"scan_shape": (1, 1),
"frame_shape": frame.shape,
}
return SparseArray(**kwargs)


def radial_sum(reader, center=(-1, -1), scan_dimensions=(0, 0)):
"""Generate a radial sum from which STEM images can be generated.
Expand Down
26 changes: 26 additions & 0 deletions stempy/electron.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,24 @@ std::vector<uint32_t> electronCount(
return maximalPoints<FrameType>(frame, frameDimensions);
}

template <typename FrameType, bool dark>
std::vector<uint32_t> electronCount(std::vector<FrameType>& frame,
Dimensions2D frameDimensions,
const ElectronCountOptionsClassic& options)
{
auto* darkReference = options.darkReference;
auto backgroundThreshold = options.backgroundThreshold;
auto xRayThreshold = options.xRayThreshold;
auto* gain = options.gain;
auto applyRowDarkSubtraction = options.applyRowDarkSubtraction;
auto applyRowDarkUseMean = options.applyRowDarkUseMean;
auto optimizedMean = options.optimizedMean;

return electronCount<FrameType, dark>(
frame, frameDimensions, darkReference, backgroundThreshold, xRayThreshold,
gain, applyRowDarkSubtraction, optimizedMean, applyRowDarkUseMean);
}

template <typename Reader, typename FrameType, bool dark>
ElectronCountedData electronCount(Reader* reader,
const ElectronCountOptions& options)
Expand Down Expand Up @@ -874,4 +892,12 @@ template ElectronCountedData electronCount(
SectorStreamMultiPassThreadedReader* reader,
const ElectronCountOptions& options);

template std::vector<uint32_t> electronCount<uint16_t, true>(
std::vector<uint16_t>& frame, Dimensions2D frameDimensions,
const ElectronCountOptionsClassic& options);

template std::vector<uint32_t> electronCount<uint16_t, false>(
std::vector<uint16_t>& frame, Dimensions2D frameDimensions,
const ElectronCountOptionsClassic& options);

} // end namespace stempy
5 changes: 5 additions & 0 deletions stempy/electron.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ template <typename InputIt>
ElectronCountedData electronCount(InputIt first, InputIt last,
const ElectronCountOptionsClassic& options);

template <typename FrameType, bool dark>
std::vector<uint32_t> electronCount(std::vector<FrameType>& frame,
Dimensions2D frameDimensions,
const ElectronCountOptionsClassic& options);

template <typename Reader>
ElectronCountedData electronCount(Reader* reader,
const ElectronCountOptions& options);
Expand Down
69 changes: 68 additions & 1 deletion tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from stempy.image import com_dense, com_sparse, radial_sum_sparse
from stempy.image import com_dense, com_sparse, electron_count_frame, radial_sum_sparse
from stempy.io.sparse_array import SparseArray


Expand Down Expand Up @@ -61,3 +61,70 @@ def test_com_sparse_parameters(simulate_sparse_array):
# No counts will be in the center so all positions will be np.nan
com2 = com_sparse(sp, crop_to=(10,10), init_center=(1,1))
assert np.isnan(com2[0,0,0])


def test_electron_count_frame():
# Create a synthetic 2D numpy array (frame)
frame = np.array(
[
[2000, 0, 1000, 0, 0],
[0, 0, 0, 200, 0],
[0, 0, 1000, 0, 0],
[0, 200, 0, 200, 0],
[0, 0, 1000, 0, 0],
],
dtype=np.uint16,
)

dark = np.ones_like(frame) * 100

# Define expected electron hits (coordinates)
expected_hits = np.array(
[
[
[
[1, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
]
]
]
)

# Electron
electron_hits = electron_count_frame(
frame, xray_threshold=10000, background_threshold=1, darkreference=dark
)
assert np.array_equal(
electron_hits.to_dense(), expected_hits
), f"Expected {expected_hits}, but got {electron_hits}"

# Test with no dark reference
electron_hits = electron_count_frame(
frame, xray_threshold=10000, background_threshold=1
)

# Test where dark reference removes some points
dark = np.ones_like(frame) * 1000
electron_hits = electron_count_frame(
frame, xray_threshold=10000, background_threshold=1, darkreference=dark
)
expected_hits = np.array(
[
[
[
[1, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
]
]
]
)

assert np.array_equal(
electron_hits.to_dense(), expected_hits
), f"Expected {expected_hits}, but got {electron_hits}"

0 comments on commit cb97047

Please sign in to comment.