Skip to content

Commit

Permalink
Add matplotlib, safeguard in plotter.py
Browse files Browse the repository at this point in the history
  • Loading branch information
robertodr committed Aug 26, 2023
1 parent c67a6ae commit 3228706
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy
matplotlib
pytest
pybind11-global
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ setup_requires =

install_requires =
numpy >= 1.15.0
matplotlib

test_suite =
tests
Expand Down
41 changes: 30 additions & 11 deletions src/vampyr/plotter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

try:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
except ImportError:
from warnings import warn

warn(
"Please install matplotlib to use plotting functionality!",
UserWarning,
stacklevel=2,
)


def plot_surface_xy(x=0.0, y=0.0, z=0.0, length=1.0):
Expand Down Expand Up @@ -46,8 +56,10 @@ def plot_cube(corner, length):


def grid_plotter(tree=None, dpi=150, lw=0.08, color=(1, 0, 0, 0.01)):
assert len(tree.MRA().world().upperBounds()) == 3, "basis plotter only works for 3D FunctionTrees"
fig, ax = plt.subplots(figsize=(6, 6), dpi=150, subplot_kw={'projection': '3d'})
assert (
len(tree.MRA().world().upperBounds()) == 3
), "basis plotter only works for 3D FunctionTrees"
fig, ax = plt.subplots(figsize=(6, 6), dpi=150, subplot_kw={"projection": "3d"})
ax.grid(False)
ax.axis("off")

Expand All @@ -60,21 +72,25 @@ def grid_plotter(tree=None, dpi=150, lw=0.08, color=(1, 0, 0, 0.01)):
for i in range(tree.nEndNodes()):
data = plot_cube(corners[i], lengths[i])
for d in data:
ax.plot_surface(d[0], d[1], d[2], color=color, edgecolor='black', lw=lw)
ax.plot_surface(d[0], d[1], d[2], color=color, edgecolor="black", lw=lw)

return fig, ax


def representation_vs_basis(tree, type="scaling"):
assert len(tree.MRA().world().upperBounds()) == 1, "basis plotter only works for 1D FunctionTrees"
assert (
len(tree.MRA().world().upperBounds()) == 1
), "basis plotter only works for 1D FunctionTrees"
mra = tree.MRA()
k = mra.basis().scalingOrder()
n = tree.depth() - 1

upper_bound = mra.world().upperBound(0)
lower_bound = mra.world().lowerBound(0)
x = np.arange(lower_bound, upper_bound, 0.001)
y = [tree([_]) for _ in x] # Plot f1 to f4 to see how the function representation improves.
y = [
tree([_]) for _ in x
] # Plot f1 to f4 to see how the function representation improves.

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[1].title.set_text(f"Basis")
Expand All @@ -95,10 +111,13 @@ def representation_vs_basis(tree, type="scaling"):
n = idx.scale()
l = idx.translation()[0]

for i in range(k+1):
y = [basis_polys(i=i, l=l, n=n)([x]) if basis_polys(i=i, l=l, n=n)([x]) != 0.0 else np.nan for x in x]
for i in range(k + 1):
y = [
basis_polys(i=i, l=l, n=n)([x])
if basis_polys(i=i, l=l, n=n)([x]) != 0.0
else np.nan
for x in x
]
ax[1].plot(x, y)

return fig, ax


0 comments on commit 3228706

Please sign in to comment.