Skip to content

Commit

Permalink
Move phonon-web-json exporter to separate file; implement --title
Browse files Browse the repository at this point in the history
This is a fair chunk of code with a specific duty, and while it _can_
be used purely with QpointPhononModes as input is better characterised
as a function that uses both modes and a corresponding set of tick
labels. Move to a new "writers" module.

Use the existing CLI --title option to set title when using
euphonic-dispersion to generate JSON

Introduce a named type for XTickLabels; this helps legibility. There
is now some inconsistency between use of generic list/tuple and
classes from Typing: don't worry about it for now, this can be tidied
up afterwards.
  • Loading branch information
ajjackson committed Oct 14, 2024
1 parent e53b23d commit 67dda66
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 211 deletions.
9 changes: 6 additions & 3 deletions euphonic/cli/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from euphonic.plot import plot_1d
from euphonic.styles import base_style
from euphonic.writers.phonon_website import write_phonon_website_json
from euphonic import Spectrum1D, ForceConstants, QpointFrequencies
from .utils import (load_data_from_file, get_args, _bands_from_force_constants,
_compose_style,
_get_q_distance, matplotlib_save_or_show, _get_cli_parser,
_get_title,
_calc_modes_kwargs, _plot_label_kwargs)


Expand Down Expand Up @@ -36,9 +38,10 @@ def main(params: Optional[List[str]] = None) -> None:
x_tick_labels = None

if args.save_web_json is not None:
bands.write_phonon_website_json(output_file=args.save_web_json,
x_tick_labels=x_tick_labels)

write_phonon_website_json(modes=bands,
name=_get_title(args.filename, args.title),
output_file=args.save_web_json,
x_tick_labels=x_tick_labels)

bands.frequencies_unit = args.energy_unit

Expand Down
6 changes: 6 additions & 0 deletions euphonic/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,3 +934,9 @@ def _compose_style(

style.append(explicit_args)
return style

def _get_title(filename: str, title: str = '') -> str:
"""Get a plot title: either user-provided string, or from filename"""
if title:
return title
return pathlib.Path(filename).stem
202 changes: 3 additions & 199 deletions euphonic/qpoint_phonon_modes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
"""Data container (with methods) for phonon frequencies and eigenvectors"""

import math
from typing import Any, Dict, Optional, Union, Type, TypedDict, TypeVar
from typing import Any, Dict, Optional, Union, Type, TypeVar
from collections.abc import Mapping

import numpy as np
Expand All @@ -15,35 +16,6 @@
StructureFactor, Spectrum1DCollection)


complex_pair = tuple[float, float]


class PhononWebsiteData(TypedDict):
"""Data container for export to phonon visualisation website
Specification: https://henriquemiranda.github.io/phononwebsite/index.html
line_breaks are currently not implemented
"""

name: str
natoms: int
lattice: list[list[float]]
atom_types: list[str]
atom_numbers: list[int]
formula: str
repetitions: list[int]
atom_pos_car: list[list[float]]
atom_pos_red: list[list[float]]
highsym_qpts: list[tuple[int, str]]
qpoints: list[list[float]]
distances: list[float] # Cumulative distance from first q-point
eigenvalues: list[float] # in cm-1
vectors: list[list[list[tuple[complex_pair, complex_pair, complex_pair]]]]
line_breaks: list[tuple[int, int]]


