Skip to content

Commit

Permalink
simplify draw_elements
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-andreas committed Dec 14, 2020
1 parent f9c8e32 commit 093c72e
Showing 1 changed file with 16 additions and 22 deletions.
38 changes: 16 additions & 22 deletions apace/plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterable
from itertools import zip_longest
from itertools import groupby
from math import inf
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import matplotlib as mpl
import matplotlib.gridspec as grid_spec
Expand All @@ -10,7 +10,6 @@
import numpy as np
from matplotlib.offsetbox import AnchoredOffsetbox, TextArea, VPacker
from matplotlib.path import Path
from matplotlib.ticker import AutoMinorLocator, ScalarFormatter
from matplotlib.widgets import Slider

from .classes import (
Expand All @@ -37,9 +36,10 @@ class Color:
CYAN = "#06B6D4"
WHITE = "white"
BLACK = "black"
LIGHT_GRAY = "#E5E7EB"


ELEMENT_COLOR = {
ELEMENT_COLOR: Dict[type, str] = {
Dipole: Color.YELLOW,
Quadrupole: Color.RED,
Sextupole: Color.GREEN,
Expand Down Expand Up @@ -78,17 +78,16 @@ def draw_elements(
plt.hlines(y0, x_min, x_max, color="black", linewidth=1)
ax.set_ylim(y_min, y_max)

sequence = lattice.sequence
position = start = end = 0
sign = 1
for element, next_element in zip_longest(sequence, sequence[1:]):
position += element.length
if element is next_element or position <= x_min:
sign = -1
start = end = 0
for element, group in groupby(lattice.sequence):
start = end
end += element.length * sum(1 for _ in group)
if end <= x_min:
continue
elif start >= x_max:
break

start, end = end, position
try:
color = ELEMENT_COLOR[type(element)]
except KeyError:
Expand All @@ -109,11 +108,10 @@ def draw_elements(
)
)
if labels and type(element) in {Dipole, Quadrupole}:
# sign = (isinstance(element, Quadrupole) << 1) - 1
sign = -sign
ax.annotate(
element.name,
xy=((start + end) / 2, y0 - sign * rect_height),
xy=((start + end) / 2, y0 + sign * rect_height),
fontsize=FONT_SIZE,
ha="center",
va="center",
Expand All @@ -139,7 +137,7 @@ def draw_sub_lattices(
# if len(ticks) < 5:
# ax.xaxis.set_minor_locator(AutoMinorLocator())
# ax.xaxis.set_minor_formatter(ScalarFormatter())
ax.grid(axis="x", linestyle="--")
ax.grid(axis="x", color=Color.LIGHT_GRAY, linestyle="--", linewidth=1)

if labels:
y_min, y_max = ax.get_ylim()
Expand Down Expand Up @@ -406,10 +404,10 @@ def floor_plan(
end = np.zeros(2)
x_min = y_min = 0
x_max = y_max = 0
sequence = lattice.sequence
sign = 1
for element, next_element in zip_longest(sequence, sequence[1:]):
length = element.length
for element, group in groupby(lattice.sequence):
start = end.copy()
length = element.length * sum(1 for _ in group)
if isinstance(element, Drift):
color = Color.BLACK
line_width = 1
Expand All @@ -420,7 +418,7 @@ def floor_plan(
# TODO: refactor current angle
angle = 0
if isinstance(element, Dipole):
angle = element.angle
angle = element.k0 * length
radius = length / angle
vec = radius * np.array([np.sin(angle), 1 - np.cos(angle)])
sin = np.sin(current_angle)
Expand Down Expand Up @@ -462,8 +460,6 @@ def floor_plan(
y_max = max(y_max, end[1])

ax.add_patch(line) # TODO: currently splitted elements get drawn twice
if element is next_element:
continue

if labels and isinstance(element, (Dipole, Quadrupole)):
angle_center = (current_angle - angle / 2) + np.pi / 2
Expand All @@ -482,8 +478,6 @@ def floor_plan(
zorder=11,
)

start = end.copy()

margin = 0.01 * max((x_max - x_min), (y_max - y_min))
ax.set_xlim(x_min - margin, x_max + margin)
ax.set_ylim(y_min - margin, y_max + margin)
Expand Down

0 comments on commit 093c72e

Please sign in to comment.