Skip to content

Commit

Permalink
fix: Use zip(..., strict=True) instead of manually checking array shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
adigitoleo committed Oct 18, 2023
1 parent 48ecf30 commit f095c6e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 23 deletions.
8 changes: 5 additions & 3 deletions src/pydrex/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def read_scsv(file):
x,
)
)
for f, fill, x in zip(coltypes, fillvals, zip(*list(reader)))
for f, fill, x in zip(
coltypes, fillvals, zip(*list(reader), strict=True), strict=True
)
]
)

Expand Down Expand Up @@ -204,9 +206,9 @@ def save_scsv(file, schema, data, **kwargs):
stream, delimiter=schema["delimiter"], lineterminator=os.linesep
)
writer.writerow(names)
for col in zip(*data):
for col in zip(*data, strict=True):
row = []
for i, (d, t, f) in enumerate(zip(col, types, fills)):
for i, (d, t, f) in enumerate(zip(col, types, fills, strict=True)):
try:
_parse_scsv_cell(
t, str(d), missingstr=schema["missing"], fillval=f
Expand Down
2 changes: 1 addition & 1 deletion src/pydrex/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def resample_orientations(orientations, fractions, n_samples=None, seed=None):
n_samples = _fractions.shape[1]
out_orientations = np.empty((len(_fractions), n_samples, 3, 3))
out_fractions = np.empty((len(_fractions), n_samples))
for i, (frac, orient) in enumerate(zip(_fractions, _orientations)):
for i, (frac, orient) in enumerate(zip(_fractions, _orientations, strict=True)):
sort_ascending = np.argsort(frac)
# Cumulative volume fractions.
frac_ascending = frac[sort_ascending]
Expand Down
31 changes: 14 additions & 17 deletions src/pydrex/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def polefigures(
density_kwargs=kwargs,
)
if density:
for ax, pf in zip((ax100, ax010, ax001), (pf100, pf010, pf001)):
for ax, pf in zip(
(ax100, ax010, ax001), (pf100, pf010, pf001), strict=True
):
cbar = fig.colorbar(
pf,
ax=ax,
Expand Down Expand Up @@ -205,7 +207,7 @@ def pathline_box2d(

U = np.zeros_like(X_grid.ravel())
V = np.zeros_like(Y_grid.ravel())
for i, (x, y) in enumerate(zip(X_grid.ravel(), Y_grid.ravel())):
for i, (x, y) in enumerate(zip(X_grid.ravel(), Y_grid.ravel(), strict=True)):
p = np.zeros(3)
p[horizontal] = x
p[vertical] = y
Expand All @@ -230,7 +232,7 @@ def pathline_box2d(
C = np.asarray(
[
f * np.asarray([c[horizontal], c[vertical]])
for f, c in zip(cpo_strengths, cpo_vectors)
for f, c in zip(cpo_strengths, cpo_vectors, strict=True)
]
)
cpo = ax.quiver(
Expand Down Expand Up @@ -294,12 +296,11 @@ def alignment(
"""
_strains = np.atleast_2d(strains)
_angles = np.atleast_2d(angles)
if len(_strains) != len(_angles) != len(markers) != len(labels):
raise ValueError("mismatch in input dimensions")
if err is not None:
_angles_err = np.atleast_2d(err)
if not np.all(_angles.shape == _angles_err.shape):
raise ValueError("mismatch in shapes of `angles` and `err`")
if not np.all(_strains.shape == _angles.shape):
# Assume strains are all the same for each series in `angles`, try np.tile().
_strains = np.tile(_strains, (len(_angles), 1))

fig, ax = figure_unless(ax)
ax.set_ylabel("Mean angle ∈ [0, 90]°")
Expand All @@ -308,7 +309,7 @@ def alignment(
ax.set_xlim((np.min(strains), np.max(strains)))
_colors = []
for i, (strains, θ_cpo, marker, label) in enumerate(
zip(_strains, _angles, markers, labels)
zip(_strains, _angles, markers, labels, strict=True)
):
if colors is not None:
ax.scatter(
Expand Down Expand Up @@ -379,12 +380,11 @@ def strengths(
"""
_strains = np.atleast_2d(strains)
_strengths = np.atleast_2d(strengths)
if len(_strains) != len(_strengths) != len(markers) != len(labels):
raise ValueError("mismatch in input dimensions")
if err is not None:
_strengths_err = np.atleast_2d(err)
if not np.all(_strengths.shape == _strengths_err.shape):
raise ValueError("mismatch in shapes of `strengths` and `err`")
if not np.all(_strains.shape == _strengths_err.shape):
# Assume strains are all the same for each series in `strengths`, try np.tile().
_strains = np.tile(_strains, (len(_strengths), 1))

fig, ax = figure_unless(ax)
ax.set_ylabel(ylabel)
Expand All @@ -396,7 +396,7 @@ def strengths(

_colors = []
for i, (strains, strength, marker, label) in enumerate(
zip(_strains, _strengths, markers, labels)
zip(_strains, _strengths, markers, labels, strict=True)
):
if colors is not None:
ax.scatter(
Expand Down Expand Up @@ -465,15 +465,12 @@ def show_Skemer2016_ShearStrainAngles(ax, studies, markers, colors, fillstyles,
the data series plots.
"""
if len(studies) != len(markers) != len(colors) != len(fillstyles) != len(labels):
raise ValueError("mismatch in lengths of inputs")
fig, ax = figure_unless(ax)

data_Skemer2016 = _io.read_scsv(
_io.data("thirdparty") / "Skemer2016_ShearStrainAngles.scsv"
)
for study, marker, color, fillstyle, label in zip(
studies, markers, colors, fillstyles, labels
studies, markers, colors, fillstyles, labels, strict=True
):
# Note: np.nonzero returns a tuple.
indices = np.nonzero(np.asarray(data_Skemer2016.study) == study)[0]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_vortex_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_xz(self, outdir, seed, n_grains):
_diagnostics.smallest_angle(
_diagnostics.bingham_average(a, axis="a"), get_velocity(x)
)
for a, x in zip(mineral.orientations, positions)
for a, x in zip(mineral.orientations, positions, strict=True)
]
if outdir is not None:
# First figure with the domain and pathline.
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_xz_ensemble(self, outdir, seeds_nearX45, ncpus, n_grains):
_diagnostics.smallest_angle(
_diagnostics.bingham_average(a, axis="a"), get_velocity(x)
)
for a, x in zip(mineral.orientations, positions)
for a, x in zip(mineral.orientations, positions, strict=True)
]
max_sizes[s] = np.log10(np.max(mineral.fractions, axis=1) * n_grains)

Expand Down

0 comments on commit f095c6e

Please sign in to comment.