From 67dda664394db185c35b99d7e99a9053341a5265 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Mon, 14 Oct 2024 15:26:14 +0100 Subject: [PATCH] Move phonon-web-json exporter to separate file; implement --title This is a fair chunk of code with a specific duty, and while it _can_ be used purely with QpointPhononModes as input is better characterised as a function that uses both modes and a corresponding set of tick labels. Move to a new "writers" module. Use the existing CLI --title option to set title when using euphonic-dispersion to generate JSON Introduce a named type for XTickLabels; this helps legibility. There is now some inconsistency between use of generic list/tuple and classes from Typing: don't worry about it for now, this can be tidied up afterwards. --- euphonic/cli/dispersion.py | 9 +- euphonic/cli/utils.py | 6 + euphonic/qpoint_phonon_modes.py | 202 +-------------------------- euphonic/spectra.py | 18 +-- euphonic/writers/__init__.py | 0 euphonic/writers/phonon_website.py | 213 +++++++++++++++++++++++++++++ 6 files changed, 237 insertions(+), 211 deletions(-) create mode 100644 euphonic/writers/__init__.py create mode 100644 euphonic/writers/phonon_website.py diff --git a/euphonic/cli/dispersion.py b/euphonic/cli/dispersion.py index 96894ddd8..4194e6c11 100644 --- a/euphonic/cli/dispersion.py +++ b/euphonic/cli/dispersion.py @@ -5,10 +5,12 @@ from euphonic.plot import plot_1d from euphonic.styles import base_style +from euphonic.writers.phonon_website import write_phonon_website_json from euphonic import Spectrum1D, ForceConstants, QpointFrequencies from .utils import (load_data_from_file, get_args, _bands_from_force_constants, _compose_style, _get_q_distance, matplotlib_save_or_show, _get_cli_parser, + _get_title, _calc_modes_kwargs, _plot_label_kwargs) @@ -36,9 +38,10 @@ def main(params: Optional[List[str]] = None) -> None: x_tick_labels = None if args.save_web_json is not None: - bands.write_phonon_website_json(output_file=args.save_web_json, - x_tick_labels=x_tick_labels) - + write_phonon_website_json(modes=bands, + name=_get_title(args.filename, args.title), + output_file=args.save_web_json, + x_tick_labels=x_tick_labels) bands.frequencies_unit = args.energy_unit diff --git a/euphonic/cli/utils.py b/euphonic/cli/utils.py index 4dd3628ec..0d96555d2 100644 --- a/euphonic/cli/utils.py +++ b/euphonic/cli/utils.py @@ -934,3 +934,9 @@ def _compose_style( style.append(explicit_args) return style + +def _get_title(filename: str, title: str = '') -> str: + """Get a plot title: either user-provided string, or from filename""" + if title: + return title + return pathlib.Path(filename).stem diff --git a/euphonic/qpoint_phonon_modes.py b/euphonic/qpoint_phonon_modes.py index e4eb2ccdb..79eebeaff 100644 --- a/euphonic/qpoint_phonon_modes.py +++ b/euphonic/qpoint_phonon_modes.py @@ -1,6 +1,7 @@ -import json +"""Data container (with methods) for phonon frequencies and eigenvectors""" + import math -from typing import Any, Dict, Optional, Union, Type, TypedDict, TypeVar +from typing import Any, Dict, Optional, Union, Type, TypeVar from collections.abc import Mapping import numpy as np @@ -15,35 +16,6 @@ StructureFactor, Spectrum1DCollection) -complex_pair = tuple[float, float] - - -class PhononWebsiteData(TypedDict): - """Data container for export to phonon visualisation website - - Specification: https://henriquemiranda.github.io/phononwebsite/index.html - - line_breaks are currently not implemented - - """ - - name: str - natoms: int - lattice: list[list[float]] - atom_types: list[str] - atom_numbers: list[int] - formula: str - repetitions: list[int] - atom_pos_car: list[list[float]] - atom_pos_red: list[list[float]] - highsym_qpts: list[tuple[int, str]] - qpoints: list[list[float]] - distances: list[float] # Cumulative distance from first q-point - eigenvalues: list[float] # in cm-1 - vectors: list[list[list[tuple[complex_pair, complex_pair, complex_pair]]]] - line_breaks: list[tuple[int, int]] - - class QpointPhononModes(QpointFrequencies): """ A class to read and store vibrational data from model (e.g. CASTEP) @@ -746,171 +718,3 @@ def from_phonopy(cls: Type[T], path: str = '.', path=path, phonon_name=phonon_name, phonon_format=phonon_format, summary_name=summary_name) return cls.from_dict(data) - - def write_phonon_website_json( - self, - output_file: str = "phonons.json", - name: str = "Euphonic export", - x_tick_labels: list[tuple[int, str]] | None = None, - ) -> None: - - """Dump to .json for use with phonon website visualiser - - Use with javascript application at - https://henriquemiranda.github.io/phononwebsite - - Parameters - ---------- - output_file - Path to output file - name - Set "name" metadata - x_tick_labels - index and label for high symmetry labels (if known) - - """ - - with open(output_file, 'w') as fd: - json.dump(self._to_phonon_website_dict(name=name, - x_tick_labels=x_tick_labels), - fd) - - @staticmethod - def _crystal_website_data(crystal: Crystal) -> dict[str, Any]: - elements = [ - '_', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', - 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', - 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', - 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', - 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', - 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', - 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', - 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', - 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', - 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', - 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'] - - def get_z(symbol: str) -> int: - try: - return elements.index(symbol) - except ValueError: # Symbol not found - return 0 - - def symbols_to_formula(symbols: list[str]) -> str: - from collections import Counter - symbol_counts = Counter(symbols) - - return "".join(f"{symbol}{symbol_counts[symbol]}" - for symbol in sorted(symbol_counts)) - - return dict( - natoms = len(crystal.atom_type), - lattice = crystal.cell_vectors.to("angstrom").magnitude.tolist(), - atom_types = crystal.atom_type.tolist(), - atom_numbers = list(map(get_z, crystal.atom_type)), - formula = symbols_to_formula(crystal.atom_type), - atom_pos_red = crystal.atom_r.tolist(), - atom_pos_car = (crystal.atom_r @ crystal.cell_vectors).to("angstrom").magnitude.tolist() - ) - - @staticmethod - def _remove_breaks(distances: np.ndarray, btol: float = 10.) -> list[int]: - """Collapse large breaks in cumulative-distance array - - These correspond to discontinuous regions of the x-axis: in euphonic - plotting this is usually handled by splitting the spectrum and plotting - to new axes, but Phonon Website does not handle this. - - Data is modified in-place - - A list of identified breakpoints is returned - - """ - diff = np.diff(distances) - median = np.median(diff) - breakpoints = np.where((diff / median) > btol)[0] + 1 - - for breakpoint in reversed(breakpoints): - distances[breakpoint:] -= (distances[breakpoint] - distances[breakpoint - 1]) - - return breakpoints.tolist() - - @staticmethod - def _expand_duplicates(distances: np.ndarray, pad_fraction = 0.001) -> list[int]: - diff = np.diff(distances) - pad = np.median(diff) * pad_fraction - - duplicates = np.where(diff == 0.)[0] + 1 - for duplicate in reversed(duplicates): - distances[duplicate:] += (distances[duplicate - 1] - distances[duplicate] + pad) - return duplicates.tolist() - - @staticmethod - def _combine_neighbouring_labels(x_tick_labels: list[tuple[int, str]] - ) -> list[tuple[int, str]]: - """Merge neighbouring labels in x_tick_label list - - If labels are the same, only keep one. - - If labels are different, join with | - - e.g.:: - - [(1, "X"), (2, "X"), (4, "A"), (7, "Y"), (8, "Z")] - - --> - - [(1, "X"), (4, "A"), (7, "Y|Z")] - - """ - labels = dict(x_tick_labels) - - for index in sorted(labels): - if index - 1 in labels: - if labels.get(index - 1) != labels.get(index): - labels[index - 1] = f"{labels[index - 1]}|{labels[index]}" - del labels[index] - return list(sorted(labels.items())) - - def _to_phonon_website_dict(self, - name: str = 'Euphonic export', - repetitions: tuple[int, int, int] = (2, 2, 2), - x_tick_labels: list[tuple[int, str]] | None = None, - ) -> PhononWebsiteData: - from itertools import pairwise - from euphonic.util import _calc_abscissa, get_qpoint_labels - - qpts = self.qpts - eigenvectors = self.eigenvectors - - abscissa = _calc_abscissa(self.crystal.reciprocal_cell(), qpts) - - duplicates = self._expand_duplicates(abscissa) - breakpoints = self._remove_breaks(abscissa) - - breakpoints = sorted(set([0] + duplicates + breakpoints + [len(abscissa)])) - line_breaks = [(start, end) for start, end in pairwise(breakpoints)] - - if x_tick_labels is None: - x_tick_labels = get_qpoint_labels(qpts, - cell=self.crystal.to_spglib_cell()) - - x_tick_labels = [(int(key) + 1, str(value)) for key, value in x_tick_labels] - x_tick_labels = self._combine_neighbouring_labels(x_tick_labels) - - vectors = eigenvectors / np.sqrt(self.crystal.atom_mass)[None, None, :, None] - vectors = vectors.view(float).reshape(*eigenvectors.shape[:-1], 3, 2) - - dat = PhononWebsiteData( - name=name, - **self._crystal_website_data(self.crystal), - highsym_qpts=x_tick_labels, - distances=abscissa.magnitude.tolist(), - qpoints=self.qpts.tolist(), - eigenvalues=self.frequencies.to("1/cm").magnitude.tolist(), - vectors=vectors.tolist(), - repetitions=repetitions, - line_breaks=line_breaks - ) - - return dat diff --git a/euphonic/spectra.py b/euphonic/spectra.py index e26092799..4ec2c1bb8 100644 --- a/euphonic/spectra.py +++ b/euphonic/spectra.py @@ -31,6 +31,7 @@ CallableQuantity = Callable[[Quantity], Quantity] +XTickLabels = list[tuple[int, str]] class Spectrum(ABC): @@ -84,11 +85,11 @@ def copy(self: T) -> T: ... @property - def x_tick_labels(self) -> List[Tuple[int, str]]: + def x_tick_labels(self) -> XTickLabels: return self._x_tick_labels @x_tick_labels.setter - def x_tick_labels(self, value: Sequence[Tuple[int, str]]) -> None: + def x_tick_labels(self, value: XTickLabels) -> None: err_msg = ('x_tick_labels should be of type ' 'Sequence[Tuple[int, str]] e.g. ' '[(0, "label1"), (5, "label2")]') @@ -171,10 +172,9 @@ def _ranges_from_indices(indices: Union[Sequence[int], np.ndarray] return ranges @staticmethod - def _cut_x_ticks(x_tick_labels: Union[Sequence[Tuple[int, str]], None], + def _cut_x_ticks(x_tick_labels: XTickLabels | None, x0: int, - x1: Union[int, None]) -> Union[List[Tuple[int, str]], - None]: + x1: int | None) -> XTickLabels | None: """Crop and shift x labels to new x range""" if x_tick_labels is None: return None @@ -419,7 +419,7 @@ class Spectrum1D(Spectrum): T = TypeVar('T', bound='Spectrum1D') def __init__(self, x_data: Quantity, y_data: Quantity, - x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, + x_tick_labels: Optional[XTickLabels] = None, metadata: Optional[Dict[str, Union[int, str]]] = None ) -> None: """ @@ -1147,7 +1147,7 @@ class Spectrum1DCollection(SpectrumCollectionMixin, def __init__( self, x_data: Quantity, y_data: Quantity, - x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, + x_tick_labels: Optional[XTickLabels] = None, metadata: Optional[Dict[str, Union[str, int, LineData]]] = None ) -> None: """ @@ -1451,7 +1451,7 @@ class Spectrum2D(Spectrum): def __init__(self, x_data: Quantity, y_data: Quantity, z_data: Quantity, - x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, + x_tick_labels: Optional[XTickLabels] = None, metadata: Optional[Metadata] = None ) -> None: """ @@ -1875,7 +1875,7 @@ class Spectrum2DCollection(SpectrumCollectionMixin, def __init__( self, x_data: Quantity, y_data: Quantity, z_data: Quantity, - x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None, + x_tick_labels: Optional[XTickLabels] = None, metadata: Optional[Metadata] = None ) -> None: _check_constructor_inputs( diff --git a/euphonic/writers/__init__.py b/euphonic/writers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/euphonic/writers/phonon_website.py b/euphonic/writers/phonon_website.py new file mode 100644 index 000000000..ba7d02416 --- /dev/null +++ b/euphonic/writers/phonon_website.py @@ -0,0 +1,213 @@ +"""Export to JSON for phonon visualisation website""" + +from collections import Counter +from itertools import pairwise +import json +from typing import Any, TypedDict + +import numpy as np + +from euphonic.crystal import Crystal +from euphonic.qpoint_phonon_modes import QpointPhononModes +from euphonic.spectra import XTickLabels +from euphonic.util import _calc_abscissa, get_qpoint_labels + + +ComplexPair = tuple[float, float] + + +class PhononWebsiteData(TypedDict): + """Data container for export to phonon visualisation website + + Specification: https://henriquemiranda.github.io/phononwebsite/index.html + + line_breaks are currently not implemented + + """ + name: str + natoms: int + lattice: list[list[float]] + atom_types: list[str] + atom_numbers: list[int] + formula: str + repetitions: list[int] + atom_pos_car: list[list[float]] + atom_pos_red: list[list[float]] + highsym_qpts: list[tuple[int, str]] + qpoints: list[list[float]] + distances: list[float] # Cumulative distance from first q-point + eigenvalues: list[float] # in cm-1 + vectors: list[list[list[tuple[ComplexPair, ComplexPair, ComplexPair]]]] + line_breaks: list[tuple[int, int]] + + +def write_phonon_website_json( + modes: QpointPhononModes, + output_file: str = "phonons.json", + name: str = "Euphonic export", + x_tick_labels: XTickLabels | None = None, +) -> None: + + """Dump to .json for use with phonon website visualiser + + Use with javascript application at + https://henriquemiranda.github.io/phononwebsite + + Parameters + ---------- + output_file + Path to output file + name + Set "name" metadata, to be used as figure title + x_tick_labels + index and label for high symmetry labels (if known) + + """ + + with open(output_file, 'w') as fd: + json.dump(_modes_to_phonon_website_dict(modes=modes, + name=name, + x_tick_labels=x_tick_labels), + fd) + + +def _crystal_website_data(crystal: Crystal) -> dict[str, Any]: + elements = [ + '_', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', + 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', + 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', + 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', + 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', + 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', + 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', + 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', + 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', + 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', + 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'] + + def get_z(symbol: str) -> int: + try: + return elements.index(symbol) + except ValueError: # Symbol not found + return 0 + + def symbols_to_formula(symbols: list[str]) -> str: + symbol_counts = Counter(symbols) + + return "".join(f"{symbol}{symbol_counts[symbol]}" + for symbol in sorted(symbol_counts)) + + return dict( + natoms=len(crystal.atom_type), + lattice=crystal.cell_vectors.to("angstrom").magnitude.tolist(), + atom_types=crystal.atom_type.tolist(), + atom_numbers=list(map(get_z, crystal.atom_type)), + formula=symbols_to_formula(crystal.atom_type), + atom_pos_red=crystal.atom_r.tolist(), + atom_pos_car=(crystal.atom_r @ crystal.cell_vectors + ).to("angstrom").magnitude.tolist() + ) + + +def _remove_breaks(distances: np.ndarray, btol: float = 10.) -> list[int]: + """Collapse large breaks in cumulative-distance array + + These correspond to discontinuous regions of the x-axis: in euphonic + plotting this is usually handled by splitting the spectrum and plotting + to new axes, but Phonon Website does not handle this. + + Data is modified in-place + + A list of identified breakpoints is returned + + """ + diff = np.diff(distances) + median = np.median(diff) + breakpoints = np.where((diff / median) > btol)[0] + 1 + + for breakpoint in reversed(breakpoints): + distances[breakpoint:] -= (distances[breakpoint] + - distances[breakpoint - 1]) + + return breakpoints.tolist() + + +def _expand_duplicates( + distances: np.ndarray, pad_fraction: float = 0.001) -> list[int]: + diff = np.diff(distances) + pad = np.median(diff) * pad_fraction + + duplicates = np.where(diff == 0.)[0] + 1 + for duplicate in reversed(duplicates): + distances[duplicate:] += ( + distances[duplicate - 1] - distances[duplicate] + pad) + return duplicates.tolist() + + +def _combine_neighbouring_labels(x_tick_labels: XTickLabels) -> XTickLabels: + """Merge neighbouring labels in x_tick_label list + + If labels are the same, only keep one. + + If labels are different, join with | + + e.g.:: + + [(1, "X"), (2, "X"), (4, "A"), (7, "Y"), (8, "Z")] + + --> + + [(1, "X"), (4, "A"), (7, "Y|Z")] + + """ + labels = dict(x_tick_labels) + + for index in sorted(labels): + if index - 1 in labels: + if labels.get(index - 1) != labels.get(index): + labels[index - 1] = f"{labels[index - 1]}|{labels[index]}" + del labels[index] + return list(sorted(labels.items())) + + +def _modes_to_phonon_website_dict(modes: QpointPhononModes, + name: str = 'Euphonic export', + repetitions: tuple[int, int, int] = (2, 2, 2), + x_tick_labels: XTickLabels | None = None, + ) -> PhononWebsiteData: + qpts = modes.qpts + eigenvectors = modes.eigenvectors + + abscissa = _calc_abscissa(modes.crystal.reciprocal_cell(), qpts) + + duplicates = _expand_duplicates(abscissa) + breakpoints = _remove_breaks(abscissa) + + breakpoints = sorted(set([0] + duplicates + breakpoints + [len(abscissa)])) + line_breaks = [(start, end) for start, end in pairwise(breakpoints)] + + if x_tick_labels is None: + x_tick_labels = get_qpoint_labels(qpts, + cell=modes.crystal.to_spglib_cell()) + + x_tick_labels = [(int(key) + 1, str(value)) + for key, value in x_tick_labels] + x_tick_labels = _combine_neighbouring_labels(x_tick_labels) + + mass_weights = 1 / np.sqrt(modes.crystal.atom_mass) + vectors = eigenvectors * mass_weights[None, None, :, None] + vectors = vectors.view(float).reshape(*eigenvectors.shape[:-1], 3, 2) + + dat = PhononWebsiteData( + name=name, + **_crystal_website_data(modes.crystal), + highsym_qpts=x_tick_labels, + distances=abscissa.magnitude.tolist(), + qpoints=modes.qpts.tolist(), + eigenvalues=modes.frequencies.to("1/cm").magnitude.tolist(), + vectors=vectors.tolist(), + repetitions=repetitions, + line_breaks=line_breaks + ) + + return dat