diff --git a/euphonic/spectra.py b/euphonic/spectra.py index 6d4a89d85..e26092799 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -3,17 +3,23 @@ from abc import ABC, abstractmethod import collections import copy +from functools import partial, reduce import itertools import math import json from numbers import Integral, Real -from typing import (Any, Callable, Dict, List, Literal, Optional, overload, - Sequence, Tuple, TypeVar, Union, Type) +from operator import contains +from typing import (Any, Callable, Dict, Generator, List, Literal, Optional, + overload, Sequence, Tuple, TypeVar, Union, Type) +from typing_extensions import Self import warnings from pint import DimensionalityError, Quantity import numpy as np from scipy.ndimage import correlate1d, gaussian_filter +from toolz.dicttoolz import keyfilter, valmap +from toolz.functoolz import complement +from toolz.itertoolz import groupby, pluck from euphonic import ureg, __version__ from euphonic.broadening import (ErrorFit, KernelShape, @@ -21,7 +27,6 @@ from euphonic.io import (_obj_to_json_file, _obj_from_json_file, _obj_to_dict, _process_dict) from euphonic.readers.castep import read_phonon_dos_data -from euphonic.util import _get_unique_elems_and_idx from euphonic.validate import _check_constructor_inputs, _check_unit_conversion @@ -564,8 +569,9 @@ def from_castep_phonon_dos(cls: Type[T], filename: str, metadata['species'] = element metadata['label'] = element - return cls(data['dos_bins']*ureg(data['dos_bins_unit']), - data['dos'][element]*ureg(data['dos_unit']), + return cls(ureg.Quantity(data["dos_bins"], + units=data["dos_bins_unit"]), + ureg.Quantity(data["dos"][element], units=data["dos_unit"]), metadata=metadata) @overload @@ -644,7 +650,8 @@ def broaden(self: T, x_width, self.y_data.magnitude, [self.get_bin_centres().magnitude], [x_width.to(self.x_data_unit).magnitude], - shape=shape, method=method) * ureg(self.y_data_unit) + shape=shape, method=method) + y_broadened = ureg.Quantity(y_broadened, units=self.y_data_unit) elif isinstance(x_width, Callable): self.assert_regular_bins(message=( @@ -669,10 +676,438 @@ def broaden(self: T, x_width, return new_spectrum -LineData = Sequence[Dict[str, Union[str, int]]] +OneLineData = Dict[str, Union[str, int]] +LineData = Sequence[OneLineData] +Metadata = Dict[str, Union[str, int, LineData]] -class Spectrum1DCollection(collections.abc.Sequence, Spectrum): +class SpectrumCollectionMixin(ABC): + """Help a collection of spectra work with "line_data" metadata file + + This is a Mixin to be inherited by Spectrum collection classes + + To avoid redundancy, spectrum collections store metadata in the form + + {"key1": value1, "key2", value2, "line_data": [{"key3": value3, ...}, + {"key4": value4, ...}...]} + + - It is not guaranteed that all "lines" carry the same keys + - No key should appear at both top-level and in line-data; any key-value + pair at top level is assumed to apply to all lines + - "lines" can actually correspond to N-D spectra, the notation was devised + for multi-line plots of Spectrum1DCollection and then applied to other + purposes. + + The _spectrum_axis class attribute determines which axis property contains + the spectral data, and should be set by subclasses (i.e. to "y" or "z" for + 1D or 2D). + """ + + # Subclasses must define which axis contains the spectral data for + # purposes of splitting, indexing, etc. + # Python doesn't support abstract class attributes so we define a default + # value, ensuring _something_ was set. + _bin_axes = ("x",) + _spectrum_axis = "y" + _item_type = Spectrum1D + + # Define some private methods which wrap this information into useful forms + @classmethod + def _spectrum_data_name(cls) -> str: + return f"{cls._spectrum_axis}_data" + + @classmethod + def _spectrum_raw_data_name(cls) -> str: + return f"_{cls._spectrum_axis}_data" + + def _get_spectrum_data(self) -> Quantity: + return getattr(self, self._spectrum_data_name()) + + def _get_raw_spectrum_data(self) -> np.ndarray: + return getattr(self, self._spectrum_raw_data_name()) + + def _set_spectrum_data(self, data: Quantity) -> None: + setattr(self, self._spectrum_data_name(), data) + + def _set_raw_spectrum_data(self, data: np.ndarray) -> None: + setattr(self, self._spectrum_raw_data_name(), data) + + def _get_spectrum_data_unit(self) -> str: + return getattr(self, f"{self._spectrum_data_name()}_unit") + + def _get_internal_spectrum_data_unit(self) -> str: + return getattr(self, f"_internal_{self._spectrum_data_name()}_unit") + + def _get_bin_kwargs(self) -> Dict[str, Quantity]: + """Get constructor args for bin axes from current data + + e.g. for Spectrum2DCollection this is + + {"x_data": self.x_data, "y_data": self.y_data} + """ + return {f"{axis}_data": getattr(self, f"{axis}_data") + for axis in self._bin_axes} + + @classmethod + def _get_item_data(cls, item: Spectrum) -> Quantity: + return getattr(item, f"{cls._spectrum_axis}_data") + + @classmethod + def _get_item_raw_data(cls, item: Spectrum) -> np.ndarray: + return getattr(item, f"_{cls._spectrum_axis}_data") + + @classmethod + def _get_item_data_unit(cls, item: Spectrum) -> str: + return getattr(item, f"{cls._spectrum_axis}_data_unit") + + def sum(self) -> Spectrum: + """ + Sum collection to a single spectrum + + Returns + ------- + summed_spectrum + A single combined spectrum from all items in collection. Any + metadata in 'line_data' not common across all spectra will be + discarded + """ + metadata = copy.deepcopy(self.metadata) + metadata.pop('line_data', None) + metadata.update(self._tidy_metadata()) + summed_s_data = ureg.Quantity( + np.sum(self._get_raw_spectrum_data(), axis=0), + units=self._get_internal_spectrum_data_unit() + ).to(self._get_spectrum_data_unit()) + return self._item_type( + **self._get_bin_kwargs(), + **{self._spectrum_data_name(): summed_s_data}, + x_tick_labels=copy.copy(self.x_tick_labels), + metadata=metadata + ) + + # Required methods + @classmethod + @abstractmethod + def from_spectra(cls, spectra: Sequence[Spectrum]) -> Self: + """Construct spectrum collection from a sequence of components""" + ... + + # Mixin methods + def __len__(self): + return self._get_raw_spectrum_data().shape[0] + + @overload + def __getitem__(self, item: int) -> Spectrum: + ... + + @overload # noqa: F811 + def __getitem__(self, item: slice) -> Self: + ... + + @overload # noqa: F811 + def __getitem__(self, item: Union[Sequence[int], np.ndarray]) -> Self: + ... + + def __getitem__( + self, item: Union[Integral, slice, Sequence[Integral], np.ndarray] + ): # noqa: F811 + self._validate_item(item) + init_kwargs = { + self._spectrum_data_name(): self._get_spectrum_data()[item, :], + "x_tick_labels": self.x_tick_labels, + "metadata": self._get_item_metadata(item) + } | self._get_bin_kwargs() + + if isinstance(item, Integral): + return self._item_type(**init_kwargs) + + return type(self)(**init_kwargs) + + def _validate_item(self, item: Integral | slice | Sequence[Integral] | np.ndarray + ) -> None: + """Raise Error if index has inappropriate typing/ranges + + Raises + ------ + IndexError + Slice is not compatible with size of collection + + TypeError + item specification does not have acceptable type; e.g. a sequence + of float or bool was provided when ints are needed. + + """ + if isinstance(item, Integral): + return + if isinstance(item, slice): + if (item.stop is not None) and (item.stop >= len(self)): + raise IndexError(f'index "{item.stop}" out of range') + return + + if not all(isinstance(i, Integral) for i in item): + raise TypeError( + f'Index "{item}" should be an integer, slice ' + f'or sequence of ints') + + @overload + def _get_item_metadata(self, item: Integral) -> OneLineData: + """Get a single metadata item with no line_data""" + + @overload + def _get_item_metadata(self, item: slice | Sequence[Integral] | np.ndarray + ) -> Metadata: # noqa: F811 + """Get a metadata collection (may include line_data)""" + + def _get_item_metadata(self, item): # noqa: F811 + """Produce appropriate metadata for __getitem__""" + metadata_lines = list(self.iter_metadata()) + + if isinstance(item, Integral): + return metadata_lines[item] + if isinstance(item, slice): + return self._combine_metadata(metadata_lines[item]) + # Item must be some kind of integer sequence + return self._combine_metadata([metadata_lines[i] for i in item]) + + def copy(self) -> Self: + """Get an independent copy of spectrum""" + return self._item_type.copy(self) + + def __add__(self, other: Self) -> Self: + """ + Appends the y_data of 2 Spectrum1DCollection objects, + creating a single Spectrum1DCollection that contains + the spectra from both objects. The two objects must + have equal x_data axes, and their y_data must + have compatible units and the same number of y_data + entries + + Any metadata key/value pairs that are common to both + spectra are retained in the top level dictionary, any + others are put in the individual 'line_data' entries + """ + return type(self).from_spectra([*self, *other]) + + def iter_metadata(self) -> Generator[OneLineData, None, None]: + """Iterate over metadata dicts of individual spectra from collection""" + common_metadata = {key: value for key, value in self.metadata.items() + if key != "line_data"} + + + line_data = self.metadata.get("line_data") + if line_data is None: + line_data = itertools.repeat({}, len(self._get_raw_spectrum_data())) + + for one_line_data in line_data: + yield common_metadata | one_line_data + + def _select_indices(self, **select_key_values) -> list[int]: + """Get indices of items that match metadata query + + The target key-value pairs are a subset of the matching data, e.g. + + self._select_indices(species="Na", weight="coherent") + + will match metadata rows + + {"species": "Na", "weight": "coherent"} + + and + + {"species": "Na", "weight": "coherent", "mass": "22.9898"} + + but not + + {"species": "Na"} + + or + + {"species": "K", "weight": "coherent"} + """ + required_metadata = select_key_values.items() + indices = [i for i, row in enumerate(self.iter_metadata()) + if required_metadata <= row.items()] + return indices + + def select(self, **select_key_values: Union[ + str, int, Sequence[str], Sequence[int]]) -> Self: + """ + Select spectra by their keys and values in metadata['line_data'] + + Parameters + ---------- + **select_key_values + Key-value/values pairs in metadata['line_data'] describing + which spectra to extract. For example, to select all spectra + where metadata['line_data']['species'] = 'Na' or 'Cl' use + spectrum.select(species=['Na', 'Cl']). To select 'Na' and + 'Cl' spectra where weighting is also coherent, use + spectrum.select(species=['Na', 'Cl'], weighting='coherent') + + Returns + ------- + selected_spectra + A Spectrum1DCollection containing the selected spectra + + Raises + ------ + ValueError + If no matching spectra are found + """ + # Convert all items to sequences of possibilities + def ensure_sequence(value: int | str | Sequence[int | str] + ) -> Sequence[int | str]: + return (value,) if isinstance(value, (int, str)) else value + + select_key_values = valmap(ensure_sequence, select_key_values) + + + # Collect indices that match each combination of values + selected_indices = [] + for value_combination in itertools.product(*select_key_values.values() + ): + selection = dict(zip(select_key_values.keys(), value_combination)) + selected_indices.extend(self._select_indices(**selection)) + + if not selected_indices: + raise ValueError(f'No spectra found with matching metadata ' + f'for {select_key_values}') + + return self[selected_indices] + + @staticmethod + def _combine_metadata(all_metadata: LineData) -> Metadata: + """ + From a sequence of metadata dictionaries, combines all common + key/value pairs into the top level of a metadata dictionary, + all unmatching key/value pairs are put into the 'line_data' + key, which is a list of metadata dicts for each element in + all_metadata + """ + # This is for combining multiple separate spectrum metadata, + # they shouldn't have line_data + for metadata in all_metadata: + assert 'line_data' not in metadata + + # Combine key-value pairs common to *all* metadata lines into new dict + common_metadata = dict( + reduce(set.intersection, + (set(metadata.items()) for metadata in all_metadata))) + + # Put all other per-spectrum metadata in line_data + is_common = partial(contains, common_metadata) + line_data = [keyfilter(complement(is_common), one_line_data) + for one_line_data in all_metadata] + + if any(line_data): + return common_metadata | {'line_data': line_data} + + return common_metadata + + def _tidy_metadata(self) -> Metadata: + """ + For a metadata dictionary, combines all common key/value + pairs in 'line_data' and puts them in a top-level dictionary. + """ + line_data = self.metadata.get("line_data", [{}] * len(self)) + combined_line_data = self._combine_metadata(line_data) + combined_line_data.pop("line_data", None) + return combined_line_data + + def _check_metadata(self) -> None: + """Check self.metadata['line_data'] is consistent with collection size + + Raises + ------ + ValueError + Metadata contains 'line_data' of incorrect length + + """ + if 'line_data' in self.metadata: + collection_size = len(self._get_raw_spectrum_data()) + n_lines = len(self.metadata['line_data']) + + if n_lines != collection_size: + raise ValueError( + f'{self._spectrum_data_name()} contains {collection_size} ' + f'spectra, but metadata["line_data"] contains ' + f'{n_lines} entries') + + def group_by(self, *line_data_keys: str) -> Self: + """ + Group and sum elements of spectral data according to the values + mapped to the specified keys in metadata['line_data'] + + Parameters + ---------- + line_data_keys + The key(s) to group by. If only one line_data_key is + supplied, if the value mapped to a key is the same for + multiple spectra, they are placed in the same group and + summed. If multiple line_data_keys are supplied, the values + must be the same for all specified keys for them to be + placed in the same group + + Returns + ------- + grouped_spectrum + A new Spectrum1DCollection with one line for each group. Any + metadata in 'line_data' not common across all spectra in a + group will be discarded + """ + def get_key_items(enumerated_metadata: tuple[int, OneLineData] + ) -> tuple[str | int, ...]: + """Get sort keys from an item of enumerated input to groupby + + e.g. with line_data_keys=("a", "b") + + (0, {"a": 4, "d": 5}) --> (4, None) + """ + return tuple(enumerated_metadata[1].get(item, None) + for item in line_data_keys) + + # First element of each tuple is the index + indices = partial(pluck, 0) + + groups = groupby(get_key_items, enumerate(self.iter_metadata())) + + return self.from_spectra([self[list(indices(group))].sum() + for group in groups.values()]) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert to a dictionary consistent with from_dict() + + Returns + ------- + dict + """ + attrs = [*self._get_bin_kwargs().keys(), + self._spectrum_data_name(), + 'x_tick_labels', + 'metadata'] + + return _obj_to_dict(self, attrs) + + @classmethod + def from_dict(cls: Self, d: dict) -> Self: + """Initialise a Spectrum Collection object from dict""" + data_keys = [f"{dim}_data" for dim in cls._bin_axes] + data_keys.append(cls._spectrum_data_name()) + + d = _process_dict(d, + quantities=data_keys, + optional=['x_tick_labels', 'metadata']) + + data_args = [d[key] for key in data_keys] + return cls(*data_args, + x_tick_labels=d['x_tick_labels'], + metadata=d['metadata']) + + +class Spectrum1DCollection(SpectrumCollectionMixin, + Spectrum, + collections.abc.Sequence): """A collection of Spectrum1D with common x_data and x_tick_labels Intended for convenient storage of band structures, projected DOS @@ -706,6 +1141,10 @@ class Spectrum1DCollection(collections.abc.Sequence, Spectrum): """ T = TypeVar('T', bound='Spectrum1DCollection') + # Private attributes used by SpectrumCollectionMixin + _spectrum_axis = "y" + _item_type = Spectrum1D + def __init__( self, x_data: Quantity, y_data: Quantity, x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, @@ -749,32 +1188,13 @@ def __init__( self._set_data(x_data, 'x') self._set_data(y_data, 'y') self.x_tick_labels = x_tick_labels - if metadata and 'line_data' in metadata.keys(): - if len(metadata['line_data']) != len(y_data): - raise ValueError( - f'y_data contains {len(y_data)} spectra, but ' - f'metadata["line_data"] contains ' - f'{len(metadata["line_data"])} entries') - self.metadata = {} if metadata is None else metadata - - def __add__(self: T, other: T) -> T: - """ - Appends the y_data of 2 Spectrum1DCollection objects, - creating a single Spectrum1DCollection that contains - the spectra from both objects. The two objects must - have equal x_data axes, and their y_data must - have compatible units and the same number of y_data - entries - Any metadata key/value pairs that are common to both - spectra are retained in the top level dictionary, any - others are put in the individual 'line_data' entries - """ - return type(self).from_spectra([*self, *other]) + self.metadata = metadata if metadata is not None else {} + self._check_metadata() def _split_by_indices(self, indices: Union[Sequence[int], np.ndarray] - ) -> List[T]: + ) -> List[Self]: """Split data along x-axis at given indices""" ranges = self._ranges_from_indices(indices) @@ -785,52 +1205,6 @@ def _split_by_indices(self, metadata=self.metadata) for x0, x1 in ranges] - def __len__(self): - return self.y_data.shape[0] - - @overload - def __getitem__(self, item: int) -> Spectrum1D: - ... - - @overload # noqa: F811 - def __getitem__(self, item: slice) -> T: - ... - - @overload # noqa: F811 - def __getitem__(self, item: Union[Sequence[int], np.ndarray]) -> T: - ... - - def __getitem__(self, item: Union[int, slice, Sequence[int], np.ndarray] - ): # noqa: F811 - new_metadata = copy.deepcopy(self.metadata) - line_metadata = new_metadata.pop('line_data', - [{} for _ in self._y_data]) - if isinstance(item, Integral): - new_metadata.update(line_metadata[item]) - return Spectrum1D(self.x_data, - self.y_data[item, :], - x_tick_labels=self.x_tick_labels, - metadata=new_metadata) - - if isinstance(item, slice): - if (item.stop is not None) and (item.stop >= len(self)): - raise IndexError(f'index "{item.stop}" out of range') - new_metadata.update(self._combine_metadata(line_metadata[item])) - else: - try: - item = list(item) - if not all([isinstance(i, Integral) for i in item]): - raise TypeError - except TypeError: - raise TypeError(f'Index "{item}" should be an integer, slice ' - f'or sequence of ints') - new_metadata.update(self._combine_metadata( - [line_metadata[i] for i in item])) - return type(self)(self.x_data, - self.y_data[item, :], - x_tick_labels=self.x_tick_labels, - metadata=new_metadata) - @classmethod def from_spectra(cls: Type[T], spectra: Sequence[Spectrum1D]) -> T: if len(spectra) < 1: @@ -862,93 +1236,6 @@ def _type_check(spectrum): return cls(x_data, y_data, x_tick_labels=x_tick_labels, metadata=metadata) - @staticmethod - def _combine_metadata(all_metadata: Sequence[Dict[str, Union[int, str]]] - ) -> Dict[str, Union[int, str, LineData]]: - """ - From a sequence of metadata dictionaries, combines all common - key/value pairs into the top level of a metadata dictionary, - all unmatching key/value pairs are put into the 'line_data' - key, which is a list of metadata dicts for each element in - all_metadata - """ - # This is for combining multiple separate spectrum metadata, - # they shouldn't have line_data - for metadata in all_metadata: - assert 'line_data' not in metadata.keys() - # Combine all common key/value pairs - combined_metadata = dict( - set(all_metadata[0].items()).intersection( - *[metadata.items() for metadata in all_metadata[1:]])) - # Put all other per-spectrum metadata in line_data - line_data = [] - for i, metadata in enumerate(all_metadata): - sdata = copy.deepcopy(metadata) - for key in combined_metadata.keys(): - sdata.pop(key) - line_data.append(sdata) - if any(line_data): - combined_metadata['line_data'] = line_data - return combined_metadata - - def _combine_line_metadata(self, indices: Optional[Sequence[int]] = None - ) -> Dict[str, Any]: - """ - For a metadata dictionary, combines all common key/value - pairs in 'line_data' and puts them in a top-level dictionary. - If indices is supplied, only those indices in 'line_data' are - combined. Unmatching key/value pairs are discarded - """ - line_data = self.metadata.get('line_data', [{}]*len(self)) - if indices is not None: - line_data = [line_data[idx] for idx in indices] - combined_line_data = self._combine_metadata(line_data) - combined_line_data.pop('line_data', None) - return combined_line_data - - def _get_line_data_vals(self, *line_data_keys: str) -> np.ndarray: - """ - Get value of the key(s) for each element in - metadata['line_data']. Returns a 1D array of tuples, where each - tuple contains the value(s) for each key in line_data_keys, for - a single element in metadata['line_data']. This allows easy - grouping/selecting by specific keys - - For example, if we have a Spectrum1DCollection with the following - metadata: - {'desc': 'Quartz', 'line_data': [ - {'inst': 'LET', 'sample': 0, 'index': 1}, - {'inst': 'MAPS', 'sample': 1, 'index': 2}, - {'inst': 'MARI', 'sample': 1, 'index': 1}, - ]} - Then: - _get_line_data_vals('inst', 'sample') = [('LET', 0), - ('MAPS', 1), - ('MARI', 1)] - - Raises a KeyError if 'line_data' or the key doesn't exist - """ - line_data = self.metadata['line_data'] - line_data_vals = np.empty(len(line_data), dtype=object) - for i, data in enumerate(line_data): - line_data_vals[i] = tuple([data[key] for key in line_data_keys]) - return line_data_vals - - def copy(self: T) -> T: - """Get an independent copy of spectrum""" - return Spectrum1D.copy(self) - - def to_dict(self) -> Dict[str, Any]: - """ - Convert to a dictionary consistent with from_dict() - - Returns - ------- - dict - """ - return _obj_to_dict(self, ['x_data', 'y_data', 'x_tick_labels', - 'metadata']) - def to_text_file(self, filename: str, fmt: Optional[Union[str, Sequence[str]]] = None) -> None: """ @@ -982,35 +1269,6 @@ def to_text_file(self, filename: str, kwargs['fmt'] = fmt np.savetxt(filename, out_data, **kwargs) - @classmethod - def from_dict(cls: Type[T], d) -> T: - """ - Convert a dictionary to a Spectrum1DCollection object - - Parameters - ---------- - d : dict - A dictionary with the following keys/values: - - - 'x_data': (n_x_data,) or (n_x_data + 1,) float ndarray - - 'x_data_unit': str - - 'y_data': (n_x_data,) float ndarray - - 'y_data_unit': str - - There are also the following optional keys: - - - 'x_tick_labels': list of (int, string) tuples - - 'metadata': dict - - Returns - ------- - spectrum_collection - """ - d = _process_dict(d, quantities=['x_data', 'y_data'], - optional=['x_tick_labels', 'metadata']) - return cls(d['x_data'], d['y_data'], x_tick_labels=d['x_tick_labels'], - metadata=d['metadata']) - @classmethod def from_castep_phonon_dos(cls: Type[T], filename: str) -> T: """ @@ -1033,8 +1291,8 @@ def from_castep_phonon_dos(cls: Type[T], filename: str) -> T: metadata['line_data'][i]['species'] = species metadata['line_data'][i]['label'] = species return Spectrum1DCollection( - data['dos_bins']*ureg(data['dos_bins_unit']), - y_data*ureg(data['dos_unit']), + ureg.Quantity(data['dos_bins'], units=data['dos_bins_unit']), + ureg.Quantity(y_data, units=data['dos_unit']), metadata=metadata) @overload @@ -1120,7 +1378,7 @@ def broaden(self: T, method=method) new_spectrum = self.copy() - new_spectrum.y_data = y_broadened * ureg(self.y_data_unit) + new_spectrum.y_data = ureg.Quantity(y_broadened, units=self.y_data_unit) return new_spectrum elif isinstance(x_width, Callable): @@ -1138,116 +1396,31 @@ def broaden(self: T, else: raise TypeError("x_width must be a Quantity or Callable") - def group_by(self, *line_data_keys: str) -> T: + @classmethod + def from_dict(cls: Self, d: dict) -> Self: """ - Group and sum y_data for each spectrum according to the values - mapped to the specified keys in metadata['line_data'] + Convert a dictionary to a Spectrum Collection object Parameters ---------- - line_data_keys - The key(s) to group by. If only one line_data_key is - supplied, if the value mapped to a key is the same for - multiple spectra, they are placed in the same group and - summed. If multiple line_data_keys are supplied, the values - must be the same for all specified keys for them to be - placed in the same group - - Returns - ------- - grouped_spectrum - A new Spectrum1DCollection with one line for each group. Any - metadata in 'line_data' not common across all spectra in a - group will be discarded - """ - grouping_dict = _get_unique_elems_and_idx( - self._get_line_data_vals(*line_data_keys)) - - new_y_data = np.zeros((len(grouping_dict), self._y_data.shape[-1])) - group_metadata = copy.deepcopy(self.metadata) - group_metadata['line_data'] = [{}]*len(grouping_dict) - for i, idxs in enumerate(grouping_dict.values()): - # Look for any common key/values in grouped metadata - group_i_metadata = self._combine_line_metadata(idxs) - group_metadata['line_data'][i] = group_i_metadata - new_y_data[i] = np.sum(self._y_data[idxs], axis=0) - new_y_data = new_y_data*ureg(self._internal_y_data_unit).to( - self.y_data_unit) - - new_data = self.copy() - new_data.y_data = new_y_data - new_data.metadata = group_metadata - - return new_data - - def sum(self) -> Spectrum1D: - """ - Sum y_data over all spectra + d : dict + A dictionary with the following keys/values: - Returns - ------- - summed_spectrum - A Spectrum1D created from the summed y_data. Any metadata - in 'line_data' not common across all spectra will be - discarded - """ - metadata = copy.deepcopy(self.metadata) - metadata.pop('line_data', None) - metadata.update(self._combine_line_metadata()) - summed_y_data = np.sum(self._y_data, axis=0)*ureg( - self._internal_y_data_unit).to(self.y_data_unit) - return Spectrum1D(np.copy(self.x_data), - summed_y_data, - x_tick_labels=copy.copy(self.x_tick_labels), - metadata=copy.deepcopy(metadata)) + - 'x_data': (n_x_data,) or (n_x_data + 1,) float ndarray + - 'x_data_unit': str + - 'y_data': (n_x_data,) float ndarray + - 'y_data_unit': str - def select(self, **select_key_values: Union[ - str, int, Sequence[str], Sequence[int]]) -> T: - """ - Select spectra by their keys and values in metadata['line_data'] + There are also the following optional keys: - Parameters - ---------- - **select_key_values - Key-value/values pairs in metadata['line_data'] describing - which spectra to extract. For example, to select all spectra - where metadata['line_data']['species'] = 'Na' or 'Cl' use - spectrum.select(species=['Na', 'Cl']). To select 'Na' and - 'Cl' spectra where weighting is also coherent, use - spectrum.select(species=['Na', 'Cl'], weighting='coherent') + - 'x_tick_labels': list of (int, string) tuples + - 'metadata': dict Returns ------- - selected_spectra - A Spectrum1DCollection containing the selected spectra - - Raises - ------ - ValueError - If no matching spectra are found + spectrum_collection """ - select_val_dict = _get_unique_elems_and_idx( - self._get_line_data_vals(*select_key_values.keys())) - for key, value in select_key_values.items(): - if isinstance(value, (int, str)): - select_key_values[key] = [value] - value_combinations = itertools.product(*select_key_values.values()) - select_idx = np.array([], dtype=np.int32) - for value_combo in value_combinations: - try: - idx = select_val_dict[value_combo] - # Don't require every combination to match e.g. - # spec.select(sample=[0, 2], inst=['MAPS', 'MARI']) - # we don't want to error simply because there are no - # inst='MAPS' and sample=2 combinations - except KeyError: - continue - select_idx = np.append(select_idx, idx) - if len(select_idx) == 0: - raise ValueError(f'No spectra found with matching metadata ' - f'for {select_key_values}') - return self[select_idx] - + return super().from_dict(d) class Spectrum2D(Spectrum): """ @@ -1279,7 +1452,7 @@ class Spectrum2D(Spectrum): def __init__(self, x_data: Quantity, y_data: Quantity, z_data: Quantity, x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, - metadata: Optional[Dict[str, Union[int, str]]] = None + metadata: Optional[Metadata] = None ) -> None: """ Parameters @@ -1442,7 +1615,7 @@ def broaden(self: T, method=method) spectrum = Spectrum2D(np.copy(self.x_data), np.copy(self.y_data), - z_broadened*ureg(self.z_data_unit), + ureg.Quantity(z_broadened, units=self.z_data_unit), copy.copy(self.x_tick_labels), copy.deepcopy(self.metadata)) else: @@ -1656,6 +1829,145 @@ def from_dict(cls: Type[T], d: Dict[str, Any]) -> T: metadata=d['metadata']) +class Spectrum2DCollection(SpectrumCollectionMixin, + Spectrum, + collections.abc.Sequence): + """A collection of Spectrum2D with common x_data, y_data and x_tick_labels + + Intended for convenient storage of contributions to spectral maps such as + S(Q,w). This object can be indexed or iterated to obtain individual + Spectrum2D. + + Attributes + ---------- + x_data + Shape (n_x_data,) or (n_x_data + 1,) float Quantity. The x_data + points (if size == (n_x_data,)) or x_data bin edges (if size + == (n_x_data + 1,)) + y_data + Shape (n_y_data,) or (n_y_data + 1,) float Quantity. The y_data + points (if size == (n_y_data,)) or y_data bin edges (if size + == (n_y_data + 1,)) + z_data + Shape (n_entries, n_x_data, n_y_data) float Quantity. The spectral data + in x and y, indexed over components + x_tick_labels + Sequence[Tuple[int, str]] or None. Special tick labels e.g. for + high-symmetry points. The int refers to the index in x_data the + label should be applied to + metadata + Dict[str, Union[int, str, LineData]] or None. Contains metadata + about the spectra. Keys should be strings and values should be + strings or integers. + There are some functional keys: + + - 'line_data' : LineData + This is a Sequence[Dict[str, Union[int, str]], + it contains metadata for each spectrum in + the collection, and must be of length + n_entries + """ + + # Private attributes used by SpectrumCollectionMixin + _bin_axes = ("x", "y") + _spectrum_axis = "z" + _item_type = Spectrum2D + + def __init__( + self, x_data: Quantity, y_data: Quantity, z_data: Quantity, + x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, + metadata: Optional[Metadata] = None + ) -> None: + _check_constructor_inputs( + [z_data, x_tick_labels, metadata], + [Quantity, [list, type(None)], [dict, type(None)]], + [(-1, -1, -1), (), ()], + ['z_data', 'x_tick_labels', 'metadata']) + # First axis corresponds to spectra in collection + _, nx, ny = z_data.shape + _check_constructor_inputs( + [x_data, y_data], + [Quantity, Quantity], + [[(nx,), (nx + 1,)], [(ny,), (ny + 1,)]], + ['x_data', 'y_data']) + + self._set_data(x_data, 'x') + self._set_data(y_data, 'y') + self.x_tick_labels = x_tick_labels + self._set_data(z_data, 'z') + + self.metadata = metadata if metadata is not None else {} + self._check_metadata() + + def _split_by_indices(self, indices: Sequence[int] | np.ndarray + ) -> List[Self]: + """Split data along x axis at given indices""" + ranges = self._ranges_from_indices(indices) + return [type(self)(self.x_data[x0:x1], + self.y_data, + self.z_data[:, x0:x1, :], + x_tick_labels=self._cut_x_ticks( + self.x_tick_labels, x0, x1), + metadata=self.metadata) + for x0, x1 in ranges] + + @property + def z_data(self) -> Quantity: + return ureg.Quantity( + self._z_data, self._internal_z_data_unit + ).to(self.z_data_unit, "reciprocal_spectroscopy") + + @z_data.setter + def z_data(self, value: Quantity) -> None: + self.z_data_unit = str(value.units) + self._z_data = value.to(self._internal_z_data_unit).magnitude + + @classmethod + def from_spectra(cls, spectra: Sequence[Spectrum2D]) -> Self: + if len(spectra) < 1: + raise IndexError("At least one spectrum is needed for collection") + + def _type_check(spectrum): + if not isinstance(spectrum, Spectrum2D): + raise TypeError( + "from_spectra() requires a sequence of Spectrum2D") + + _type_check(spectra[0]) + bins_data = { + f"{ax}_data": getattr(spectra[0], f"{ax}_data") + for ax in cls._bin_axes + } + x_tick_labels = spectra[0].x_tick_labels + + spectrum_0_data = cls._get_item_data(spectra[0]) + spectrum_data_shape = spectrum_0_data.shape + spectrum_data_magnitude = np.empty( + (len(spectra), *spectrum_data_shape)) + spectrum_data_magnitude[0, :, :] = spectrum_0_data.magnitude + spectrum_data_units = spectrum_0_data.units + + for i, spectrum in enumerate(spectra[1:]): + _type_check(spectrum) + spectrum_i_raw_data = cls._get_item_raw_data(spectrum) + spectrum_i_data_units = cls._get_item_data_unit(spectrum) + assert spectrum_i_data_units == spectrum_data_units + + for key, ref_bins in bins_data.items(): + item_bins = getattr(spectrum, key) + assert np.allclose(item_bins.magnitude, ref_bins.magnitude) + assert item_bins.units == ref_bins.units + + assert spectrum.x_tick_labels == x_tick_labels + spectrum_data_magnitude[i + 1, :, :] = spectrum_i_raw_data + + metadata = cls._combine_metadata([spec.metadata for spec in spectra]) + spectrum_data = Quantity(spectrum_data_magnitude, spectrum_data_units) + return cls(**bins_data, + **{f"{cls._spectrum_axis}_data": spectrum_data}, + x_tick_labels=x_tick_labels, + metadata=metadata) + + def apply_kinematic_constraints(spectrum: Spectrum2D, e_i: Quantity = None, e_f: Quantity = None, diff --git a/setup.py b/setup.py index c7b9250cb..02f28dae1 100644 --- a/setup.py +++ b/setup.py @@ -144,7 +144,8 @@ def run_setup(): 'seekpath>=1.1.0', 'spglib>=1.9.4', 'pint>=0.22', - 'threadpoolctl>=3.0.0' + 'threadpoolctl>=3.0.0', + 'toolz>=0.12.1', ], extras_require={ 'matplotlib': ['matplotlib>=3.8.0'], diff --git a/tests_and_analysis/minimum_euphonic_requirements.txt b/tests_and_analysis/minimum_euphonic_requirements.txt index 065a89b61..dbddec3df 100644 --- a/tests_and_analysis/minimum_euphonic_requirements.txt +++ b/tests_and_analysis/minimum_euphonic_requirements.txt @@ -7,3 +7,4 @@ matplotlib==3.8 h5py==3.6 PyYAML==6.0 threadpoolctl==3.0.0 +toolz==0.12.1 diff --git a/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_0.json b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_0.json new file mode 100644 index 000000000..015da5952 --- /dev/null +++ b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_0.json @@ -0,0 +1,267 @@ +{ + "__euphonic_class__": "Spectrum2D", + "__euphonic_version__": "1.3.2+33.gd8680c2.dirty", + "metadata": { + "common": "yes", + "direction": 0 + }, + "x_data": [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555556, + 0.6666666666666666, + 0.7777777777777777, + 0.8888888888888888, + 1.0 + ], + "x_data_unit": "1 / angstrom", + "x_tick_labels": [ + [ + 0, + "$\\Gamma$" + ], + [ + 9, + "" + ] + ], + "y_data": [ + 0.0, + 5.2631578947368425, + 10.526315789473685, + 15.789473684210527, + 21.05263157894737, + 26.315789473684212, + 31.578947368421055, + 36.8421052631579, + 42.10526315789474, + 47.36842105263158, + 52.631578947368425, + 57.89473684210527, + 63.15789473684211, + 68.42105263157896, + 73.6842105263158, + 78.94736842105263, + 84.21052631578948, + 89.47368421052633, + 94.73684210526316, + 100.0 + ], + "y_data_unit": "millielectron_volt", + "z_data": [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.06333333333333335, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.04222222222222224 + ], + [ + 0.0, + 0.06333333333333331, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.021111111111111074, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.021111111111111074, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0, + 0.0422222222222222, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.04222222222222224, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.021111111111111164, + 0.04222222222222224 + ], + [ + 0.0, + 0.0, + 0.0422222222222222, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.08444444444444447, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.04222222222222224, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.04222222222222224 + ] + ], + "z_data_unit": "1 / millielectron_volt" +} \ No newline at end of file diff --git a/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_1.json b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_1.json new file mode 100644 index 000000000..d2c5fd013 --- /dev/null +++ b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_1.json @@ -0,0 +1,267 @@ +{ + "__euphonic_class__": "Spectrum2D", + "__euphonic_version__": "1.3.2+33.gd8680c2.dirty", + "metadata": { + "common": "yes", + "direction": 1 + }, + "x_data": [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555557, + 0.6666666666666666, + 0.7777777777777778, + 0.8888888888888888, + 1.0 + ], + "x_data_unit": "1 / angstrom", + "x_tick_labels": [ + [ + 0, + "$\\Gamma$" + ], + [ + 9, + "" + ] + ], + "y_data": [ + 0.0, + 5.2631578947368425, + 10.526315789473685, + 15.789473684210527, + 21.05263157894737, + 26.315789473684212, + 31.578947368421055, + 36.8421052631579, + 42.10526315789474, + 47.36842105263158, + 52.631578947368425, + 57.89473684210527, + 63.15789473684211, + 68.42105263157896, + 73.6842105263158, + 78.94736842105263, + 84.21052631578948, + 89.47368421052633, + 94.73684210526316, + 100.0 + ], + "y_data_unit": "millielectron_volt", + "z_data": [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.06333333333333335, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0, + 0.06333333333333331, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.06333333333333332, + 0.0, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ] + ], + "z_data_unit": "1 / millielectron_volt" +} \ No newline at end of file diff --git a/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_2.json b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_2.json new file mode 100644 index 000000000..12106d100 --- /dev/null +++ b/tests_and_analysis/test/data/spectrum2d/quartz_fuzzy_map_2.json @@ -0,0 +1,267 @@ +{ + "__euphonic_class__": "Spectrum2D", + "__euphonic_version__": "1.3.2+33.gd8680c2.dirty", + "metadata": { + "common": "yes", + "direction": 2 + }, + "x_data": [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555557, + 0.6666666666666666, + 0.7777777777777778, + 0.8888888888888888, + 1.0 + ], + "x_data_unit": "1 / angstrom", + "x_tick_labels": [ + [ + 0, + "$\\Gamma$" + ], + [ + 9, + "" + ] + ], + "y_data": [ + 0.0, + 5.2631578947368425, + 10.526315789473685, + 15.789473684210527, + 21.05263157894737, + 26.315789473684212, + 31.578947368421055, + 36.8421052631579, + 42.10526315789474, + 47.36842105263158, + 52.631578947368425, + 57.89473684210527, + 63.15789473684211, + 68.42105263157896, + 73.6842105263158, + 78.94736842105263, + 84.21052631578948, + 89.47368421052633, + 94.73684210526316, + 100.0 + ], + "y_data_unit": "millielectron_volt", + "z_data": [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.08444444444444447, + 0.0, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.06333333333333335 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.06333333333333335 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.04222222222222219, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.06333333333333332, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.08444444444444438, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.042222222222222223, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.08444444444444438, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.04222222222222219, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ] + ], + "z_data_unit": "1 / millielectron_volt" +} \ No newline at end of file diff --git a/tests_and_analysis/test/data/spectrum2dcollection/quartz_fuzzy_map.json b/tests_and_analysis/test/data/spectrum2dcollection/quartz_fuzzy_map.json new file mode 100644 index 000000000..2240424be --- /dev/null +++ b/tests_and_analysis/test/data/spectrum2dcollection/quartz_fuzzy_map.json @@ -0,0 +1,703 @@ +{ + "__euphonic_class__": "Spectrum2DCollection", + "__euphonic_version__": "1.3.2+33.gd8680c2.dirty", + "metadata": { + "common": "yes", + "line_data": [ + { + "direction": 0 + }, + { + "direction": 1 + }, + { + "direction": 2 + } + ] + }, + "x_data": [ + 0.0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555556, + 0.6666666666666666, + 0.7777777777777777, + 0.8888888888888888, + 1.0 + ], + "x_data_unit": "1 / angstrom", + "x_tick_labels": [ + [ + 0, + "$\\Gamma$" + ], + [ + 9, + "" + ] + ], + "y_data": [ + 0.0, + 5.2631578947368425, + 10.526315789473685, + 15.789473684210527, + 21.05263157894737, + 26.315789473684212, + 31.578947368421055, + 36.8421052631579, + 42.10526315789474, + 47.36842105263158, + 52.631578947368425, + 57.89473684210527, + 63.15789473684211, + 68.42105263157896, + 73.6842105263158, + 78.94736842105263, + 84.21052631578948, + 89.47368421052633, + 94.73684210526316, + 100.0 + ], + "y_data_unit": "millielectron_volt", + "z_data": [ + [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.06333333333333335, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.04222222222222224 + ], + [ + 0.0, + 0.06333333333333331, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.021111111111111074, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.021111111111111074, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.02111111111111112 + ], + [ + 0.0, + 0.0, + 0.0422222222222222, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.04222222222222224, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.021111111111111164, + 0.04222222222222224 + ], + [ + 0.0, + 0.0, + 0.0422222222222222, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.08444444444444447, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.021111111111111112, + 0.042222222222222223, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.04222222222222224, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.04222222222222224 + ] + ], + [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.06333333333333335, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.04222222222222224 + ], + [ + 0.0, + 0.06333333333333331, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.06333333333333332, + 0.0, + 0.021111111111111094, + 0.02111111111111112, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222219, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0211111111111111, + 0.042222222222222223, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.021111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.04222222222222219, + 0.0, + 0.06333333333333328, + 0.02111111111111112, + 0.04222222222222215, + 0.0, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.02111111111111112 + ] + ], + [ + [ + 0.0, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.021111111111111094, + 0.04222222222222224, + 0.0, + 0.08444444444444447, + 0.0, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.0, + 0.04222222222222215, + 0.0, + 0.06333333333333335 + ], + [ + 0.06333333333333331, + 0.0, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.04222222222222224, + 0.021111111111111094, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.021111111111111074, + 0.0, + 0.06333333333333335 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.04222222222222219, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.021111111111111112, + 0.06333333333333332, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.08444444444444438, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.042222222222222223, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.08444444444444438, + 0.04222222222222224, + 0.0, + 0.0, + 0.04222222222222224, + 0.0, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0, + 0.0211111111111111, + 0.0422222222222222, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.0, + 0.08444444444444447 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.06333333333333328, + 0.04222222222222224, + 0.0, + 0.02111111111111112, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ], + [ + 0.0211111111111111, + 0.0422222222222222, + 0.0, + 0.042222222222222223, + 0.0, + 0.04222222222222219, + 0.02111111111111112, + 0.021111111111111094, + 0.02111111111111112, + 0.04222222222222219, + 0.06333333333333335, + 0.0, + 0.02111111111111112, + 0.0, + 0.0, + 0.021111111111111074, + 0.0, + 0.021111111111111164, + 0.06333333333333335 + ] + ] + ], + "z_data_unit": "1 / millielectron_volt" +} \ No newline at end of file diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py index 24aedbd5e..cca15fe5d 100644 --- a/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py +++ b/tests_and_analysis/test/euphonic_test/test_spectrum1dcollection.py @@ -623,7 +623,9 @@ def test_select(self, spectrum_file, select_kwargs, [3, 5]), ('La2Zr2O7_666_coh_incoh_species_append_pdos.json', {'weighting': 'incoherent', 'species': 'O'}, - [3]) + [3]), + ('methane_pdos.json', + {'desc': 'Methane PDOS', 'label': 'H3'}, [2]), ]) def test_select_same_as_indexing(self, spectrum_file, select_kwargs, expected_indices): diff --git a/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py new file mode 100644 index 000000000..96029bc8e --- /dev/null +++ b/tests_and_analysis/test/euphonic_test/test_spectrum2dcollection.py @@ -0,0 +1,214 @@ +"""Unit tests for Spectrum2DCollection""" + +# Stop the linter from complaining when pytest fixtures are used idiomatically +# pylint: disable=redefined-outer-name + +from typing import Optional + +import numpy as np +import pytest + +from euphonic import Quantity, ureg +from euphonic.spectra import OneLineData, Spectrum2D, Spectrum2DCollection + +from tests_and_analysis.test.utils import get_data_path +from .test_spectrum2d import check_spectrum2d, get_spectrum2d + + +def get_spectrum2dcollection_path(*subpaths): + """Get Spectrum2DCollection reference data path""" + return get_data_path('spectrum2dcollection', *subpaths) + + +def get_spectrum2dcollection(json_filename): + """Get Spectrum2DCollection reference data object""" + return Spectrum2DCollection.from_json_file( + get_spectrum2dcollection_path(json_filename)) + + +@pytest.fixture +def quartz_fuzzy_collection() -> Spectrum2DCollection: + """Coarsely sampled quartz bands in a few directions""" + return get_spectrum2dcollection("quartz_fuzzy_map.json") + + +@pytest.fixture +def quartz_fuzzy_items() -> list[Spectrum2D]: + """Individual spectra corresponding to quartz_fuzzy_collection""" + return [get_spectrum2d(f"quartz_fuzzy_map_{i}.json") for i in range(3)] + +@pytest.fixture +def inconsistent_x_item() -> Spectrum2D: + """Spectrum with different x values""" + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item._x_data *= 2. + return item + +@pytest.fixture +def inconsistent_x_units_item(): + """Spectrum with different x units""" + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item.x_data_unit = "1/bohr" + return item + +@pytest.fixture +def inconsistent_x_length_item(): + """Spectrum with different number of x values""" + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item.x_data = item.x_data[:-2] + item.z_data = item.z_data[:-2, :] + return item + +@pytest.fixture +def inconsistent_y_item(): + """Spectrum with different y values""" + item = get_spectrum2d("quartz_fuzzy_map_0.json") + item.y_data = item.y_data * 2. + return item + +def rand_spectrum2d(seed: int = 1, + x_bins: Optional[Quantity] = None, + y_bins: Optional[Quantity] = None, + metadata: Optional[OneLineData] = None) -> Spectrum2D: + """Generate a Spectrum2D with random axis lengths, ranges, and metadata""" + rng = np.random.default_rng(seed=seed) + + if x_bins is None: + x_bins = np.linspace(*sorted([rng.random(), rng.random()]), + rng.integers(3, 10), + ) * ureg("1 / angstrom") + if y_bins is None: + y_bins = np.linspace(*sorted([rng.random(), rng.random()]), + rng.integers(3, 10)) * ureg("meV") + if metadata is None: + metadata = {"index": rng.integers(10), + "value": rng.random(), + "tag": "common"} + + spectrum = Spectrum2D(x_data=x_bins, + y_data=y_bins, + z_data=rng.random([len(x_bins) - 1, len(y_bins) - 1] + ) * ureg("millibarn / meV"), + metadata=metadata) + return spectrum + + +class TestSpectrum2DCollectionCreation: + """Unit tests for Spectrum2DCollection constructors""" + def test_init_from_numbers(self): + """Construct Spectrum2DCollection with __init__()""" + n_x = 10 + n_y = 20 + n_z = 5 + + x_data = ureg.Quantity(np.linspace(0, 100, n_x), "1 / angstrom") + y_data = ureg.Quantity(np.linspace(0, 2000, n_y), "meV") + z_data = ureg.Quantity(np.random.random((n_z, n_x, n_y)), "1 / meV") + + metadata = {"flavour": "chocolate", + "line_data": [{"index": i} for i in range(n_z)]} + + x_tick_labels = [(0, "Start"), (n_x - 1, "END")] + + spectrum = Spectrum2DCollection( + x_data, y_data, z_data, + x_tick_labels=x_tick_labels, metadata=metadata) + + for attr, data in [("x_data", x_data), + ("y_data", y_data), + ("z_data", z_data)]: + assert getattr(spectrum, attr).units == data.units + np.testing.assert_allclose(getattr(spectrum, attr).magnitude, + data.magnitude) + + assert spectrum.metadata == metadata + + def test_from_spectra(self, quartz_fuzzy_collection, quartz_fuzzy_items): + """Use alternate constructor Spectrum2DCollection.from_spectra()""" + collection = Spectrum2DCollection.from_spectra(quartz_fuzzy_items) + ref_collection = quartz_fuzzy_collection + + for attr in ("x_data", "y_data", "z_data"): + new, ref = getattr(collection, attr), getattr(ref_collection, attr) + assert new.units == ref.units + np.testing.assert_allclose(new.magnitude, ref.magnitude) + + if ref_collection.metadata is None: + assert collection.metadata is None + else: + assert ref_collection.metadata == collection.metadata + + # pylint: disable=R0913 # These fixtures are "too many arguments" + def test_from_bad_spectra( + self, + quartz_fuzzy_items, + inconsistent_x_item, + inconsistent_x_length_item, + inconsistent_x_units_item, + inconsistent_y_item): + """Spectrum2DCollection.from_spectra with inconsistent input""" + + with pytest.raises(AssertionError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_x_item] + ) + + with pytest.raises(AssertionError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_x_units_item] + ) + + with pytest.raises(ValueError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_x_length_item] + ) + + with pytest.raises(AssertionError): + Spectrum2DCollection.from_spectra( + quartz_fuzzy_items + [inconsistent_y_item] + ) + +class TestSpectrum2DCollectionFunctionality: + """Unit test indexing and methods of Spectrum2DCollection""" + + def test_indexing(self, quartz_fuzzy_collection, quartz_fuzzy_items): + """Check indexing an element, slice and iteration + + - Individual index should yield corresponding Spectrum2D + - A slice should yield a new Spectrum2DCollection + - Iteration should yield a series of Spectrum2D + + """ + item_1 = quartz_fuzzy_collection[1] + assert isinstance(item_1, Spectrum2D) + check_spectrum2d(item_1, quartz_fuzzy_items[1]) + + item_1_to_end = quartz_fuzzy_collection[1:] + assert isinstance(item_1_to_end, Spectrum2DCollection) + assert item_1_to_end != quartz_fuzzy_collection + + for item, ref in zip(item_1_to_end, quartz_fuzzy_items[1:]): + assert isinstance(item, Spectrum2D) + check_spectrum2d(item, ref) + + def test_collection_methods(self, quartz_fuzzy_collection): + """Check methods from SpectrumCollectionMixin + + These are checked thoroughly for Spectrum1DCollection, but here we + try to ensure the generic implementation works correctly in 2-D + + """ + + total = quartz_fuzzy_collection.sum() + assert isinstance(total, Spectrum2D) + assert total.z_data[3, 3] == sum(spec.z_data[3, 3] + for spec in quartz_fuzzy_collection) + + extended = quartz_fuzzy_collection + quartz_fuzzy_collection + assert len(extended) == 2 * len(quartz_fuzzy_collection) + np.testing.assert_allclose(extended.sum().z_data.magnitude, + total.z_data.magnitude * 2) + + selection = quartz_fuzzy_collection.select(direction=2, common="yes") + ref_item_2 = get_spectrum2d("quartz_fuzzy_map_2.json") + check_spectrum2d(selection.sum(), ref_item_2)