Skip to content

Commit

Permalink
Review feedback ( @oerc0122 )
Browse files Browse the repository at this point in the history
  • Loading branch information
ajjackson committed Oct 15, 2024
1 parent 8d8e415 commit 0112dea
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions euphonic/writers/phonon_website.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Export to JSON for phonon visualisation website"""
"""Export to JSON for phonon visualisation website:
- https://henriquemiranda.github.io/phononwebsite/index.html
"""

from collections import Counter
from itertools import pairwise
import json
from pathlib import Path
from typing import Any, TypedDict

import numpy as np
Expand Down Expand Up @@ -43,7 +47,7 @@ class PhononWebsiteData(TypedDict):

def write_phonon_website_json(
modes: QpointPhononModes,
output_file: str = "phonons.json",
output_file: str | Path = "phonons.json",
name: str = "Euphonic export",
x_tick_labels: XTickLabels | None = None,
) -> None:
Expand Down Expand Up @@ -126,8 +130,7 @@ def _remove_breaks(distances: np.ndarray, btol: float = 10.) -> list[int]:
breakpoints = np.where((diff / median) > btol)[0] + 1

for breakpoint in reversed(breakpoints):

Check warning on line 132 in euphonic/writers/phonon_website.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

euphonic/writers/phonon_website.py#L132

Redefining built-in 'breakpoint'
distances[breakpoint:] -= (distances[breakpoint]
- distances[breakpoint - 1])
distances[breakpoint:] -= diff[breakpoint - 1]

return breakpoints.tolist()

Expand All @@ -140,7 +143,7 @@ def _find_duplicates(distances: np.ndarray) -> list[int]:
"""

diff = np.diff(distances)
duplicates = np.where(diff == 0.)[0] + 1
duplicates = np.where(np.isclose(diff, 0.))[0] + 1

return duplicates.tolist()

Expand Down Expand Up @@ -189,7 +192,7 @@ def _modes_to_phonon_website_dict(modes: QpointPhononModes,
breakpoints = _remove_breaks(abscissa)

breakpoints = sorted(set([0] + duplicates + breakpoints + [len(abscissa)]))
line_breaks = [(start, end) for start, end in pairwise(breakpoints)]
line_breaks = list(pairwise(breakpoints))

if x_tick_labels is None:
x_tick_labels = get_qpoint_labels(qpts,
Expand Down

0 comments on commit 0112dea

Please sign in to comment.