diff --git a/euphonic/writers/phonon_website.py b/euphonic/writers/phonon_website.py index 89293eec..50741343 100644 --- a/euphonic/writers/phonon_website.py +++ b/euphonic/writers/phonon_website.py @@ -1,8 +1,12 @@ -"""Export to JSON for phonon visualisation website""" +"""Export to JSON for phonon visualisation website: + +- https://henriquemiranda.github.io/phononwebsite/index.html +""" from collections import Counter from itertools import pairwise import json +from pathlib import Path from typing import Any, TypedDict import numpy as np @@ -43,7 +47,7 @@ class PhononWebsiteData(TypedDict): def write_phonon_website_json( modes: QpointPhononModes, - output_file: str = "phonons.json", + output_file: str | Path = "phonons.json", name: str = "Euphonic export", x_tick_labels: XTickLabels | None = None, ) -> None: @@ -126,8 +130,7 @@ def _remove_breaks(distances: np.ndarray, btol: float = 10.) -> list[int]: breakpoints = np.where((diff / median) > btol)[0] + 1 for breakpoint in reversed(breakpoints): - distances[breakpoint:] -= (distances[breakpoint] - - distances[breakpoint - 1]) + distances[breakpoint:] -= diff[breakpoint - 1] return breakpoints.tolist() @@ -140,7 +143,7 @@ def _find_duplicates(distances: np.ndarray) -> list[int]: """ diff = np.diff(distances) - duplicates = np.where(diff == 0.)[0] + 1 + duplicates = np.where(np.isclose(diff, 0.))[0] + 1 return duplicates.tolist() @@ -189,7 +192,7 @@ def _modes_to_phonon_website_dict(modes: QpointPhononModes, breakpoints = _remove_breaks(abscissa) breakpoints = sorted(set([0] + duplicates + breakpoints + [len(abscissa)])) - line_breaks = [(start, end) for start, end in pairwise(breakpoints)] + line_breaks = list(pairwise(breakpoints)) if x_tick_labels is None: x_tick_labels = get_qpoint_labels(qpts,