Skip to content

Commit

Permalink
BUG: make annotate_sphere and annotate_arrow safe when run after plot…
Browse files Browse the repository at this point in the history
… invalidation (#4699)

Co-authored-by: Clément Robert <cr52@protonmail.com>
  • Loading branch information
chrishavlin and neutrinoceros authored Oct 11, 2023
1 parent 762f902 commit 21da682
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
36 changes: 20 additions & 16 deletions yt/visualization/plot_modifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,11 +1806,12 @@ def __call__(self, plot):
xx0, xx1, yy0, yy1 = self._plot_bounds(plot)
# normalize all of the kwarg lengths to the plot size
plot_diag = ((yy1 - yy0) ** 2 + (xx1 - xx0) ** 2) ** (0.5)
self.length *= plot_diag
self.width *= plot_diag
self.head_width *= plot_diag
length = self.length * plot_diag
width = self.width * plot_diag
head_width = self.head_width * plot_diag
head_length = None
if self.head_length is not None:
self.head_length *= plot_diag
head_length = self.head_length * plot_diag

if self.starting_pos is not None:
start_x, start_y = self._sanitize_coord_system(
Expand All @@ -1819,8 +1820,8 @@ def __call__(self, plot):
dx = x - start_x
dy = y - start_y
else:
dx = (xx1 - xx0) * 2 ** (0.5) * self.length
dy = (yy1 - yy0) * 2 ** (0.5) * self.length
dx = (xx1 - xx0) * 2 ** (0.5) * length
dy = (yy1 - yy0) * 2 ** (0.5) * length
# If the arrow is 0 length
if dx == dy == 0:
warnings.warn("The arrow has zero length. Not annotating.", stacklevel=2)
Expand All @@ -1833,9 +1834,9 @@ def __call__(self, plot):
y - dy,
dx,
dy,
width=self.width,
head_width=self.head_width,
head_length=self.head_length,
width=width,
head_width=head_width,
head_length=head_length,
transform=self.transform,
length_includes_head=True,
**self.plot_args,
Expand All @@ -1847,9 +1848,9 @@ def __call__(self, plot):
y[i] - dy,
dx,
dy,
width=self.width,
head_width=self.head_width,
head_length=self.head_length,
width=width,
head_width=head_width,
head_length=head_length,
transform=self.transform,
length_includes_head=True,
**self.plot_args,
Expand Down Expand Up @@ -2032,7 +2033,8 @@ def __call__(self, plot):

if is_sequence(self.radius):
self.radius = plot.data.ds.quan(self.radius[0], self.radius[1])
self.radius = np.float64(self.radius.in_units(plot.xlim[0].units))
self.radius = self.radius.in_units(plot.xlim[0].units)

if isinstance(self.radius, YTQuantity):
if isinstance(self.center, YTArray):
units = self.center.units
Expand All @@ -2045,16 +2047,18 @@ def __call__(self, plot):
# apply a different transform for a length in the same way
# you can for a coordinate.
if self.coord_system == "data" or self.coord_system == "plot":
self.radius = self.radius * self._pixel_scale(plot)[0]
scaled_radius = self.radius * self._pixel_scale(plot)[0]
else:
self.radius /= (plot.xlim[1] - plot.xlim[0]).v
scaled_radius = self.radius / (plot.xlim[1] - plot.xlim[0])

x, y = self._sanitize_coord_system(
plot, self.center, coord_system=self.coord_system
)

x, y = self._sanitize_xy_order(plot, x, y)
cir = Circle((x, y), self.radius, transform=self.transform, **self.circle_args)
cir = Circle(
(x, y), scaled_radius.v, transform=self.transform, **self.circle_args
)

plot._axes.add_patch(cir)
if self.text is not None:
Expand Down
17 changes: 17 additions & 0 deletions yt/visualization/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,23 @@ def test_sphere_callback():
assert_fname(p.save(prefix)[0])


def test_invalidated_annotations():
# check that annotate_sphere and annotate_arrow succeed on re-running after
# an operation that invalidates the plot (set_font_size), see
# https://github.com/yt-project/yt/issues/4698

ds = fake_amr_ds(fields=("density",), units=("g/cm**3",))
p = SlicePlot(ds, "z", ("gas", "density"))
p.annotate_sphere([0.5, 0.5, 0.5], 0.1)
p.set_font_size(24)
p.render()

p = SlicePlot(ds, "z", ("gas", "density"))
p.annotate_arrow([0.5, 0.5, 0.5])
p.set_font_size(24)
p.render()


def test_text_callback():
with _cleanup_fname() as prefix:
ax = "z"
Expand Down

0 comments on commit 21da682

Please sign in to comment.