Skip to content

Commit

Permalink
Add matplotlib, safeguard in plotter.py
Browse files Browse the repository at this point in the history
  • Loading branch information
robertodr committed Aug 26, 2023
1 parent fcc81dc commit 7e17dc5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy
matplotlib
pytest
pybind11-global
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ setup_requires =

install_requires =
numpy >= 1.15.0
matplotlib

test_suite =
tests
Expand Down
41 changes: 30 additions & 11 deletions src/vampyr/plotter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import matplotlib.pyplot as plt
import numpy as np

Check warning on line 1 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L1

Added line #L1 was not covered by tests
from mpl_toolkits.mplot3d import Axes3D

try:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
except ImportError:
from warnings import warn

Check warning on line 7 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L3-L7

Added lines #L3 - L7 were not covered by tests

warn(

Check warning on line 9 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L9

Added line #L9 was not covered by tests
"Please install matplotlib to use plotting functionality!",
UserWarning,
stacklevel=2,
)


def plot_surface_xy(x=0.0, y=0.0, z=0.0, length=1.0):
Expand Down Expand Up @@ -46,8 +56,10 @@ def plot_cube(corner, length):


def grid_plotter(tree=None, dpi=150, lw=0.08, color=(1, 0, 0, 0.01)):
assert len(tree.MRA().world().upperBounds()) == 3, "basis plotter only works for 3D FunctionTrees"
fig, ax = plt.subplots(figsize=(6, 6), dpi=150, subplot_kw={'projection': '3d'})
assert (

Check warning on line 59 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L58-L59

Added lines #L58 - L59 were not covered by tests
len(tree.MRA().world().upperBounds()) == 3
), "basis plotter only works for 3D FunctionTrees"
fig, ax = plt.subplots(figsize=(6, 6), dpi=150, subplot_kw={"projection": "3d"})
ax.grid(False)
ax.axis("off")

Check warning on line 64 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L62-L64

Added lines #L62 - L64 were not covered by tests

Expand All @@ -60,21 +72,25 @@ def grid_plotter(tree=None, dpi=150, lw=0.08, color=(1, 0, 0, 0.01)):
for i in range(tree.nEndNodes()):
data = plot_cube(corners[i], lengths[i])
for d in data:
ax.plot_surface(d[0], d[1], d[2], color=color, edgecolor='black', lw=lw)
ax.plot_surface(d[0], d[1], d[2], color=color, edgecolor="black", lw=lw)

Check warning on line 75 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L72-L75

Added lines #L72 - L75 were not covered by tests

return fig, ax

Check warning on line 77 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L77

Added line #L77 was not covered by tests


def representation_vs_basis(tree, type="scaling"):
assert len(tree.MRA().world().upperBounds()) == 1, "basis plotter only works for 1D FunctionTrees"
assert (

Check warning on line 81 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L80-L81

Added lines #L80 - L81 were not covered by tests
len(tree.MRA().world().upperBounds()) == 1
), "basis plotter only works for 1D FunctionTrees"
mra = tree.MRA()
k = mra.basis().scalingOrder()
n = tree.depth() - 1

Check warning on line 86 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L84-L86

Added lines #L84 - L86 were not covered by tests

upper_bound = mra.world().upperBound(0)
lower_bound = mra.world().lowerBound(0)
x = np.arange(lower_bound, upper_bound, 0.001)
y = [tree([_]) for _ in x] # Plot f1 to f4 to see how the function representation improves.
y = [

Check warning on line 91 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L88-L91

Added lines #L88 - L91 were not covered by tests
tree([_]) for _ in x
] # Plot f1 to f4 to see how the function representation improves.

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[1].title.set_text(f"Basis")
Expand All @@ -95,10 +111,13 @@ def representation_vs_basis(tree, type="scaling"):
n = idx.scale()
l = idx.translation()[0]

Check warning on line 112 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L108-L112

Added lines #L108 - L112 were not covered by tests

for i in range(k+1):
y = [basis_polys(i=i, l=l, n=n)([x]) if basis_polys(i=i, l=l, n=n)([x]) != 0.0 else np.nan for x in x]
for i in range(k + 1):
y = [

Check warning on line 115 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L114-L115

Added lines #L114 - L115 were not covered by tests
basis_polys(i=i, l=l, n=n)([x])
if basis_polys(i=i, l=l, n=n)([x]) != 0.0
else np.nan
for x in x
]
ax[1].plot(x, y)

Check warning on line 121 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L121

Added line #L121 was not covered by tests

return fig, ax

Check warning on line 123 in src/vampyr/plotter.py

View check run for this annotation

Codecov / codecov/patch

src/vampyr/plotter.py#L123

Added line #L123 was not covered by tests


0 comments on commit 7e17dc5

Please sign in to comment.