-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #42 from phohenberger/vector_field
Add vector field support
- Loading branch information
Showing
4 changed files
with
248 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
""" | ||
ZnVis: A Zincwarecode package. | ||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
SPDX-License-Identifier: EPL-2.0 | ||
Copyright Contributors to the Zincwarecode Project. | ||
Contact Information | ||
------------------- | ||
email: zincwarecode@gmail.com | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
Summary | ||
------- | ||
Create a sphere mesh. | ||
""" | ||
|
||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
import open3d as o3d | ||
|
||
from znvis.transformations.rotation_matrices import rotation_matrix | ||
|
||
from znvis.mesh import Mesh | ||
|
||
|
||
@dataclass | ||
class Arrow(Mesh): | ||
""" | ||
A class to produce arrow meshes. | ||
Attributes | ||
---------- | ||
scale : float | ||
Scale of the arrow | ||
resolution : int | ||
Resolution of the mesh. | ||
""" | ||
scale: float = 1.0 | ||
resolution: int = 10 | ||
|
||
def create_mesh( | ||
self, starting_position: np.ndarray, direction: np.ndarray = None | ||
) -> o3d.geometry.TriangleMesh: | ||
""" | ||
Create a mesh object defined by the dataclass. | ||
Parameters | ||
---------- | ||
starting_position : np.ndarray shape=(3,) | ||
Starting position of the mesh. | ||
direction : np.ndarray shape=(3,) (default = None) | ||
Direction of the mesh. | ||
Returns | ||
------- | ||
mesh : o3d.geometry.TriangleMesh | ||
""" | ||
|
||
direction_length = np.linalg.norm(direction) | ||
|
||
cylinder_radius = 0.06 * direction_length * self.scale | ||
cylinder_height = 0.85 * direction_length * self.scale | ||
cone_radius = 0.15 * direction_length * self.scale | ||
cone_height = 0.15 * direction_length * self.scale | ||
|
||
arrow = o3d.geometry.TriangleMesh.create_arrow( | ||
cylinder_radius=cylinder_radius, | ||
cylinder_height=cylinder_height, | ||
cone_radius=cone_radius, | ||
cone_height=cone_height, | ||
resolution=self.resolution | ||
) | ||
|
||
arrow.compute_vertex_normals() | ||
matrix = rotation_matrix(np.array([0, 0, 1]), direction) | ||
arrow.rotate(matrix, center=(0, 0, 0)) | ||
|
||
# Translate the arrow to the starting position and center the origin | ||
arrow.translate(starting_position.astype(float)) | ||
|
||
return arrow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
ZnVis: A Zincwarecode package. | ||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
SPDX-License-Identifier: EPL-2.0 | ||
Copyright Contributors to the Zincwarecode Project. | ||
Contact Information | ||
------------------- | ||
email: zincwarecode@gmail.com | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
Summary | ||
------- | ||
Module for the particle parent class | ||
""" | ||
|
||
import typing | ||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
from rich.progress import track | ||
|
||
from znvis.mesh.arrow import Arrow | ||
|
||
|
||
@dataclass | ||
class VectorField: | ||
""" | ||
A class to represent a vector field. | ||
Attributes | ||
---------- | ||
name : str | ||
Name of the vector field | ||
mesh : Mesh | ||
Mesh to use | ||
position : np.ndarray | ||
Position tensor of the shape (n_steps, n_vectors, n_dims) | ||
direction : np.ndarray | ||
Direction tensor of the shape (n_steps, n_vectors, n_dims) | ||
mesh_list : list | ||
A list of mesh objects, one for each time step. | ||
smoothing : bool (default=False) | ||
If true, apply smoothing to each mesh object as it is rendered. | ||
This will slow down the initial construction of the mesh objects | ||
but not the deployment. | ||
""" | ||
|
||
name: str | ||
mesh: Arrow = None # Should be an instance of the Arrow class | ||
position: np.ndarray = None | ||
direction: np.ndarray = None | ||
mesh_list: typing.List[Arrow] = None | ||
|
||
smoothing: bool = False | ||
|
||
def _create_mesh(self, position: np.ndarray, direction: np.ndarray): | ||
""" | ||
Create a mesh object for the vector field. | ||
Parameters | ||
---------- | ||
position : np.ndarray | ||
Position of the arrow | ||
direction : np.ndarray | ||
Direction of the arrow | ||
Returns | ||
------- | ||
mesh : o3d.geometry.TriangleMesh | ||
A mesh object | ||
""" | ||
|
||
mesh = self.mesh.create_mesh(position, direction) | ||
if self.smoothing: | ||
return mesh.filter_smooth_taubin(100) | ||
else: | ||
return mesh | ||
|
||
def construct_mesh_list(self): | ||
""" | ||
Constructor the mesh list for the class. | ||
The mesh list is a list of mesh objects for each | ||
time step in the parsed trajectory. | ||
Returns | ||
------- | ||
Updates the class attributes mesh_list | ||
""" | ||
self.mesh_list = [] | ||
try: | ||
n_particles = int(self.position.shape[1]) | ||
n_time_steps = int(self.position.shape[0]) | ||
except ValueError: | ||
raise ValueError("There is no data for this vector field.") | ||
|
||
for i in track(range(n_time_steps), description=f"Building {self.name} Mesh"): | ||
for j in range(n_particles): | ||
if j == 0: | ||
mesh = self._create_mesh(self.position[i][j], self.direction[i][j]) | ||
else: | ||
mesh += self._create_mesh(self.position[i][j], self.direction[i][j]) | ||
self.mesh_list.append(mesh) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters