diff --git a/environment.yml b/environment.yml index e400131..0c1b91e 100644 --- a/environment.yml +++ b/environment.yml @@ -12,4 +12,5 @@ dependencies: - vispy - shapely - trimesh - - "imageio>=2.19" \ No newline at end of file + - "imageio>=2.19" + - allantools \ No newline at end of file diff --git a/examples/eit_system_performance_simulated.py b/examples/eit_system_performance_simulated.py new file mode 100644 index 0000000..7cdfeeb --- /dev/null +++ b/examples/eit_system_performance_simulated.py @@ -0,0 +1,207 @@ +# coding: utf-8 +from __future__ import absolute_import, division, print_function + +import matplotlib.pyplot as plt +import numpy as np +import pyeit.eit.jac as jac +import pyeit.mesh as mesh +from pyeit.eit.fem import EITForward +import pyeit.eit.protocol as protocol +from pyeit.mesh.wrapper import PyEITAnomaly_Circle +from pyeit.mesh.external import place_electrodes_equal_spacing +from numpy.random import default_rng +from pyeit.quality.eit_system import ( + calc_signal_to_noise_ratio, + calc_accuracy, + calc_drift, + calc_detectability, +) +from pyeit.eit.render import render_2d_mesh +from pyeit.visual.plot import colorbar + + +def main(): + # Configuration + # ------------------------------------------------------------------------------------------------------------------ + n_el = 16 + render_resolution = (64, 64) + background_value = 1 + anomaly_value = 2 + noise_magnitude = 2e-4 + drift_rate = 2e-7 # Per frame + + n_background_measurements = 10 + n_drift_measurements = 1800 + measurement_frequency = 1 + + detectability_r = 0.5 + distinguishability_r = 0.35 + + # Initialization + # ------------------------------------------------------------------------------------------------------------------ + drift_period_hours = n_drift_measurements / (60 * 60 * measurement_frequency) + conductive_target = True if anomaly_value - background_value > 0 else False + det_center_range = np.arange(0, 0.667, 0.067) + dist_distance_range = np.arange(0.35, 1.2, 0.09) + rng = default_rng(0) + + # Problem setup + # ------------------------------------------------------------------------------------------------------------------ + sim_mesh = mesh.create(n_el, h0=0.05) + electrode_nodes = place_electrodes_equal_spacing(sim_mesh, n_electrodes=n_el) + sim_mesh.el_pos = np.array(electrode_nodes) + protocol_obj = protocol.create(n_el, dist_exc=1, step_meas=1, parser_meas="std") + fwd = EITForward(sim_mesh, protocol_obj) + + recon_mesh = mesh.create(n_el, h0=0.1) + electrode_nodes = place_electrodes_equal_spacing(recon_mesh, n_electrodes=n_el) + recon_mesh.el_pos = np.array(electrode_nodes) + eit = jac.JAC(recon_mesh, protocol_obj) + eit.setup( + p=0.5, lamb=0.03, method="kotre", perm=background_value, jac_normalized=True + ) + + # Simulate background + # ------------------------------------------------------------------------------------------------------------------ + v0 = fwd.solve_eit(perm=background_value) + d1 = np.array(range(n_drift_measurements)) * drift_rate # Create drift + n1 = noise_magnitude * rng.standard_normal( + (n_drift_measurements, len(v0)) + ) # Create noise + + v0_dn = np.tile(v0, (n_drift_measurements, 1)) + np.tile(d1, (len(v0), 1)).T + n1 + + # Calculate background performance measures + # ------------------------------------------------------------------------------------------------------------------ + snr = calc_signal_to_noise_ratio(v0_dn[:n_background_measurements], method="db") + accuracy = calc_accuracy(v0_dn[:n_background_measurements], v0, method="EIDORS") + t2, adevs = calc_drift(v0_dn, method="Allan") + + drifts_delta = calc_drift(v0_dn, sample_period=10, method="Delta") + start = np.average(v0_dn[0:10], axis=0) + drifts_percent = 100 * drifts_delta / start + + # Simulate detectability test + # ------------------------------------------------------------------------------------------------------------------ + detectabilities = [] + detectability_renders = [] + for c in det_center_range: + anomaly = PyEITAnomaly_Circle( + center=[c, 0], r=detectability_r, perm=anomaly_value + ) + sim_mesh_new = mesh.set_perm( + sim_mesh, anomaly=anomaly, background=background_value + ) + v1 = fwd.solve_eit(perm=sim_mesh_new.perm) + n = noise_magnitude * rng.standard_normal(len(v1)) + v1_n = v1 + n + ds = eit.solve(v1_n, v0_dn[0], normalize=True) + solution = np.real(ds) + image = render_2d_mesh(recon_mesh, solution, resolution=render_resolution) + detectability = calc_detectability( + image, conductive_target=conductive_target, method="db" + ) + detectabilities.append(detectability) + detectability_renders.append(image) + + # Simulate distinguishability test + # ------------------------------------------------------------------------------------------------------------------ + anomaly = PyEITAnomaly_Circle(center=[0, 0], r=detectability_r, perm=anomaly_value) + sim_mesh_new = mesh.set_perm(sim_mesh, anomaly=anomaly, background=background_value) + dist_v0 = fwd.solve_eit(perm=sim_mesh_new.perm) + dist_v0n = dist_v0 + noise_magnitude * rng.standard_normal(len(dist_v0)) + + distinguishabilities = [] + distinguishabilitiy_renders = [] + for d in dist_distance_range: + a1 = PyEITAnomaly_Circle( + center=[d / 2, 0], r=distinguishability_r, perm=anomaly_value + ) + a2 = PyEITAnomaly_Circle( + center=[-d / 2, 0], r=distinguishability_r, perm=anomaly_value + ) + sim_mesh_new = mesh.set_perm( + sim_mesh, anomaly=[a1, a2], background=background_value + ) + v1 = fwd.solve_eit(perm=sim_mesh_new.perm) + v1_n = v1 + noise_magnitude * rng.standard_normal(len(v1)) + ds = eit.solve(v1_n, dist_v0n, normalize=True) + solution = np.real(ds) + image = render_2d_mesh(recon_mesh, solution, resolution=render_resolution) + # Distinguishability is detectability but with a target as the background. + distinguishability = calc_detectability( + image, conductive_target=conductive_target, method="db" + ) + distinguishabilities.append(distinguishability) + distinguishabilitiy_renders.append(image) + + # Plot results + # ------------------------------------------------------------------------------------------------------------------ + fig, axs = plt.subplots(2, 2) + + axs[0, 0].plot(snr) + axs[0, 0].set_xlabel("Channel Number") + axs[0, 0].set_ylabel("Signal to Noise Ratio\n(dB)") + axs[0, 0].title.set_text(f"Signal to Noise Ratio for {len(snr)} channels") + + axs[0, 1].plot(accuracy) + axs[0, 1].set_xlabel("Channel Number") + axs[0, 1].set_ylabel("Accuracy") + axs[0, 1].title.set_text(f"Accuracy for {len(snr)} channels") + + axs[1, 0].set_xlabel("Averaging Window (s)") + axs[1, 0].set_ylabel("Allan Deviation") + axs[1, 0].title.set_text(f"Allan Deviation for {len(snr)} channels") + for adev in adevs: + axs[1, 0].plot(t2, adev) + + axs[1, 1].plot(drifts_percent) + axs[1, 1].title.set_text( + f"Drift percentage on all channels.\nDrift period (hours): {drift_period_hours}" + ) + axs[1, 1].set_xlabel("Channel number") + axs[1, 1].set_ylabel("Drift (% of starting value)") + + fig.tight_layout() + fig.set_size_inches((10, 6)) + + fig, axs = plt.subplots(1, 2) + axs[0].plot(det_center_range, detectabilities, ".-", label="-x axis") + axs[0].legend() + axs[0].set_xlabel("Target position (radius fraction)") + axs[0].set_ylabel("Detectability (dB)") + axs[0].title.set_text("Detectability vs radial position") + + axs[1].plot(dist_distance_range, distinguishabilities) + axs[1].set_xlabel("Separation distance (radius fraction)") + axs[1].set_ylabel("Distinguishability (dB)") + axs[1].title.set_text("Distinguishability vs separation distance") + + fig.set_size_inches((10, 4)) + fig.tight_layout() + + fig, axs = plt.subplots(1, len(det_center_range)) + for i, c in enumerate(det_center_range): + axs[i].imshow(detectability_renders[i]) + axs[i].xaxis.set_ticks([]) + axs[i].yaxis.set_ticks([]) + + fig.set_size_inches((14, 2)) + fig.suptitle("Detectability Renders") + fig.tight_layout() + + fig, axs = plt.subplots(1, len(dist_distance_range)) + for i, d in enumerate(dist_distance_range): + img = axs[i].imshow(distinguishabilitiy_renders[i]) + axs[i].xaxis.set_ticks([]) + axs[i].yaxis.set_ticks([]) + colorbar(img) + + fig.set_size_inches((18, 2)) + fig.suptitle("Distinguishability Renders") + fig.tight_layout() + plt.show() + + +if __name__ == "__main__": + main() diff --git a/examples/figures_of_merit_range.py b/examples/figures_of_merit_range.py index f9b0d79..f2e22cc 100644 --- a/examples/figures_of_merit_range.py +++ b/examples/figures_of_merit_range.py @@ -88,15 +88,14 @@ def main(): axs[i], solution, recon_mesh, - ax_kwargs={"title": f"Target pos: {plot_c[i]:.2f}/r"}, + ax_kwargs={"title": f"Target Pos: {plot_c[i]:.2f}/r"}, ) - fig.set_size_inches(15, 2) + fig.set_size_inches(12, 2) fig.tight_layout() figs_list = np.array(figs_list) - fig, axs = plt.subplots(5, 1, sharex=True) - axs[4].set_xlabel("Target pos/r") + fig, axs = plt.subplots(1, 5) titles = [ "Average Amplitude", "Position Error", @@ -106,9 +105,12 @@ def main(): ] for i in range(5): axs[i].plot(c_range, figs_list[:, i]) - axs[i].set_title(titles[i], size="small") + axs[i].set_title(f"{titles[i]}\nvs Target Pos") + axs[i].set_xlabel("Target Pos/r") + axs[i].set_ylabel(titles[i]) - plt.tight_layout() + fig.set_size_inches(15, 3) + fig.tight_layout() plt.show() diff --git a/examples/figures_of_merit_single.py b/examples/figures_of_merit_single.py index a005086..46e39ce 100644 --- a/examples/figures_of_merit_single.py +++ b/examples/figures_of_merit_single.py @@ -69,20 +69,20 @@ def main(): # Print figures of merit print("") print(f"Amplitude: Average pixel value in reconstruction image is {figs[0]:.4f}") - print(f"Position Error: {100*figs[1]:.2f}% of widest axis") + print(f"Position Error: {100 * figs[1]:.2f}% of widest axis") print( - f"Resolution: Reconstructed point radius {100*figs[2]:.2f}% of image equivalent radius" + f"Resolution: Reconstructed point radius {100 * figs[2]:.2f}% of image equivalent radius" ) print( - f"Shape Deformation: {100*figs[3]:.2f}% of pixels in the thresholded image are outside the equivalent circle" + f"Shape Deformation: {100 * figs[3]:.2f}% of pixels in the thresholded image are outside the equivalent circle" ) print( - f"Ringing: Ringing pixel amplitude is {100*figs[4]:.2f}% of image amplitude in thresholded region" + f"Ringing: Ringing pixel amplitude is {100 * figs[4]:.2f}% of image amplitude in thresholded region" ) # Create mesh plots fig, axs = plt.subplots(1, 2) - create_mesh_plot(axs[0], sim_mesh, ax_kwargs={"title": "Sim mesh"}) + create_mesh_plot(axs[0], sim_mesh_new, ax_kwargs={"title": "Sim mesh"}) create_mesh_plot(axs[1], recon_mesh, ax_kwargs={"title": "Recon mesh"}) fig.set_size_inches(10, 4) diff --git a/pyeit/io/oeit.py b/pyeit/io/oeit.py new file mode 100644 index 0000000..a41edb4 --- /dev/null +++ b/pyeit/io/oeit.py @@ -0,0 +1,35 @@ +import numpy as np +from scipy import stats + + +def load_oeit_data(file_name): + with open(file_name, "r") as f: + lines = f.readlines() + + data = [] + for line in lines: + eit = parse_oeit_line(line) + if eit is not None: + data.append(eit) + + mode_len = stats.mode([len(item) for item in data], keepdims=False) + data = [item for item in data if len(item) == mode_len.mode] + + return np.array(data) + + +def parse_oeit_line(line): + try: + _, data = line.split(":", 1) + except (ValueError, AttributeError): + return None + items = [] + for item in data.split(","): + item = item.strip() + if not item: + continue + try: + items.append(float(item)) + except ValueError: + return None + return np.array(items) diff --git a/pyeit/quality/eit_system.py b/pyeit/quality/eit_system.py new file mode 100644 index 0000000..80b063b --- /dev/null +++ b/pyeit/quality/eit_system.py @@ -0,0 +1,292 @@ +import numpy as np +from numpy.typing import NDArray +import allantools +from pyeit.eit.protocol import PyEITProtocol +from pyeit.quality.merit import calc_fractional_amplitude_set + +""" +eit_system.py contains calculation for the performance of EIT hardware systems based on measured data. These performance +measures are defined in "Evaluation of EIT system performance" by Mamatjan Yasin et. al. (With Andy Adler) +doi:10.1088/0967-3334/32/7/S09 +""" + + +def calc_signal_to_noise_ratio(measurements: NDArray, method="ratio") -> NDArray: + """ + Signal to noise ratio calculates the mean measurement divided by the standard deviation of measurements for each + channel. (For this calculation, a channel is defined as a unique combination of stimulation and measurement + electrodes) + + The measurements array must contain at least two sets of measurements + + Parameters + ---------- + measurements + NDArray containing a number of repeated EIT measurements + + Returns + ------- + snr + NDArray of signal to noise ratio for each individual channel in the EIT measurements + + """ + # Flatten measurements in case they are in a 2d array + measurements = measurements.reshape((measurements.shape[0], -1)) + + stdev = np.std(measurements, axis=0) + average = np.average(measurements, axis=0) + + snr = average / stdev + + if method == "ratio": + return snr + elif method == "db": + return ( + np.log10(np.abs(snr)) * 20 + ) # Convert to decibels as a root-power quantity + else: + raise ValueError("Invalid method specified (must be ratio or db)") + + +def calc_accuracy( + measurements: NDArray, reference_measurements: NDArray, method="Ratio" +) -> NDArray: + """ + Accuracy measures the closeness of measured quantities to a "true" reference value. In this case simulated EIT + measurements are used as the reference + + The measurements array must contain at least two sets of measurements + + Parameters + ---------- + measurements + reference_measurements + method + Options: EIDORS, Ratio + + Returns + ------- + accuracy + """ + # Flatten measurements in case they are in a 2d array + measurements = measurements.reshape((measurements.shape[0], -1)) + reference_measurements = reference_measurements.reshape(-1) + + average = np.average(measurements, axis=0) + + if method == "EIDORS": + # Normalize measurement sets individually (This is like calibrating by scaling and offsetting, so we only see + # the difference between channels. But it doesn't necessarily give you true information about any particular + # channel. So maybe it would be better as a range?) + average = (average - np.min(average)) / (np.max(average) - np.min(average)) + reference_measurements = ( + reference_measurements - np.min(reference_measurements) + ) / (np.max(reference_measurements) - np.min(reference_measurements)) + accuracy = 1 - np.abs(average - reference_measurements) + + elif method == "Ratio": + # This is as described in Gagnon 2010 - "A Resistive Mesh Phantom for Assissing the Performance of EIT Systems" + accuracy = 1 - np.abs( + (average - reference_measurements) / reference_measurements + ) + else: + raise ValueError("Invalid method selected for accuracy") + + return accuracy + + +def calc_drift( + measurements: NDArray, sampling_rate: float = 1, sample_period=None, method="Allan" +): + """ + Drift is a measure of the change in average value of measurements over time. There are two methods for calculating + this. The EIDORS method uses the Allan variance, and the Delta method calculates the difference between two + samples taken from the start and end of the total list of measurements. + + Returns + ------- + method: "Allan" + t2: the set of sampling periods used + adevs: the list of allan deviations calculated for each channel + + method: "Delta" + drifts: drifts calculated for each channel + + """ + # Flatten measurements in case they are in a 2d array + measurements = measurements.reshape((measurements.shape[0], -1)) + + if method == "Allan": + # Iterate through each channel + adevs = [] + for channel_measurements in measurements.T: + (t2, ad, ade, adn) = allantools.oadev( + channel_measurements, rate=sampling_rate, data_type="freq", taus="all" + ) + adevs.append(ad) + + adevs = np.array(adevs) + return t2, adevs + + elif method == "Delta": + drifts = [] + for channel_measurements in measurements.T: + start = np.average(channel_measurements[0 : sampling_rate * sample_period]) + end = np.average( + np.flip(channel_measurements)[0 : sampling_rate * sample_period] + ) + drifts.append(end - start) + drifts = np.array(drifts) + return drifts + + +def calc_reciprocity_accuracy( + measurements: NDArray, protocol: PyEITProtocol +) -> NDArray: + """ + Tests the closeness of reciprocal measurements to each other. This is in accordance with the principle in + "Reciprocity Applied to Volume Conductors and the ECG" (Plonsey 1963). The interpretation of this in + "Evaluation of EIT system performance" is as follows: "EIT measurements from a stimulation–measurement pair should + not change if the current stimulation and voltage measurement electrodes are swapped." + + The measurements array must contain at least two sets of measurements + + Parameters + ---------- + measurements + Array of EIT measurements + protocol + PyEITProtocol object listing the excitation and measurement electrodes for each row in the EIT measurement + array + + Returns + ------- + reciprocal_accuracies + Array of accuracy calculations for reciprocal pairs + + """ + combined_mat = np.hstack( + (protocol.meas_mat[:, :2], protocol.ex_mat[protocol.meas_mat[:, 2]]) + ) + reciprocals = find_reciprocals(combined_mat) + + # Flatten measurements in case they are in a 2d array + measurements = measurements.reshape((measurements.shape[0], -1)) + average = np.average(measurements, axis=0) + + reciprocal_accuracies = [] + for reciprocal in reciprocals: + v = average[reciprocal[0]] + vr = average[reciprocal[1]] + + reciprocal_accuracy = 1 - np.abs(v - vr) / np.abs(v) + reciprocal_accuracies.append(reciprocal_accuracy) + + return np.array(reciprocal_accuracies) + + +def find_reciprocals(combined_mat: NDArray) -> NDArray: + """ + Find reciprocal rows in an Nx4 array. Reciprocals are defined as a pair of rows where the two pairs of elements + are swapped, i.e., element 0 and 1 in row a are the same as element 2 and 3 in row b, and element 2 and 3 in row a + are the same as element 0 and 1 in row b. Order does not matter. + + If any row has no reciprocal, a ValueError will be raised. + + Parameters + ---------- + combined_mat + Nx4 array where all rows have a reciprocal + + Returns + ------- + reciprocals + Nx2 array with indices of reciprocal rows. (Duplicates are removed) + + """ + reciprocals = set() + for i, row in enumerate(combined_mat): + reciprocal = find_reciprocal(row, combined_mat) + reciprocals.add( + frozenset((i, reciprocal)) + ) # Append as inner set to outer set to filter out duplicates + + reciprocals = np.array( + [list(r) for r in list(reciprocals)] + ) # Convert back to array + + return reciprocals + + +def find_reciprocal(row: NDArray, combined_mat: NDArray) -> int: + """ + Auxilliary function to find_reciprocals. Finds the reciprocal of a single row. Raises ValueError if one is not found. + + Parameters + ---------- + row + Row to find reciprocal of + combined_mat + Array to search + + Returns + ------- + i + Index of reciprocal row + + """ + # Reciprocal pairs can be in any order (because the signal is AC). So we compare rows as sets + reciprocal_row = np.array(({*row[2:]}, {*row[0:2]})) + + for i, compare_row in enumerate(combined_mat): + if np.array_equal( + reciprocal_row, np.array(({*compare_row[0:2]}, {*compare_row[2:]})) + ): + return i + + raise ValueError("No reciprocal found") + + +def calc_detectability( + image, + conductive_target: bool = True, + fraction=0.25, + fraction_method="GREIT", + method="ratio", +): + """ + See Adler et. al. 2010 "Distinguishability in EIT using a hypothesis-testing model". This creates a z statistic + so how do we calculate probability of null hypothesis? + + Parameters + ---------- + image + conductive_target + fraction + fraction_method + + Returns + ------- + + """ + + fractional_image = calc_fractional_amplitude_set( + image, + conductive_target=conductive_target, + fraction=fraction, + method=fraction_method, + ) + + mean = np.abs(np.mean(image[fractional_image == 1])) + std = np.std(image[fractional_image == 1]) + + detectability = mean / std + + if method == "ratio": + return detectability + elif method == "db": + return ( + np.log10(np.abs(detectability)) * 20 + ) # Convert to decibels as a root-power quantity + else: + raise ValueError("Invalid method specified (must be ratio or db)") diff --git a/pyeit/visual/plot.py b/pyeit/visual/plot.py index 1ec4d21..e9c8bae 100644 --- a/pyeit/visual/plot.py +++ b/pyeit/visual/plot.py @@ -15,6 +15,8 @@ patches as mpatches, axes as mpl_axes, ) +from mpl_toolkits.axes_grid1 import make_axes_locatable +import matplotlib.colorbar def ts_plot(ts, figsize=(6, 4), ylabel="ATI (Ohm)", ylim=None, xdate_format=True): @@ -100,7 +102,8 @@ def create_mesh_plot( # Add mesh to ax ax.add_collection(pc) - ax.figure.colorbar(pc, ax=ax, label="Element Value") + cb = colorbar(pc) + cb.set_label("Element Value") ax.autoscale() ax.set_xticks([], labels=None) ax.set_yticks([], labels=None) @@ -344,7 +347,7 @@ def create_plot( tripcolor_keys_map = {"vmin": vmin, "vmax": vmax} tripcolor_kwargs = {k: v for k, v in tripcolor_keys_map.items() if v is not None} plot_image = ax.tripcolor(x, y, elements, eit_image, **tripcolor_kwargs) - ax.figure.colorbar(plot_image) + colorbar(plot_image) ax.set_xticks([], labels=None) ax.set_yticks([], labels=None) ax.set(**ax_kwargs) @@ -395,7 +398,7 @@ def create_image_plot( ax.set_xbound(img_bounds[2] - margin, img_bounds[3] + margin) ax.set_title(title) - ax.figure.colorbar(im) + colorbar(im) return im @@ -494,3 +497,26 @@ def get_img_bounds(img, background=np.nan): ymax = j - 1 return xmin, xmax, ymin, ymax + + +def colorbar(mappable: matplotlib.cm.ScalarMappable) -> matplotlib.colorbar.Colorbar: + """ + Add a colorbar that matches the height of its corresponding image + + Parameters + ---------- + mappable + + Returns + ------- + cbar + + """ + last_axes = plt.gca() + ax = mappable.axes + fig = ax.figure + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.1) + cbar = fig.colorbar(mappable, cax=cax) + plt.sca(last_axes) + return cbar diff --git a/setup.cfg b/setup.cfg index 83828e7..bf7aeb4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,7 @@ install_requires = shapely trimesh imageio + allantools test_suite = tests tests_require = diff --git a/tests/test_eit_system.py b/tests/test_eit_system.py new file mode 100644 index 0000000..b4d3d58 --- /dev/null +++ b/tests/test_eit_system.py @@ -0,0 +1,279 @@ +import numpy as np +from pyeit.quality.eit_system import ( + calc_signal_to_noise_ratio, + calc_accuracy, + calc_drift, + find_reciprocal, + find_reciprocals, + calc_reciprocity_accuracy, +) +from numpy.testing import assert_array_equal +from numpy.random import default_rng +import matplotlib.pyplot as plt +import pyeit.eit.protocol as protocol + + +def test_snr(): + measurements = np.array([[[1, 2, 3], [4, 5, 6]], [[2, 3, 4], [5, 6, 7]]]) + snr = calc_signal_to_noise_ratio(measurements) + snr_with_flat = calc_signal_to_noise_ratio(measurements.reshape((2, -1))) + + correct_snr = np.array([3, 5, 7, 9, 11, 13]) + + assert_array_equal(snr, correct_snr) + assert_array_equal(snr_with_flat, correct_snr) + + +def test_accuracy(): + measurements = np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) + reference_measurements = np.array([[1.25, 2.5, 3.75], [5, 6.25, 7.5]]) + + accuracy_ratio = calc_accuracy(measurements, reference_measurements, method="Ratio") + correct_accuracy_ratio = np.array([0.8, 0.8, 0.8, 0.8, 0.8, 0.8]) + + accuracy_eidors = calc_accuracy( + measurements, reference_measurements, method="EIDORS" + ) + correct_accuracy_eidors = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + + assert_array_equal(accuracy_ratio, correct_accuracy_ratio) + assert_array_equal(accuracy_eidors, correct_accuracy_eidors) + + +def test_drift_allan(): + measurements = np.arange(1, 11) + rng = default_rng(0) + noise = rng.normal(0, 0.5, (100, 10)) + measurements = measurements + noise + + correct_t2 = np.arange(1, 50) + correct_adevs_0 = np.array( + [ + 0.50144798, + 0.34189729, + 0.26004536, + 0.21503377, + 0.19580965, + 0.17977422, + 0.16560979, + 0.15510351, + 0.13790106, + 0.12184227, + 0.11378077, + 0.1028532, + 0.09888532, + 0.09849245, + 0.09000506, + 0.0825436, + 0.0737621, + 0.06756625, + 0.06567533, + 0.0643697, + 0.05971835, + 0.05578692, + 0.05155231, + 0.04734716, + 0.04230214, + 0.03907241, + 0.03439275, + 0.03255361, + 0.03208408, + 0.0327237, + 0.03250187, + 0.03394281, + 0.03573658, + 0.03737299, + 0.03849052, + 0.03712749, + 0.03684584, + 0.03460951, + 0.03283292, + 0.02997242, + 0.02945563, + 0.02765693, + 0.02602859, + 0.02549113, + 0.0244943, + 0.02455454, + 0.01994969, + 0.0135816, + 0.00945514, + ] + ) + + t2, adevs = calc_drift(measurements, sampling_rate=1, method="Allan") + + # fig, ax = plt.subplots() + # for adev in adevs: + # ax.plot(t2, adev, ".") + # ax.set_title("Allan Deviation, 10 Channels") + # plt.show() + + np.testing.assert_array_equal(t2, correct_t2) + np.testing.assert_array_almost_equal(adevs[0], correct_adevs_0) + + +def test_drift_allan_with_drift(): + measurements = np.arange(1, 11) + rng = default_rng(0) + noise = rng.normal(0, 0.5, (100, 10)) + drift = (np.arange(1, 101) / 100).reshape((1, -1)) + measurements = measurements + noise + drift.T + + t2, adevs = calc_drift(measurements, sampling_rate=1, method="Allan") + + correct_t2 = np.arange(1, 50) + correct_adevs_0 = np.array( + [ + 0.50151151, + 0.34240022, + 0.26094362, + 0.21688356, + 0.19928049, + 0.18489464, + 0.17249367, + 0.16404771, + 0.15007923, + 0.13812585, + 0.13544855, + 0.13172051, + 0.13419322, + 0.13919843, + 0.13771382, + 0.13777519, + 0.13701817, + 0.13801215, + 0.14205622, + 0.14714887, + 0.14925671, + 0.15213755, + 0.1549156, + 0.15761515, + 0.16228652, + 0.1684516, + 0.17457164, + 0.17906281, + 0.18619088, + 0.19295449, + 0.19888729, + 0.20556761, + 0.21200783, + 0.21823937, + 0.22585266, + 0.23350849, + 0.24107197, + 0.249568, + 0.2585531, + 0.26629471, + 0.27500569, + 0.28331606, + 0.29105864, + 0.29524983, + 0.30406007, + 0.31541238, + 0.32955649, + 0.34293692, + 0.34950794, + ] + ) + + # fig, ax = plt.subplots() + # for adev in adevs: + # ax.plot(t2, adev, ".") + # ax.set_title("Allan Deviation, 10 Channels") + # plt.show() + + np.testing.assert_array_equal(t2, correct_t2) + np.testing.assert_array_almost_equal(adevs[0], correct_adevs_0) + + +def test_drift_delta(): + measurements = np.arange(1, 11) + rng = default_rng(0) + noise = rng.normal(0, 0.5, (100, 10)) + measurements = measurements + noise + + correct_drifts = np.array( + [ + 0.04459828, + -0.52595789, + -0.33285465, + 0.0782566, + -0.02511503, + -0.09873106, + -0.27501217, + -0.28731926, + -0.27797152, + -0.02059609, + ] + ) + drifts = calc_drift(measurements, sampling_rate=1, sample_period=10, method="Delta") + + # print("\n") + # for i, drift in enumerate(drifts): + # print(f"Channel {i + 1} drift: {drift:.4f}") + + np.testing.assert_array_almost_equal(drifts, correct_drifts) + + +def test_drift_delta_with_drift(): + measurements = np.arange(1, 11) + rng = default_rng(0) + noise = rng.normal(0, 0.5, (100, 10)) + drift = (np.arange(1, 101) / 100).reshape((1, -1)) # 1/100 drift per second + measurements = measurements + noise + drift.T + + correct_drifts = np.array( + [ + 0.94459828, + 0.37404211, + 0.56714535, + 0.9782566, + 0.87488497, + 0.80126894, + 0.62498783, + 0.61268074, + 0.62202848, + 0.87940391, + ] + ) + drifts = calc_drift(measurements, sampling_rate=1, sample_period=10, method="Delta") + + # total_period = 1*100 # Sampling rate times n_samples + # sampling_time = 1*10 # Sampling rate times sample period + # period_between_samples = total_period - 2*sampling_time + # drifts_per_second = drifts/period_between_samples + # print("\n") + # for i, (drift, drift_per_second) in enumerate(zip(drifts, drifts_per_second)): + # print(f"Channel {i+1} drift: {drift:.4f}, drift per second: {drift_per_second:.4f}") + # print(f"Average drift per second: {np.average(drifts_per_second):.4f}") + + np.testing.assert_array_almost_equal(drifts, correct_drifts) + + +def test_find_reciprocal(): + arr = np.array([[1, 2, 3, 4], [3, 4, 2, 1], [4, 5, 6, 7]]) + correct_reciprocal = 1 + reciprocal = find_reciprocal(arr[0], arr) + + assert reciprocal == correct_reciprocal + + +def test_find_reciprocals(): + arr = np.array([[1, 2, 3, 4], [3, 4, 2, 1], [4, 5, 6, 7], [6, 7, 4, 5]]) + correct_reciprocals = np.array([[0, 1], [2, 3]]) + + reciprocals = find_reciprocals(arr) + + np.testing.assert_array_equal(reciprocals, correct_reciprocals) + + +def test_calc_reciprocity_accuracy(): + protocol_obj = protocol.create(4, dist_exc=1, step_meas=1, parser_meas="std") + data = np.array([[1, 1, 0.9, 0.8], [1, 1, 0.9, 0.8]]) + + correct_reciprocity_accuracy = np.array([0.9, 0.8]) + + reciprocity_accuracy = calc_reciprocity_accuracy(data, protocol_obj) + + np.testing.assert_array_equal(correct_reciprocity_accuracy, reciprocity_accuracy)