diff --git a/README.md b/README.md index 2fd9d076..c9fba7a3 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,8 @@ ZnDraw enables you to analyze your data and generate plots using [Plotly](https: ![ZnDraw UI](https://raw.githubusercontent.com/zincware/ZnDraw/main/misc/darkmode/analysis.png#gh-dark-mode-only "ZnDraw Analysis") ![ZnDraw UI](https://raw.githubusercontent.com/zincware/ZnDraw/main/misc/lightmode/analysis.png#gh-light-mode-only "ZnDraw Analysis") +ZnDraw will look for the `step` and `atom` index in the [customdata](https://plotly.com/python/reference/scatter/#scatter-customdata)`[0]` and `[1]` respectively to highlight the steps and atoms. + ## Writing Extensions Make your tools accessible via the ZnDraw UI by writing an extension: diff --git a/app/src/App.tsx b/app/src/App.tsx index 08f00384..3d79cbdb 100644 --- a/app/src/App.tsx +++ b/app/src/App.tsx @@ -848,6 +848,8 @@ export default function App() { setStep={setStep} setSelectedFrames={setSelectedFrames} addPlotsWindow={addPlotsWindow} + setSelectedIds={setSelectedIds} + step={step} /> {showParticleInfo && ( <> diff --git a/app/src/components/particles.tsx b/app/src/components/particles.tsx index abf4ec32..0a2702a9 100644 --- a/app/src/components/particles.tsx +++ b/app/src/components/particles.tsx @@ -579,13 +579,30 @@ export const PerParticleVectors: React.FC = ({ }, [vectors, arrowsConfig.normalize, arrowsConfig.colorrange]); useEffect(() => { - if (!frame || !frame.calc || !frame.calc[property]) { - console.log(`Property ${property} not found in frame.calc`); + if (!frame) { setVectors([]); return; } else { + let frameData; + if (property in frame.calc) { + frameData = frame.calc[property]; + } else if (property in frame.arrays) { + frameData = frame.arrays[property]; + } else { + console.error(`Property ${property} not found in frame`); + setVectors([]); + return; + } + if (frameData.length !== frame.positions.length) { + console.error( + `Length of property ${property} does not match the number of particles`, + ); + setVectors([]); + return; + } + console.log(`Property ${property} found in frame.calc`); - const calculatedVectors = frame.calc[property].map((vector, i) => { + const calculatedVectors = frameData.map((vector, i) => { const start = frame.positions[i]; const end = start .clone() diff --git a/app/src/components/plotting.tsx b/app/src/components/plotting.tsx index d7915d2a..1c2237f3 100644 --- a/app/src/components/plotting.tsx +++ b/app/src/components/plotting.tsx @@ -11,14 +11,18 @@ import { IndicesState } from "./utils"; interface PlottingProps { setStep: (step: number) => void; + step: number; setSelectedFrames: (selectedFrames: IndicesState) => void; addPlotsWindow: number; + setSelectedIds: (selectedIds: Set) => void; } export const Plotting = ({ setStep, + step, setSelectedFrames, addPlotsWindow, + setSelectedIds, }: PlottingProps) => { const [availablePlots, setAvailablePlots] = useState([]); const [plotData, setPlotData] = useState<{ [key: string]: any }>({}); @@ -72,6 +76,8 @@ export const Plotting = ({ setDisplayedCards={setDisplayedCards} setStep={setStep} setSelectedFrames={setSelectedFrames} + setSelectedIds={setSelectedIds} + step={step} /> ))} @@ -85,7 +91,9 @@ interface PlotsCardProps { plotData: { [key: string]: any }; setDisplayedCards: (displayedCards: number[]) => void; setStep: (step: number) => void; + step: number; setSelectedFrames: (selectedFrames: IndicesState) => void; + setSelectedIds: (selectedIds: Set) => void; } const PlotsCard = ({ @@ -95,7 +103,9 @@ const PlotsCard = ({ plotData, setDisplayedCards, setStep, + step, setSelectedFrames, + setSelectedIds, }: PlotsCardProps) => { const cardRef = useRef(null); const [selectedOption, setSelectedOption] = useState(""); @@ -130,8 +140,9 @@ const PlotsCard = ({ const onPlotClick = ({ points }: { points: any[] }) => { if (points[0]?.customdata) { setStep(points[0].customdata[0]); - } else { - setStep(points[0].pointIndex); + if (points[0].customdata[1]) { + setSelectedIds(new Set([points[0].customdata[1]])); + } } }; @@ -140,6 +151,16 @@ const PlotsCard = ({ const selectedFrames = event.points.map((point: any) => point.customdata ? point.customdata[0] : point.pointIndex, ); + // for all points.customdata[0] == step collect the points.customdata[1] and set selectedIds if customdata[1] is available + const selectedIds = new Set( + event.points + .filter((point: any) => point.customdata && point.customdata[1]) + .map((point: any) => point.customdata[1]), + ); + if (selectedIds.size > 0) { + setSelectedIds(selectedIds); + } + setSelectedFrames({ active: true, indices: new Set(selectedFrames), diff --git a/zndraw/analyse/__init__.py b/zndraw/analyse/__init__.py index 31fc023e..94d1efac 100644 --- a/zndraw/analyse/__init__.py +++ b/zndraw/analyse/__init__.py @@ -2,6 +2,7 @@ import logging import typing as t +import ase import numpy as np import pandas as pd import plotly.express as px @@ -23,6 +24,19 @@ def _schema_from_atoms(schema, cls): return cls.model_json_schema_from_atoms(schema) +def _get_data_from_frames(key, frames: list[ase.Atoms]): + if frames[0].calc is not None and key in frames[0].calc.results: + data = np.array([x.calc.results[key] for x in frames]) + elif key in frames[0].arrays: + data = np.array([x.arrays[key] for x in frames]) + elif key in frames[0].info: + data = np.array([x.info[key] for x in frames]) + else: + raise ValueError(f"Property '{key}' not found in atoms") + + return data + + class DihedralAngle(Extension): def run(self, vis): atoms_lst = list(vis) @@ -38,11 +52,29 @@ def run(self, vis): {"step": list(range(len(atoms_lst))), "dihedral": dihedral_angles} ) fig = px.line(df, x="step", y="dihedral", render_mode="svg") + + meta_step = np.arange(len(atoms_lst)) + # meta_idx = np.full_like(meta_step, np.nan) + + fig.update_traces( + customdata=np.stack([meta_step], axis=-1), + ) + vis.figures = vis.figures | {"DihedralAngle": fig.to_json()} class Distance(Extension): smooth: bool = False + mic: bool = True + + model_config = ConfigDict(json_schema_extra=_schema_from_atoms) + + @staticmethod + def model_json_schema_from_atoms(schema: dict) -> dict: + schema["properties"]["smooth"]["format"] = "checkbox" + schema["properties"]["mic"]["format"] = "checkbox" + + return schema def run(self, vis): atoms_lst, ids = list(vis), vis.selection @@ -50,10 +82,9 @@ def run(self, vis): for x in itertools.combinations(ids, 2): distances[f"{tuple(x)}"] = [] for atoms in atoms_lst: - positions = atoms.get_positions() for x in itertools.combinations(ids, 2): distances[f"{tuple(x)}"].append( - np.linalg.norm(positions[x[0]] - positions[x[1]]) + atoms.get_distance(x[0], x[1], mic=self.mic) ) df = pd.DataFrame({"step": list(range(len(atoms_lst)))} | distances) @@ -73,6 +104,13 @@ def run(self, vis): fig.add_scatter( x=smooth_df["step"], y=smooth_df[col], name=f"smooth_{col}" ) + meta_step = np.arange(len(atoms_lst)) + # meta_idx = np.full_like(meta_step, np.nan) + + fig.update_traces( + customdata=np.stack([meta_step], axis=-1), + ) + vis.figures = vis.figures | {"Distance": fig.to_json()} @@ -89,12 +127,19 @@ def model_json_schema_from_atoms(cls, schema: dict) -> dict: ATOMS = cls.get_atoms() log.debug(f"GATHERING PROPERTIES FROM {ATOMS=}") try: - available_properties = list(ATOMS.calc.results) - available_properties += list(ATOMS.arrays) + available_properties = list(ATOMS.arrays.keys()) + available_properties += list(ATOMS.info.keys()) + if ATOMS.calc is not None: + available_properties += list( + ATOMS.calc.results.keys() + ) # global ATOMS object + available_properties += ["step"] schema["properties"]["x_data"]["enum"] = available_properties schema["properties"]["y_data"]["enum"] = available_properties schema["properties"]["color"]["enum"] = available_properties + schema["properties"]["fix_aspect_ratio"]["format"] = "checkbox" + except AttributeError: pass return schema @@ -106,43 +151,92 @@ def run(self, vis): if self.x_data == "step": x_data = list(range(len(atoms_lst))) else: - try: - x_data = [x.calc.results[self.x_data] for x in atoms_lst] - except KeyError: - x_data = [x.arrays[self.x_data] for x in atoms_lst] + x_data = _get_data_from_frames(self.x_data, atoms_lst) if self.y_data == "step": y_data = list(range(len(atoms_lst))) else: - try: - y_data = [x.calc.results[self.y_data] for x in atoms_lst] - except KeyError: - y_data = [x.arrays[self.y_data] for x in atoms_lst] + y_data = _get_data_from_frames(self.y_data, atoms_lst) if self.color == "step": color = list(range(len(atoms_lst))) else: - try: - color = [x.calc.results[self.color] for x in atoms_lst] - except KeyError: - color = [x.arrays[self.color] for x in atoms_lst] + color = _get_data_from_frames(self.color, atoms_lst) y_data = np.array(y_data).reshape(-1) x_data = np.array(x_data).reshape(-1) color = np.array(color).reshape(-1) df = pd.DataFrame({self.x_data: x_data, self.y_data: y_data, self.color: color}) - fig = px.scatter( - df, x=self.x_data, y=self.y_data, color=self.color, render_mode="svg" - ) + fig = px.scatter(df, x=self.x_data, y=self.y_data, color=self.color) if self.fix_aspect_ratio: fig.update_yaxes( scaleanchor="x", scaleratio=1, ) + + meta_step = np.arange(len(atoms_lst)) + # meta_idx = np.full_like(meta_step, np.nan) + + fig.update_traces( + customdata=np.stack([meta_step], axis=-1), + ) + vis.figures = vis.figures | {"Properties2D": fig.to_json()} +class ForceCorrelation(Extension): + """Compute the correlation between two properties for the current frame.""" + + x_data: str + y_data: str + + model_config = ConfigDict(json_schema_extra=_schema_from_atoms) + + @classmethod + def model_json_schema_from_atoms(cls, schema: dict) -> dict: + ATOMS = cls.get_atoms() + try: + available_properties = list(ATOMS.arrays.keys()) + available_properties += list(ATOMS.info.keys()) + if ATOMS.calc is not None: + available_properties += list(ATOMS.calc.results.keys()) + schema["properties"]["x_data"]["enum"] = available_properties + schema["properties"]["y_data"]["enum"] = available_properties + except AttributeError: + pass + return schema + + def run(self, vis): + atoms = vis.atoms + x_data = _get_data_from_frames(self.x_data, [atoms]) + y_data = _get_data_from_frames(self.y_data, [atoms]) + + x_data = np.linalg.norm(x_data, axis=-1) + y_data = np.linalg.norm(y_data, axis=-1) + + vis.log(f"x_data: {x_data.shape}, y_data: {y_data.shape}") + + x_data = x_data.reshape(-1) + y_data = y_data.reshape(-1) + + current_step = vis.step + meta_step = [current_step for _ in range(len(x_data))] + meta_idx = list(range(len(x_data))) + + df = pd.DataFrame( + { + self.x_data: x_data, + self.y_data: y_data, + } + ) + + fig = px.scatter(df, x=self.x_data, y=self.y_data, render_mode="svg") + fig.update_traces(customdata=np.stack([meta_step, meta_idx], axis=-1)) + + vis.figures = vis.figures | {"ForceCorrelation": fig.to_json()} + + class Properties1D(Extension): value: str smooth: bool = False @@ -157,8 +251,11 @@ class Properties1D(Extension): def model_json_schema_from_atoms(cls, schema: dict) -> dict: ATOMS = cls.get_atoms() try: - available_properties = list(ATOMS.calc.results.keys()) # global ATOMS object - log.debug(f"AVAILABLE PROPERTIES: {available_properties=}") + available_properties = list(ATOMS.arrays.keys()) + available_properties += list(ATOMS.info.keys()) + if ATOMS.calc is not None: + available_properties += list(ATOMS.calc.results.keys()) + log.critical(f"AVAILABLE PROPERTIES: {available_properties=}") schema["properties"]["value"]["enum"] = available_properties except AttributeError: print(f"{ATOMS=}") @@ -167,7 +264,7 @@ def model_json_schema_from_atoms(cls, schema: dict) -> dict: def run(self, vis): vis.log("Downloading data...") atoms_lst = list(vis) - data = np.array([x.calc.results[self.value] for x in atoms_lst]) + data = _get_data_from_frames(self.value, atoms_lst) if data.ndim > 1: axis = tuple(range(1, data.ndim)) @@ -190,6 +287,13 @@ def run(self, vis): x=smooth_df["step"], y=smooth_df[col], name=f"smooth_{col}" ) + meta_step = np.arange(len(atoms_lst)) + # meta_idx = np.full_like(meta_step, np.nan) + + fig.update_traces( + customdata=np.stack([meta_step], axis=-1), + ) + vis.figures = vis.figures | {"Properties1D": fig.to_json()} @@ -198,6 +302,7 @@ def run(self, vis): DihedralAngle, Distance, Properties2D, + ForceCorrelation, ] diff --git a/zndraw/scene.py b/zndraw/scene.py index b7a0ed3a..76bc7cc6 100644 --- a/zndraw/scene.py +++ b/zndraw/scene.py @@ -2,6 +2,7 @@ import typing as t import ase +import numpy as np import znjson import znsocket from flask import current_app, session @@ -110,9 +111,21 @@ def get_updated_schema(cls) -> dict: schema = cls.model_json_schema() atoms = cls._get_atoms() - if atoms.calc is not None and "forces" in atoms.calc.results: - schema["properties"]["vectors"]["enum"] = ["", "forces"] - schema["properties"]["vectors"]["default"] = "" + array_props = [""] + for key in atoms.calc.results.keys(): + if ( + np.array(atoms.calc.results[key]).ndim == 2 + and np.array(atoms.calc.results[key]).shape[1] == 3 + ): + array_props.append(key) + for key in atoms.arrays.keys(): + if ( + np.array(atoms.arrays[key]).ndim == 2 + and np.array(atoms.arrays[key]).shape[1] == 3 + ): + array_props.append(key) + schema["properties"]["vectors"]["enum"] = array_props + schema["properties"]["vectors"]["default"] = "" # schema["properties"]["wireframe"]["format"] = "checkbox" schema["properties"]["Animation Loop"]["format"] = "checkbox" diff --git a/zndraw/utils.py b/zndraw/utils.py index 34ad3fa8..d6b87169 100644 --- a/zndraw/utils.py +++ b/zndraw/utils.py @@ -145,6 +145,14 @@ def encode(self, obj: ase.Atoms) -> ASEDict: else obj.arrays["radii"] ) + for key in obj.arrays: + if key in ["colors", "radii"]: + continue + if isinstance(obj.arrays[key], np.ndarray): + arrays[key] = obj.arrays[key].tolist() + else: + arrays[key] = obj.arrays[key] + if hasattr(obj, "connectivity") and obj.connectivity is not None: connectivity = ( obj.connectivity.tolist() @@ -178,10 +186,8 @@ def decode(self, value: ASEDict) -> ase.Atoms: if connectivity := value.get("connectivity"): # or do we want this to be nx.Graph? atoms.connectivity = np.array(connectivity) - if "colors" in value["arrays"]: - atoms.arrays["colors"] = np.array(value["arrays"]["colors"]) - if "radii" in value["arrays"]: - atoms.arrays["radii"] = np.array(value["arrays"]["radii"]) + for key, val in value["arrays"].items(): + atoms.arrays[key] = np.array(val) if calc := value.get("calc"): atoms.calc = SinglePointCalculator(atoms) atoms.calc.results.update(calc)