Skip to content

Commit

Permalink
Improve Plotting Experience (#650)
Browse files Browse the repository at this point in the history
* unify figure types

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Sep 16, 2024
1 parent 84cb793 commit 7dd970d
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 28 deletions.
11 changes: 6 additions & 5 deletions app/src/components/plotting.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ export const Plotting = ({
availablePlots={availablePlots}
setAvailablePlots={setAvailablePlots}
plotData={plotData}
// current plot data
setDisplayedCards={setDisplayedCards}
setStep={setStep}
setSelectedFrames={setSelectedFrames}
Expand Down Expand Up @@ -125,11 +126,11 @@ const PlotsCard = ({

// once plot data updates and selectedOption == "" set selectedOption to first available plot
// TODO: this part is still very buggy!
useEffect(() => {
if (availablePlots.length > 0 && selectedOption === "") {
setSelectedOption(availablePlots[0]);
}
}, [availablePlots, selectedOption]);
// useEffect(() => {
// if (availablePlots.length > 0 && selectedOption === "") {
// setSelectedOption(availablePlots[0]);
// }
// }, [availablePlots, selectedOption]);

useEffect(() => {
if (plotData[selectedOption]) {
Expand Down
10 changes: 4 additions & 6 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import plotly.io as pio

from zndraw import ZnDraw


Expand All @@ -11,7 +9,7 @@ def test_run_analysis_distance(server, s22_energy_forces):

vis.socket.emit("analysis:run", {"method": {"discriminator": "Distance"}})
vis.socket.sleep(7)
fig = pio.from_json(vis.figures["Distance"])
fig = vis.figures["Distance"]
# assert that the x-axis label is "step"
assert fig.layout.xaxis.title.text == "step"

Expand All @@ -24,7 +22,7 @@ def test_run_analysis_Properties1D(server, s22_energy_forces):
"analysis:run", {"method": {"discriminator": "Properties1D", "value": "energy"}}
)
vis.socket.sleep(7)
fig = pio.from_json(vis.figures["Properties1D"])
fig = vis.figures["Properties1D"]
# assert that the x-axis label is "step"
assert fig.layout.xaxis.title.text == "step"
assert fig.layout.yaxis.title.text == "energy"
Expand All @@ -46,7 +44,7 @@ def test_run_analysis_Properties2D(server, s22_energy_forces):
},
)
vis.socket.sleep(7)
fig = pio.from_json(vis.figures["Properties2D"])
fig = vis.figures["Properties2D"]
# assert that the x-axis label is "step"
assert fig.layout.yaxis.title.text == "step"
assert fig.layout.xaxis.title.text == "energy"
Expand All @@ -60,6 +58,6 @@ def test_run_analysis_DihedralAngle(server, s22_energy_forces):

vis.socket.emit("analysis:run", {"method": {"discriminator": "DihedralAngle"}})
vis.socket.sleep(7)
fig = pio.from_json(vis.figures["DihedralAngle"])
fig = vis.figures["DihedralAngle"]
# assert that the x-axis label is "step"
assert fig.layout.xaxis.title.text == "step"
10 changes: 5 additions & 5 deletions zndraw/analyse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def run(self, vis):
customdata=np.stack([meta_step], axis=-1),
)

vis.figures = vis.figures | {"DihedralAngle": fig.to_json()}
vis.figures = vis.figures | {"DihedralAngle": fig}


class Distance(Extension):
Expand Down Expand Up @@ -111,7 +111,7 @@ def run(self, vis):
customdata=np.stack([meta_step], axis=-1),
)

vis.figures = vis.figures | {"Distance": fig.to_json()}
vis.figures = vis.figures | {"Distance": fig}


class Properties2D(Extension):
Expand Down Expand Up @@ -182,7 +182,7 @@ def run(self, vis):
customdata=np.stack([meta_step], axis=-1),
)

vis.figures = vis.figures | {"Properties2D": fig.to_json()}
vis.figures = vis.figures | {"Properties2D": fig}


class ForceCorrelation(Extension):
Expand Down Expand Up @@ -234,7 +234,7 @@ def run(self, vis):
fig = px.scatter(df, x=self.x_data, y=self.y_data, render_mode="svg")
fig.update_traces(customdata=np.stack([meta_step, meta_idx], axis=-1))

vis.figures = vis.figures | {"ForceCorrelation": fig.to_json()}
vis.figures = vis.figures | {"ForceCorrelation": fig}


class Properties1D(Extension):
Expand Down Expand Up @@ -294,7 +294,7 @@ def run(self, vis):
customdata=np.stack([meta_step], axis=-1),
)

vis.figures = vis.figures | {"Properties1D": fig.to_json()}
vis.figures = vis.figures | {"Properties1D": fig}


methods = t.Union[
Expand Down
4 changes: 2 additions & 2 deletions zndraw/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from zndraw.base import FileIO
from zndraw.bonds import ASEComputeBonds
from zndraw.exceptions import RoomLockedError
from zndraw.utils import ASEConverter, load_plots_to_json
from zndraw.utils import ASEConverter, load_plots_to_dict

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -285,7 +285,7 @@ def read_plots(paths: list[str], remote, rev) -> None:
token="default",
)

vis.figures = load_plots_to_json(paths, remote, rev)
vis.figures = load_plots_to_dict(paths, remote, rev)

vis.socket.sleep(1)
vis.socket.disconnect()
4 changes: 2 additions & 2 deletions zndraw/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from zndraw import ZnDraw

from .tasks import FileIO, get_generator_from_filename
from .utils import load_plots_to_json
from .utils import load_plots_to_dict

cli = typer.Typer()

Expand Down Expand Up @@ -50,4 +50,4 @@ def upload(
vis.extend(frames[1:])

figures = vis.figures
vis.figures = load_plots_to_json(plots, fileio.remote, fileio.rev) | figures
vis.figures = load_plots_to_dict(plots, fileio.remote, fileio.rev) | figures
11 changes: 7 additions & 4 deletions zndraw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import ase
import datamodel_code_generator
import numpy as np
import plotly.graph_objects as go
import plotly.graph_objs
import socketio.exceptions
import znjson
Expand Down Expand Up @@ -399,7 +400,9 @@ def get_plots_from_zntrack(path: str, remote: str | None, rev: str | None):
) from err


def load_plots_to_json(paths: list[str], remote: str | None, rev: str | None):
def load_plots_to_dict(
paths: list[str], remote: str | None, rev: str | None
) -> dict[str, go.Figure]:
data = {}
for path in paths:
if not pathlib.Path(path).exists():
Expand All @@ -410,15 +413,15 @@ def load_plots_to_json(paths: list[str], remote: str | None, rev: str | None):
else:
plots = znjson.loads(pathlib.Path(path).read_text())
if isinstance(plots, plotly.graph_objs.Figure):
data[path] = plots.to_json()
data[path] = plots
elif isinstance(plots, dict):
if not all(isinstance(v, plotly.graph_objs.Figure) for v in plots.values()):
raise ValueError("All values in the plots dict must be plotly.graph_objs")
data.update({f"{path}_{k}": v.to_json() for k, v in plots.items()})
data.update({f"{path}_{k}": v for k, v in plots.items()})
elif isinstance(plots, list):
if not all(isinstance(v, plotly.graph_objs.Figure) for v in plots):
raise ValueError("All values in the plots list must be plotly.graph_objs")
data.update({f"{path}_{i}": v.to_json() for i, v in enumerate(plots)})
data.update({f"{path}_{i}": v for i, v in enumerate(plots)})
else:
raise ValueError("The plots must be a dict, list or Figure")

Expand Down
12 changes: 8 additions & 4 deletions zndraw/zndraw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

import ase
import numpy as np
import plotly.graph_objects as go
import requests
import socketio.exceptions
import tqdm
import znjson
import znsocket
from plotly.io import from_json as ploty_from_json
from redis import Redis

from zndraw.base import Extension, ZnDrawBase
Expand Down Expand Up @@ -381,24 +383,26 @@ def step(self, value: int):
)

@property
def figures(self) -> dict[str, dict | t.Any]:
def figures(self) -> dict[str, go.Figure]:
# TODO: znjson.loads
return call_with_retry(
data = call_with_retry(
self.socket,
"analysis:figure:get",
retries=self.timeout["call_retries"],
)
return {k: ploty_from_json(v) for k, v in data.items()}

@figures.setter
def figures(self, fig: dict[str, dict | t.Any]) -> None:
def figures(self, data: dict[str, go.Figure]) -> None:
"""Update the figures on the remote."""
# TODO: can you use znsocket.Dict
# to update the data an avoid
# sending duplicates?
data = {k: v.to_json() for k, v in data.items()}
emit_with_retry(
self.socket,
"analysis:figure:set",
fig,
data,
retries=self.timeout["emit_retries"],
)

Expand Down

0 comments on commit 7dd970d

Please sign in to comment.