Skip to content

Commit

Permalink
show all atoms data (#639)
Browse files Browse the repository at this point in the history
* show all available data in analysis and when plotting vectors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* improve plotting: highlight atoms in the current frame if available

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update README.md

* update customdata for all plots

* use more checkboxes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* typo

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Aug 28, 2024
1 parent 684e781 commit 84cb793
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 34 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ ZnDraw enables you to analyze your data and generate plots using [Plotly](https:
![ZnDraw UI](https://raw.githubusercontent.com/zincware/ZnDraw/main/misc/darkmode/analysis.png#gh-dark-mode-only "ZnDraw Analysis")
![ZnDraw UI](https://raw.githubusercontent.com/zincware/ZnDraw/main/misc/lightmode/analysis.png#gh-light-mode-only "ZnDraw Analysis")

ZnDraw will look for the `step` and `atom` index in the [customdata](https://plotly.com/python/reference/scatter/#scatter-customdata)`[0]` and `[1]` respectively to highlight the steps and atoms.

## Writing Extensions

Make your tools accessible via the ZnDraw UI by writing an extension:
Expand Down
2 changes: 2 additions & 0 deletions app/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,8 @@ export default function App() {
setStep={setStep}
setSelectedFrames={setSelectedFrames}
addPlotsWindow={addPlotsWindow}
setSelectedIds={setSelectedIds}
step={step}
/>
{showParticleInfo && (
<>
Expand Down
23 changes: 20 additions & 3 deletions app/src/components/particles.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,30 @@ export const PerParticleVectors: React.FC<PerParticleVectorsProps> = ({
}, [vectors, arrowsConfig.normalize, arrowsConfig.colorrange]);

useEffect(() => {
if (!frame || !frame.calc || !frame.calc[property]) {
console.log(`Property ${property} not found in frame.calc`);
if (!frame) {
setVectors([]);
return;
} else {
let frameData;
if (property in frame.calc) {
frameData = frame.calc[property];
} else if (property in frame.arrays) {
frameData = frame.arrays[property];
} else {
console.error(`Property ${property} not found in frame`);
setVectors([]);
return;
}
if (frameData.length !== frame.positions.length) {
console.error(
`Length of property ${property} does not match the number of particles`,
);
setVectors([]);
return;
}

console.log(`Property ${property} found in frame.calc`);
const calculatedVectors = frame.calc[property].map((vector, i) => {
const calculatedVectors = frameData.map((vector, i) => {
const start = frame.positions[i];
const end = start
.clone()
Expand Down
25 changes: 23 additions & 2 deletions app/src/components/plotting.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@ import { IndicesState } from "./utils";

interface PlottingProps {
setStep: (step: number) => void;
step: number;
setSelectedFrames: (selectedFrames: IndicesState) => void;
addPlotsWindow: number;
setSelectedIds: (selectedIds: Set<number>) => void;
}

export const Plotting = ({
setStep,
step,
setSelectedFrames,
addPlotsWindow,
setSelectedIds,
}: PlottingProps) => {
const [availablePlots, setAvailablePlots] = useState<string[]>([]);
const [plotData, setPlotData] = useState<{ [key: string]: any }>({});
Expand Down Expand Up @@ -72,6 +76,8 @@ export const Plotting = ({
setDisplayedCards={setDisplayedCards}
setStep={setStep}
setSelectedFrames={setSelectedFrames}
setSelectedIds={setSelectedIds}
step={step}
/>
))}
</>
Expand All @@ -85,7 +91,9 @@ interface PlotsCardProps {
plotData: { [key: string]: any };
setDisplayedCards: (displayedCards: number[]) => void;
setStep: (step: number) => void;
step: number;
setSelectedFrames: (selectedFrames: IndicesState) => void;
setSelectedIds: (selectedIds: Set<number>) => void;
}

const PlotsCard = ({
Expand All @@ -95,7 +103,9 @@ const PlotsCard = ({
plotData,
setDisplayedCards,
setStep,
step,
setSelectedFrames,
setSelectedIds,
}: PlotsCardProps) => {
const cardRef = useRef<any>(null);
const [selectedOption, setSelectedOption] = useState<string>("");
Expand Down Expand Up @@ -130,8 +140,9 @@ const PlotsCard = ({
const onPlotClick = ({ points }: { points: any[] }) => {
if (points[0]?.customdata) {
setStep(points[0].customdata[0]);
} else {
setStep(points[0].pointIndex);
if (points[0].customdata[1]) {
setSelectedIds(new Set([points[0].customdata[1]]));
}
}
};

Expand All @@ -140,6 +151,16 @@ const PlotsCard = ({
const selectedFrames = event.points.map((point: any) =>
point.customdata ? point.customdata[0] : point.pointIndex,
);
// for all points.customdata[0] == step collect the points.customdata[1] and set selectedIds if customdata[1] is available
const selectedIds = new Set<number>(
event.points
.filter((point: any) => point.customdata && point.customdata[1])
.map((point: any) => point.customdata[1]),
);
if (selectedIds.size > 0) {
setSelectedIds(selectedIds);
}

setSelectedFrames({
active: true,
indices: new Set(selectedFrames),
Expand Down
149 changes: 127 additions & 22 deletions zndraw/analyse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import typing as t

import ase
import numpy as np
import pandas as pd
import plotly.express as px
Expand All @@ -23,6 +24,19 @@ def _schema_from_atoms(schema, cls):
return cls.model_json_schema_from_atoms(schema)


def _get_data_from_frames(key, frames: list[ase.Atoms]):
if frames[0].calc is not None and key in frames[0].calc.results:
data = np.array([x.calc.results[key] for x in frames])
elif key in frames[0].arrays:
data = np.array([x.arrays[key] for x in frames])
elif key in frames[0].info:
data = np.array([x.info[key] for x in frames])
else:
raise ValueError(f"Property '{key}' not found in atoms")

return data


class DihedralAngle(Extension):
def run(self, vis):
atoms_lst = list(vis)
Expand All @@ -38,22 +52,39 @@ def run(self, vis):
{"step": list(range(len(atoms_lst))), "dihedral": dihedral_angles}
)
fig = px.line(df, x="step", y="dihedral", render_mode="svg")

meta_step = np.arange(len(atoms_lst))
# meta_idx = np.full_like(meta_step, np.nan)

fig.update_traces(
customdata=np.stack([meta_step], axis=-1),
)

vis.figures = vis.figures | {"DihedralAngle": fig.to_json()}


class Distance(Extension):
smooth: bool = False
mic: bool = True

model_config = ConfigDict(json_schema_extra=_schema_from_atoms)

@staticmethod
def model_json_schema_from_atoms(schema: dict) -> dict:
schema["properties"]["smooth"]["format"] = "checkbox"
schema["properties"]["mic"]["format"] = "checkbox"

return schema

def run(self, vis):
atoms_lst, ids = list(vis), vis.selection
distances = {}
for x in itertools.combinations(ids, 2):
distances[f"{tuple(x)}"] = []
for atoms in atoms_lst:
positions = atoms.get_positions()
for x in itertools.combinations(ids, 2):
distances[f"{tuple(x)}"].append(
np.linalg.norm(positions[x[0]] - positions[x[1]])
atoms.get_distance(x[0], x[1], mic=self.mic)
)

df = pd.DataFrame({"step": list(range(len(atoms_lst)))} | distances)
Expand All @@ -73,6 +104,13 @@ def run(self, vis):
fig.add_scatter(
x=smooth_df["step"], y=smooth_df[col], name=f"smooth_{col}"
)
meta_step = np.arange(len(atoms_lst))
# meta_idx = np.full_like(meta_step, np.nan)

fig.update_traces(
customdata=np.stack([meta_step], axis=-1),
)

vis.figures = vis.figures | {"Distance": fig.to_json()}


Expand All @@ -89,12 +127,19 @@ def model_json_schema_from_atoms(cls, schema: dict) -> dict:
ATOMS = cls.get_atoms()
log.debug(f"GATHERING PROPERTIES FROM {ATOMS=}")
try:
available_properties = list(ATOMS.calc.results)
available_properties += list(ATOMS.arrays)
available_properties = list(ATOMS.arrays.keys())
available_properties += list(ATOMS.info.keys())
if ATOMS.calc is not None:
available_properties += list(
ATOMS.calc.results.keys()
) # global ATOMS object

available_properties += ["step"]
schema["properties"]["x_data"]["enum"] = available_properties
schema["properties"]["y_data"]["enum"] = available_properties
schema["properties"]["color"]["enum"] = available_properties
schema["properties"]["fix_aspect_ratio"]["format"] = "checkbox"

except AttributeError:
pass
return schema
Expand All @@ -106,43 +151,92 @@ def run(self, vis):
if self.x_data == "step":
x_data = list(range(len(atoms_lst)))
else:
try:
x_data = [x.calc.results[self.x_data] for x in atoms_lst]
except KeyError:
x_data = [x.arrays[self.x_data] for x in atoms_lst]
x_data = _get_data_from_frames(self.x_data, atoms_lst)

if self.y_data == "step":
y_data = list(range(len(atoms_lst)))
else:
try:
y_data = [x.calc.results[self.y_data] for x in atoms_lst]
except KeyError:
y_data = [x.arrays[self.y_data] for x in atoms_lst]
y_data = _get_data_from_frames(self.y_data, atoms_lst)

if self.color == "step":
color = list(range(len(atoms_lst)))
else:
try:
color = [x.calc.results[self.color] for x in atoms_lst]
except KeyError:
color = [x.arrays[self.color] for x in atoms_lst]
color = _get_data_from_frames(self.color, atoms_lst)

y_data = np.array(y_data).reshape(-1)
x_data = np.array(x_data).reshape(-1)
color = np.array(color).reshape(-1)

df = pd.DataFrame({self.x_data: x_data, self.y_data: y_data, self.color: color})
fig = px.scatter(
df, x=self.x_data, y=self.y_data, color=self.color, render_mode="svg"
)
fig = px.scatter(df, x=self.x_data, y=self.y_data, color=self.color)
if self.fix_aspect_ratio:
fig.update_yaxes(
scaleanchor="x",
scaleratio=1,
)

meta_step = np.arange(len(atoms_lst))
# meta_idx = np.full_like(meta_step, np.nan)

fig.update_traces(
customdata=np.stack([meta_step], axis=-1),
)

vis.figures = vis.figures | {"Properties2D": fig.to_json()}


class ForceCorrelation(Extension):
"""Compute the correlation between two properties for the current frame."""

x_data: str
y_data: str

model_config = ConfigDict(json_schema_extra=_schema_from_atoms)

@classmethod
def model_json_schema_from_atoms(cls, schema: dict) -> dict:
ATOMS = cls.get_atoms()
try:
available_properties = list(ATOMS.arrays.keys())
available_properties += list(ATOMS.info.keys())
if ATOMS.calc is not None:
available_properties += list(ATOMS.calc.results.keys())
schema["properties"]["x_data"]["enum"] = available_properties
schema["properties"]["y_data"]["enum"] = available_properties
except AttributeError:
pass
return schema

def run(self, vis):
atoms = vis.atoms
x_data = _get_data_from_frames(self.x_data, [atoms])
y_data = _get_data_from_frames(self.y_data, [atoms])

x_data = np.linalg.norm(x_data, axis=-1)
y_data = np.linalg.norm(y_data, axis=-1)

vis.log(f"x_data: {x_data.shape}, y_data: {y_data.shape}")

x_data = x_data.reshape(-1)
y_data = y_data.reshape(-1)

current_step = vis.step
meta_step = [current_step for _ in range(len(x_data))]
meta_idx = list(range(len(x_data)))

df = pd.DataFrame(
{
self.x_data: x_data,
self.y_data: y_data,
}
)

fig = px.scatter(df, x=self.x_data, y=self.y_data, render_mode="svg")
fig.update_traces(customdata=np.stack([meta_step, meta_idx], axis=-1))

vis.figures = vis.figures | {"ForceCorrelation": fig.to_json()}


class Properties1D(Extension):
value: str
smooth: bool = False
Expand All @@ -157,8 +251,11 @@ class Properties1D(Extension):
def model_json_schema_from_atoms(cls, schema: dict) -> dict:
ATOMS = cls.get_atoms()
try:
available_properties = list(ATOMS.calc.results.keys()) # global ATOMS object
log.debug(f"AVAILABLE PROPERTIES: {available_properties=}")
available_properties = list(ATOMS.arrays.keys())
available_properties += list(ATOMS.info.keys())
if ATOMS.calc is not None:
available_properties += list(ATOMS.calc.results.keys())
log.critical(f"AVAILABLE PROPERTIES: {available_properties=}")
schema["properties"]["value"]["enum"] = available_properties
except AttributeError:
print(f"{ATOMS=}")
Expand All @@ -167,7 +264,7 @@ def model_json_schema_from_atoms(cls, schema: dict) -> dict:
def run(self, vis):
vis.log("Downloading data...")
atoms_lst = list(vis)
data = np.array([x.calc.results[self.value] for x in atoms_lst])
data = _get_data_from_frames(self.value, atoms_lst)

if data.ndim > 1:
axis = tuple(range(1, data.ndim))
Expand All @@ -190,6 +287,13 @@ def run(self, vis):
x=smooth_df["step"], y=smooth_df[col], name=f"smooth_{col}"
)

meta_step = np.arange(len(atoms_lst))
# meta_idx = np.full_like(meta_step, np.nan)

fig.update_traces(
customdata=np.stack([meta_step], axis=-1),
)

vis.figures = vis.figures | {"Properties1D": fig.to_json()}


Expand All @@ -198,6 +302,7 @@ def run(self, vis):
DihedralAngle,
Distance,
Properties2D,
ForceCorrelation,
]


Expand Down
Loading

0 comments on commit 84cb793

Please sign in to comment.