diff --git a/zndraw/app.py b/zndraw/app.py index 31f230231..4df55334b 100644 --- a/zndraw/app.py +++ b/zndraw/app.py @@ -55,12 +55,6 @@ def exit_route(): return "Server shutting down..." -@io.on("atoms:request") -def atoms_request(url): - """Return the atoms.""" - emit("atoms:request", url, broadcast=True, include_self=False) - - @io.on("modifier:schema") def modifier_schema(): config = GlobalConfig.load() @@ -82,11 +76,6 @@ def modifier_schema(): ) -@io.on("modifier:run") -def modifier_run(data): - emit("modifier:run", data, broadcast=True, include_self=False) - - @io.on("analysis:schema") def analysis_schema(data): config = GlobalConfig.load() @@ -111,85 +100,6 @@ def selection_schema(): io.emit("selection:schema", get_selection_class().model_json_schema()) -@io.on("selection:run") -def selection_run(data): - emit("selection:run", data, broadcast=True, include_self=False) - - -@io.on("analysis:run") -def analysis_run(data): - emit("analysis:run", data, broadcast=True, include_self=False) - - -@io.on("config") -def config(data): - pass - - -@io.on("download") -def download(data): - atoms = [atoms_from_json(x) for x in data["atoms_list"].values()] - if "selection" in data: - atoms = [atoms[data["selection"]] for atoms in atoms] - import ase.io - - file = StringIO() - ase.io.write(file, atoms, format="extxyz") - file.seek(0) - return file.read() - - -@io.on("atoms:download") -def atoms_download(data): - """Return the atoms.""" - emit("atoms:download", data, broadcast=True, include_self=False) - - -@io.on("atoms:upload") -def atoms_upload(data): - """Return the atoms.""" - emit("atoms:upload", data, broadcast=True, include_self=False) - - -@io.on("view:set") -def display(data): - """Display the atoms at the given index""" - emit("view:set", data, broadcast=True, include_self=False) - - -@io.on("atoms:size") -def atoms_size(data): - """Return the atoms.""" - emit("atoms:size", data, broadcast=True, include_self=False) - - -@io.on("upload") -def upload(data): - from io import StringIO - - import ase.io - import tqdm - - # tested with small files only - - format = data["filename"].split(".")[-1] - if format == "h5": - print("H5MD format not supported for uploading yet") - # import znh5md - # stream = BytesIO(data["content"].encode("utf-8")) - # atoms = znh5md.ASEH5MD(stream).get_atoms_list() - # for idx, atoms in tqdm.tqdm(enumerate(atoms)): - # atoms_dict = atoms_to_json(atoms) - # io.emit("atoms:upload", {idx: atoms_dict}) - else: - stream = StringIO(data["content"]) - io.emit("atoms:clear", 0) - for idx, atoms in tqdm.tqdm(enumerate(ase.io.iread(stream, format=format))): - atoms_dict = atoms_to_json(atoms) - io.emit("atoms:upload", {idx: atoms_dict}) - emit("view:set", 0) - - @io.on("scene:schema") def scene_schema(): import enum @@ -254,6 +164,66 @@ class Scene(BaseModel): return schema +@io.on("atoms:request") +def atoms_request(url): + """Return the atoms.""" + emit("atoms:request", url, broadcast=True, include_self=False) + + +@io.on("modifier:run") +def modifier_run(data): + emit("modifier:run", data, broadcast=True, include_self=False) + + +@io.on("selection:run") +def selection_run(data): + emit("selection:run", data, broadcast=True, include_self=False) + + +@io.on("analysis:run") +def analysis_run(data): + emit("analysis:run", data, broadcast=True, include_self=False) + + +@io.on("download:request") +def download_request(data): + emit("download:request", data, broadcast=True, include_self=False) + + +@io.on("download:response") +def download_response(data): + emit("download:response", data, broadcast=True, include_self=False) + + +@io.on("atoms:download") +def atoms_download(data): + """Return the atoms.""" + emit("atoms:download", data, broadcast=True, include_self=False) + + +@io.on("atoms:upload") +def atoms_upload(data): + """Return the atoms.""" + emit("atoms:upload", data, broadcast=True, include_self=False) + + +@io.on("view:set") +def display(data): + """Display the atoms at the given index""" + emit("view:set", data, broadcast=True, include_self=False) + + +@io.on("atoms:size") +def atoms_size(data): + """Return the atoms.""" + emit("atoms:size", data, broadcast=True, include_self=False) + + +@io.on("upload") +def upload(data): + emit("upload", data, broadcast=True, include_self=False) + + @io.on("draw:schema") def draw_schema(): io.emit("draw:schema", Geometry.updated_schema()) diff --git a/zndraw/static/UI/UI.js b/zndraw/static/UI/UI.js index 9f151e1eb..02a4a0cde 100644 --- a/zndraw/static/UI/UI.js +++ b/zndraw/static/UI/UI.js @@ -190,36 +190,28 @@ export function setUIEvents(socket, cache, world) { fetch("/exit", { method: "GET" }); }); - document.getElementById("downloadBtn").addEventListener("click", () => { - socket.emit("download", { atoms_list: cache.getAllAtoms() }, (data) => { - const blob = new Blob([data], { type: "text/csv" }); + socket.on("download:response", (data) => { + const blob = new Blob([data], { type: "text/csv" }); const elem = window.document.createElement("a"); elem.href = window.URL.createObjectURL(blob); elem.download = "trajectory.xyz"; document.body.appendChild(elem); elem.click(); document.body.removeChild(elem); - }); + }); + + document.getElementById("downloadBtn").addEventListener("click", () => { + socket.emit("download:request", { }); }); document .getElementById("downloadSelectedBtn") .addEventListener("click", () => { socket.emit( - "download", + "download:request", { - atoms_list: cache.getAllAtoms(), selection: world.getSelection(), - }, - (data) => { - const blob = new Blob([data], { type: "text/csv" }); - const elem = window.document.createElement("a"); - elem.href = window.URL.createObjectURL(blob); - elem.download = "trajectory.xyz"; - document.body.appendChild(elem); - elem.click(); - document.body.removeChild(elem); - }, + } ); }); } diff --git a/zndraw/zndraw.py b/zndraw/zndraw.py index d3281541f..7d31f5fcd 100644 --- a/zndraw/zndraw.py +++ b/zndraw/zndraw.py @@ -5,6 +5,7 @@ import threading import time import typing as t +from io import StringIO import ase import ase.io @@ -104,6 +105,8 @@ def __post_init__(self): self.socket.on("modifier:run", self._run_modifier) self.socket.on("selection:run", self._run_selection) self.socket.on("analysis:run", self._run_analysis) + self.socket.on("download:request", self._download_file) + self.socket.on("upload", self._upload_file) self.socket.on("disconnect", lambda: self.disconnect()) @@ -331,3 +334,38 @@ def _run_analysis(self, data): fig = instance.run(atoms_list, data["selection"]) self.socket.emit("analysis:figure", fig.to_json()) + + def _download_file(self, data): + atoms = list(self) + if "selection" in data: + atoms = [atoms[data["selection"]] for atoms in atoms] + import ase.io + + file = StringIO() + ase.io.write(file, atoms, format="extxyz") + file.seek(0) + + self.socket.emit("download:response", file.read()) + + def _upload_file(self, data): + from io import StringIO + + import ase.io + import tqdm + + # tested with small files only + + format = data["filename"].split(".")[-1] + if format == "h5": + print("H5MD format not supported for uploading yet") + # import znh5md + # stream = BytesIO(data["content"].encode("utf-8")) + # atoms = znh5md.ASEH5MD(stream).get_atoms_list() + # for idx, atoms in tqdm.tqdm(enumerate(atoms)): + # atoms_dict = atoms_to_json(atoms) + # io.emit("atoms:upload", {idx: atoms_dict}) + else: + stream = StringIO(data["content"]) + del self[:] + for atoms in tqdm.tqdm(ase.io.iread(stream, format=format)): + self.append(atoms)