Skip to content

Commit

Permalink
Merge pull request #42 from phohenberger/vector_field
Browse files Browse the repository at this point in the history
Add vector field support
  • Loading branch information
SamTov authored May 24, 2024
2 parents 5b16008 + c4fe08a commit 4a043d5
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 0 deletions.
4 changes: 4 additions & 0 deletions znvis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@
from znvis.mesh.custom import CustomMesh
from znvis.mesh.cylinder import Cylinder
from znvis.mesh.sphere import Sphere
from znvis.mesh.arrow import Arrow
from znvis.particle.particle import Particle
from znvis.particle.vector_field import VectorField
from znvis.visualizer.visualizer import Visualizer

__all__ = [
Particle.__name__,
Sphere.__name__,
Arrow.__name__,
VectorField.__name__,
Visualizer.__name__,
Cylinder.__name__,
CustomMesh.__name__,
Expand Down
89 changes: 89 additions & 0 deletions znvis/mesh/arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
ZnVis: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Create a sphere mesh.
"""

from dataclasses import dataclass

import numpy as np
import open3d as o3d

from znvis.transformations.rotation_matrices import rotation_matrix

from znvis.mesh import Mesh


@dataclass
class Arrow(Mesh):
"""
A class to produce arrow meshes.
Attributes
----------
scale : float
Scale of the arrow
resolution : int
Resolution of the mesh.
"""
scale: float = 1.0
resolution: int = 10

def create_mesh(
self, starting_position: np.ndarray, direction: np.ndarray = None
) -> o3d.geometry.TriangleMesh:
"""
Create a mesh object defined by the dataclass.
Parameters
----------
starting_position : np.ndarray shape=(3,)
Starting position of the mesh.
direction : np.ndarray shape=(3,) (default = None)
Direction of the mesh.
Returns
-------
mesh : o3d.geometry.TriangleMesh
"""

direction_length = np.linalg.norm(direction)

cylinder_radius = 0.06 * direction_length * self.scale
cylinder_height = 0.85 * direction_length * self.scale
cone_radius = 0.15 * direction_length * self.scale
cone_height = 0.15 * direction_length * self.scale

arrow = o3d.geometry.TriangleMesh.create_arrow(
cylinder_radius=cylinder_radius,
cylinder_height=cylinder_height,
cone_radius=cone_radius,
cone_height=cone_height,
resolution=self.resolution
)

arrow.compute_vertex_normals()
matrix = rotation_matrix(np.array([0, 0, 1]), direction)
arrow.rotate(matrix, center=(0, 0, 0))

# Translate the arrow to the starting position and center the origin
arrow.translate(starting_position.astype(float))

return arrow
111 changes: 111 additions & 0 deletions znvis/particle/vector_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
ZnVis: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Module for the particle parent class
"""

import typing
from dataclasses import dataclass

import numpy as np
from rich.progress import track

from znvis.mesh.arrow import Arrow


@dataclass
class VectorField:
"""
A class to represent a vector field.
Attributes
----------
name : str
Name of the vector field
mesh : Mesh
Mesh to use
position : np.ndarray
Position tensor of the shape (n_steps, n_vectors, n_dims)
direction : np.ndarray
Direction tensor of the shape (n_steps, n_vectors, n_dims)
mesh_list : list
A list of mesh objects, one for each time step.
smoothing : bool (default=False)
If true, apply smoothing to each mesh object as it is rendered.
This will slow down the initial construction of the mesh objects
but not the deployment.
"""

name: str
mesh: Arrow = None # Should be an instance of the Arrow class
position: np.ndarray = None
direction: np.ndarray = None
mesh_list: typing.List[Arrow] = None

smoothing: bool = False

def _create_mesh(self, position: np.ndarray, direction: np.ndarray):
"""
Create a mesh object for the vector field.
Parameters
----------
position : np.ndarray
Position of the arrow
direction : np.ndarray
Direction of the arrow
Returns
-------
mesh : o3d.geometry.TriangleMesh
A mesh object
"""

mesh = self.mesh.create_mesh(position, direction)
if self.smoothing:
return mesh.filter_smooth_taubin(100)
else:
return mesh

def construct_mesh_list(self):
"""
Constructor the mesh list for the class.
The mesh list is a list of mesh objects for each
time step in the parsed trajectory.
Returns
-------
Updates the class attributes mesh_list
"""
self.mesh_list = []
try:
n_particles = int(self.position.shape[1])
n_time_steps = int(self.position.shape[0])
except ValueError:
raise ValueError("There is no data for this vector field.")

for i in track(range(n_time_steps), description=f"Building {self.name} Mesh"):
for j in range(n_particles):
if j == 0:
mesh = self._create_mesh(self.position[i][j], self.direction[i][j])
else:
mesh += self._create_mesh(self.position[i][j], self.direction[i][j])
self.mesh_list.append(mesh)
44 changes: 44 additions & 0 deletions znvis/visualizer/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Visualizer:
def __init__(
self,
particles: typing.List[znvis.Particle],
vector_field: typing.List[znvis.VectorField] = None,
output_folder: typing.Union[str, pathlib.Path] = "./",
frame_rate: int = 24,
number_of_steps: int = None,
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
The format of the video to be generated.
"""
self.particles = particles
self.vector_field = vector_field
self.frame_rate = frame_rate
self.bounding_box = bounding_box() if bounding_box else None

Expand Down Expand Up @@ -305,6 +307,11 @@ def _initialize_particles(self):

self._draw_particles(initial=True)

def _initialize_vector_field(self):
for item in self.vector_field:
item.construct_mesh_list()
self._draw_vector_field(initial=True)

def _draw_particles(self, visualizer=None, initial: bool = False):
"""
Draw the particles on the visualizer.
Expand Down Expand Up @@ -344,6 +351,36 @@ def _draw_particles(self, visualizer=None, initial: bool = False):
item.name, item.mesh_list[self.counter], item.mesh.o3d_material
)


def _draw_vector_field(self, visualizer=None, initial: bool = False):
"""
Draw the vector field on the visualizer.
Parameters
----------
initial : bool (default = True)
If true, no particles are removed.
Returns
-------
updates the information in the visualizer.
-----
"""
if visualizer is None:
visualizer = self.vis

if initial:
for i, item in enumerate(self.vector_field):
visualizer.add_geometry(
item.name, item.mesh_list[self.counter], item.mesh.o3d_material
)
else:
for i, item in enumerate(self.vector_field):
visualizer.remove_geometry(item.name)
visualizer.add_geometry(
item.name, item.mesh_list[self.counter], item.mesh.o3d_material
)

def _continuous_trajectory(self, vis):
"""
Button command for running the simulation in the visualizer.
Expand Down Expand Up @@ -509,6 +546,11 @@ def _update_particles(self, visualizer=None, step: int = None):
step = self.counter

self._draw_particles(visualizer=visualizer) # draw the particles.

# draw the vector field if it exists.
if self.vector_field is not None:
self._draw_vector_field(visualizer=visualizer)

visualizer.post_redraw() # re-draw the window.

def run_visualization(self):
Expand All @@ -521,6 +563,8 @@ def run_visualization(self):
"""
self._initialize_app()
self._initialize_particles()
if self.vector_field is not None:
self._initialize_vector_field()

self.vis.reset_camera_to_default()
self.app.run()

0 comments on commit 4a043d5

Please sign in to comment.