class QpointPhononModes(QpointFrequencies):
"""
A class to read and store vibrational data from model (e.g. CASTEP)
Expand Down Expand Up @@ -746,171 +718,3 @@ def from_phonopy(cls: Type[T], path: str = '.',
path=path, phonon_name=phonon_name, phonon_format=phonon_format,
summary_name=summary_name)
return cls.from_dict(data)

def write_phonon_website_json(
self,
output_file: str = "phonons.json",
name: str = "Euphonic export",
x_tick_labels: list[tuple[int, str]] | None = None,
) -> None:

"""Dump to .json for use with phonon website visualiser
Use with javascript application at
https://henriquemiranda.github.io/phononwebsite
Parameters
----------
output_file
Path to output file
name
Set "name" metadata
x_tick_labels
index and label for high symmetry labels (if known)
"""

with open(output_file, 'w') as fd:
json.dump(self._to_phonon_website_dict(name=name,
x_tick_labels=x_tick_labels),
fd)

@staticmethod
def _crystal_website_data(crystal: Crystal) -> dict[str, Any]:
elements = [
'_', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na',
'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V',
'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se',
'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh',
'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba',
'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho',
'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt',
'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac',
'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm',
'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg',
'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og']

def get_z(symbol: str) -> int:
try:
return elements.index(symbol)
except ValueError: # Symbol not found
return 0

def symbols_to_formula(symbols: list[str]) -> str:
from collections import Counter
symbol_counts = Counter(symbols)

return "".join(f"{symbol}{symbol_counts[symbol]}"
for symbol in sorted(symbol_counts))

return dict(
natoms = len(crystal.atom_type),
lattice = crystal.cell_vectors.to("angstrom").magnitude.tolist(),
atom_types = crystal.atom_type.tolist(),
atom_numbers = list(map(get_z, crystal.atom_type)),
formula = symbols_to_formula(crystal.atom_type),
atom_pos_red = crystal.atom_r.tolist(),
atom_pos_car = (crystal.atom_r @ crystal.cell_vectors).to("angstrom").magnitude.tolist()
)

@staticmethod
def _remove_breaks(distances: np.ndarray, btol: float = 10.) -> list[int]:
"""Collapse large breaks in cumulative-distance array
These correspond to discontinuous regions of the x-axis: in euphonic
plotting this is usually handled by splitting the spectrum and plotting
to new axes, but Phonon Website does not handle this.
Data is modified in-place
A list of identified breakpoints is returned
"""
diff = np.diff(distances)
median = np.median(diff)
breakpoints = np.where((diff / median) > btol)[0] + 1

for breakpoint in reversed(breakpoints):
distances[breakpoint:] -= (distances[breakpoint] - distances[breakpoint - 1])

return breakpoints.tolist()

@staticmethod
def _expand_duplicates(distances: np.ndarray, pad_fraction = 0.001) -> list[int]:
diff = np.diff(distances)
pad = np.median(diff) * pad_fraction

duplicates = np.where(diff == 0.)[0] + 1
for duplicate in reversed(duplicates):
distances[duplicate:] += (distances[duplicate - 1] - distances[duplicate] + pad)
return duplicates.tolist()

@staticmethod
def _combine_neighbouring_labels(x_tick_labels: list[tuple[int, str]]
) -> list[tuple[int, str]]:
"""Merge neighbouring labels in x_tick_label list
If labels are the same, only keep one.
If labels are different, join with |
e.g.::
[(1, "X"), (2, "X"), (4, "A"), (7, "Y"), (8, "Z")]
-->
[(1, "X"), (4, "A"), (7, "Y|Z")]
"""
labels = dict(x_tick_labels)

for index in sorted(labels):
if index - 1 in labels:
if labels.get(index - 1) != labels.get(index):
labels[index - 1] = f"{labels[index - 1]}|{labels[index]}"
del labels[index]
return list(sorted(labels.items()))

def _to_phonon_website_dict(self,
name: str = 'Euphonic export',
repetitions: tuple[int, int, int] = (2, 2, 2),
x_tick_labels: list[tuple[int, str]] | None = None,
) -> PhononWebsiteData:
from itertools import pairwise
from euphonic.util import _calc_abscissa, get_qpoint_labels

qpts = self.qpts
eigenvectors = self.eigenvectors

abscissa = _calc_abscissa(self.crystal.reciprocal_cell(), qpts)

duplicates = self._expand_duplicates(abscissa)
breakpoints = self._remove_breaks(abscissa)

breakpoints = sorted(set([0] + duplicates + breakpoints + [len(abscissa)]))
line_breaks = [(start, end) for start, end in pairwise(breakpoints)]

if x_tick_labels is None:
x_tick_labels = get_qpoint_labels(qpts,
cell=self.crystal.to_spglib_cell())

x_tick_labels = [(int(key) + 1, str(value)) for key, value in x_tick_labels]
x_tick_labels = self._combine_neighbouring_labels(x_tick_labels)

vectors = eigenvectors / np.sqrt(self.crystal.atom_mass)[None, None, :, None]
vectors = vectors.view(float).reshape(*eigenvectors.shape[:-1], 3, 2)

dat = PhononWebsiteData(
name=name,
**self._crystal_website_data(self.crystal),
highsym_qpts=x_tick_labels,
distances=abscissa.magnitude.tolist(),
qpoints=self.qpts.tolist(),
eigenvalues=self.frequencies.to("1/cm").magnitude.tolist(),
vectors=vectors.tolist(),
repetitions=repetitions,
line_breaks=line_breaks
)

return dat
18 changes: 9 additions & 9 deletions euphonic/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@


CallableQuantity = Callable[[Quantity], Quantity]
XTickLabels = list[tuple[int, str]]


class Spectrum(ABC):
Expand Down Expand Up @@ -84,11 +85,11 @@ def copy(self: T) -> T:
...

@property
def x_tick_labels(self) -> List[Tuple[int, str]]:
def x_tick_labels(self) -> XTickLabels:

Check notice on line 88 in euphonic/spectra.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

euphonic/spectra.py#L88

Missing function or method docstring
return self._x_tick_labels

@x_tick_labels.setter
def x_tick_labels(self, value: Sequence[Tuple[int, str]]) -> None:
def x_tick_labels(self, value: XTickLabels) -> None:
err_msg = ('x_tick_labels should be of type '
'Sequence[Tuple[int, str]] e.g. '
'[(0, "label1"), (5, "label2")]')
Expand Down Expand Up @@ -171,10 +172,9 @@ def _ranges_from_indices(indices: Union[Sequence[int], np.ndarray]
return ranges

@staticmethod
def _cut_x_ticks(x_tick_labels: Union[Sequence[Tuple[int, str]], None],
def _cut_x_ticks(x_tick_labels: XTickLabels | None,
x0: int,
x1: Union[int, None]) -> Union[List[Tuple[int, str]],
None]:
x1: int | None) -> XTickLabels | None:
"""Crop and shift x labels to new x range"""
if x_tick_labels is None:
return None
Expand Down Expand Up @@ -419,7 +419,7 @@ class Spectrum1D(Spectrum):
T = TypeVar('T', bound='Spectrum1D')

def __init__(self, x_data: Quantity, y_data: Quantity,
x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None,
x_tick_labels: Optional[XTickLabels] = None,
metadata: Optional[Dict[str, Union[int, str]]] = None
) -> None:
"""
Expand Down Expand Up @@ -1147,7 +1147,7 @@ class Spectrum1DCollection(SpectrumCollectionMixin,

def __init__(
self, x_data: Quantity, y_data: Quantity,
x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None,
x_tick_labels: Optional[XTickLabels] = None,
metadata: Optional[Dict[str, Union[str, int, LineData]]] = None
) -> None:
"""
Expand Down Expand Up @@ -1451,7 +1451,7 @@ class Spectrum2D(Spectrum):

def __init__(self, x_data: Quantity, y_data: Quantity,
z_data: Quantity,
x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None,
x_tick_labels: Optional[XTickLabels] = None,
metadata: Optional[Metadata] = None
) -> None:
"""
Expand Down Expand Up @@ -1875,7 +1875,7 @@ class Spectrum2DCollection(SpectrumCollectionMixin,

def __init__(
self, x_data: Quantity, y_data: Quantity, z_data: Quantity,
x_tick_labels: Optional[Sequence[Tuple[int, str]]] = None,
x_tick_labels: Optional[XTickLabels] = None,
metadata: Optional[Metadata] = None
) -> None:
_check_constructor_inputs(
Expand Down
Empty file added euphonic/writers/__init__.py
Empty file.
Loading

0 comments on commit 67dda66

Please sign in to comment.