diff --git a/znvis/__init__.py b/znvis/__init__.py index ba99d48..5306b99 100644 --- a/znvis/__init__.py +++ b/znvis/__init__.py @@ -28,12 +28,16 @@ from znvis.mesh.custom import CustomMesh from znvis.mesh.cylinder import Cylinder from znvis.mesh.sphere import Sphere +from znvis.mesh.arrow import Arrow from znvis.particle.particle import Particle +from znvis.particle.vector_field import VectorField from znvis.visualizer.visualizer import Visualizer __all__ = [ Particle.__name__, Sphere.__name__, + Arrow.__name__, + VectorField.__name__, Visualizer.__name__, Cylinder.__name__, CustomMesh.__name__, diff --git a/znvis/mesh/arrow.py b/znvis/mesh/arrow.py new file mode 100644 index 0000000..5f9382c --- /dev/null +++ b/znvis/mesh/arrow.py @@ -0,0 +1,89 @@ +""" +ZnVis: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Create a sphere mesh. +""" + +from dataclasses import dataclass + +import numpy as np +import open3d as o3d + +from znvis.transformations.rotation_matrices import rotation_matrix + +from znvis.mesh import Mesh + + +@dataclass +class Arrow(Mesh): + """ + A class to produce arrow meshes. + + Attributes + ---------- + scale : float + Scale of the arrow + resolution : int + Resolution of the mesh. + """ + scale: float = 1.0 + resolution: int = 10 + + def create_mesh( + self, starting_position: np.ndarray, direction: np.ndarray = None + ) -> o3d.geometry.TriangleMesh: + """ + Create a mesh object defined by the dataclass. + + Parameters + ---------- + starting_position : np.ndarray shape=(3,) + Starting position of the mesh. + direction : np.ndarray shape=(3,) (default = None) + Direction of the mesh. + + Returns + ------- + mesh : o3d.geometry.TriangleMesh + """ + + direction_length = np.linalg.norm(direction) + + cylinder_radius = 0.06 * direction_length * self.scale + cylinder_height = 0.85 * direction_length * self.scale + cone_radius = 0.15 * direction_length * self.scale + cone_height = 0.15 * direction_length * self.scale + + arrow = o3d.geometry.TriangleMesh.create_arrow( + cylinder_radius=cylinder_radius, + cylinder_height=cylinder_height, + cone_radius=cone_radius, + cone_height=cone_height, + resolution=self.resolution + ) + + arrow.compute_vertex_normals() + matrix = rotation_matrix(np.array([0, 0, 1]), direction) + arrow.rotate(matrix, center=(0, 0, 0)) + + # Translate the arrow to the starting position and center the origin + arrow.translate(starting_position.astype(float)) + + return arrow diff --git a/znvis/particle/vector_field.py b/znvis/particle/vector_field.py new file mode 100644 index 0000000..871a3b6 --- /dev/null +++ b/znvis/particle/vector_field.py @@ -0,0 +1,111 @@ +""" +ZnVis: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for the particle parent class +""" + +import typing +from dataclasses import dataclass + +import numpy as np +from rich.progress import track + +from znvis.mesh.arrow import Arrow + + +@dataclass +class VectorField: + """ + A class to represent a vector field. + + Attributes + ---------- + name : str + Name of the vector field + mesh : Mesh + Mesh to use + position : np.ndarray + Position tensor of the shape (n_steps, n_vectors, n_dims) + direction : np.ndarray + Direction tensor of the shape (n_steps, n_vectors, n_dims) + mesh_list : list + A list of mesh objects, one for each time step. + smoothing : bool (default=False) + If true, apply smoothing to each mesh object as it is rendered. + This will slow down the initial construction of the mesh objects + but not the deployment. + """ + + name: str + mesh: Arrow = None # Should be an instance of the Arrow class + position: np.ndarray = None + direction: np.ndarray = None + mesh_list: typing.List[Arrow] = None + + smoothing: bool = False + + def _create_mesh(self, position: np.ndarray, direction: np.ndarray): + """ + Create a mesh object for the vector field. + + Parameters + ---------- + position : np.ndarray + Position of the arrow + direction : np.ndarray + Direction of the arrow + + Returns + ------- + mesh : o3d.geometry.TriangleMesh + A mesh object + """ + + mesh = self.mesh.create_mesh(position, direction) + if self.smoothing: + return mesh.filter_smooth_taubin(100) + else: + return mesh + + def construct_mesh_list(self): + """ + Constructor the mesh list for the class. + + The mesh list is a list of mesh objects for each + time step in the parsed trajectory. + + Returns + ------- + Updates the class attributes mesh_list + """ + self.mesh_list = [] + try: + n_particles = int(self.position.shape[1]) + n_time_steps = int(self.position.shape[0]) + except ValueError: + raise ValueError("There is no data for this vector field.") + + for i in track(range(n_time_steps), description=f"Building {self.name} Mesh"): + for j in range(n_particles): + if j == 0: + mesh = self._create_mesh(self.position[i][j], self.direction[i][j]) + else: + mesh += self._create_mesh(self.position[i][j], self.direction[i][j]) + self.mesh_list.append(mesh) diff --git a/znvis/visualizer/visualizer.py b/znvis/visualizer/visualizer.py index 00f5502..ddfbe62 100644 --- a/znvis/visualizer/visualizer.py +++ b/znvis/visualizer/visualizer.py @@ -61,6 +61,7 @@ class Visualizer: def __init__( self, particles: typing.List[znvis.Particle], + vector_field: typing.List[znvis.VectorField] = None, output_folder: typing.Union[str, pathlib.Path] = "./", frame_rate: int = 24, number_of_steps: int = None, @@ -89,6 +90,7 @@ def __init__( The format of the video to be generated. """ self.particles = particles + self.vector_field = vector_field self.frame_rate = frame_rate self.bounding_box = bounding_box() if bounding_box else None @@ -305,6 +307,11 @@ def _initialize_particles(self): self._draw_particles(initial=True) + def _initialize_vector_field(self): + for item in self.vector_field: + item.construct_mesh_list() + self._draw_vector_field(initial=True) + def _draw_particles(self, visualizer=None, initial: bool = False): """ Draw the particles on the visualizer. @@ -344,6 +351,36 @@ def _draw_particles(self, visualizer=None, initial: bool = False): item.name, item.mesh_list[self.counter], item.mesh.o3d_material ) + + def _draw_vector_field(self, visualizer=None, initial: bool = False): + """ + Draw the vector field on the visualizer. + + Parameters + ---------- + initial : bool (default = True) + If true, no particles are removed. + + Returns + ------- + updates the information in the visualizer. + ----- + """ + if visualizer is None: + visualizer = self.vis + + if initial: + for i, item in enumerate(self.vector_field): + visualizer.add_geometry( + item.name, item.mesh_list[self.counter], item.mesh.o3d_material + ) + else: + for i, item in enumerate(self.vector_field): + visualizer.remove_geometry(item.name) + visualizer.add_geometry( + item.name, item.mesh_list[self.counter], item.mesh.o3d_material + ) + def _continuous_trajectory(self, vis): """ Button command for running the simulation in the visualizer. @@ -509,6 +546,11 @@ def _update_particles(self, visualizer=None, step: int = None): step = self.counter self._draw_particles(visualizer=visualizer) # draw the particles. + + # draw the vector field if it exists. + if self.vector_field is not None: + self._draw_vector_field(visualizer=visualizer) + visualizer.post_redraw() # re-draw the window. def run_visualization(self): @@ -521,6 +563,8 @@ def run_visualization(self): """ self._initialize_app() self._initialize_particles() + if self.vector_field is not None: + self._initialize_vector_field() self.vis.reset_camera_to_default() self.app.run()