Skip to content

Commit

Permalink
Patch CI test for runtime gain (#49)
Browse files Browse the repository at this point in the history
* enhanced debugging

* mask runtime gains not reliably large in CS testing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jpn-- and pre-commit-ci[bot] authored Mar 26, 2024
1 parent 8cbab1a commit 3c3deda
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/walkthrough/one-dim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1561,7 +1561,7 @@
" ),\n",
" number=1,\n",
" )\n",
"assert masked_time * 2 < raw_time # generous buffer, should be nearly 7 times faster\n",
"assert masked_time < raw_time # generous, should be nearly 7 times faster\n",
"assert len(wide_flow.cache_misses[\"_imnl_plus1d\"]) == 3"
]
}
Expand Down
53 changes: 35 additions & 18 deletions sharrow/aster.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def __init__(
bool_wrapping=False,
swallow_errors=False,
get_default=False,
original_expr="???",
):
self.spacename = spacename
self.dim_slots = dim_slots
Expand All @@ -355,6 +356,7 @@ def __init__(
self.bool_wrapping = bool_wrapping
self.swallow_errors = swallow_errors
self.get_default = get_default
self.original_expr = original_expr

def log_event(self, tag, node1=None, node2=None):
if logger.getEffectiveLevel() <= 0:
Expand Down Expand Up @@ -1003,7 +1005,9 @@ def visit_Compare(self, node):
right_decoded = None
warnings.warn(
f"right hand value {right.value!r} not found in "
f"categories for {left_varname} in {self.spacename}",
f"categories for {left_varname} in {self.spacename}"
f"\nexpression: {self.original_expr}"
f"\ncategories: {left_dictionary}",
stacklevel=2,
)
if right_decoded is not None:
Expand Down Expand Up @@ -1034,7 +1038,9 @@ def visit_Compare(self, node):
left_decoded = None
warnings.warn(
f"left hand value {left.value!r} not found in "
f"categories for {right_varname} in {self.spacename}",
f"categories for {right_varname} in {self.spacename}"
f"\nexpression: {self.original_expr}"
f"\ncategories: {right_dictionary}",
stacklevel=2,
)
if left_decoded is not None:
Expand Down Expand Up @@ -1069,6 +1075,7 @@ def expression_for_numba(
bool_wrapping=False,
swallow_errors=False,
get_default=False,
original_expr=None,
):
"""
Rewrite an expression so numba can compile it.
Expand All @@ -1087,27 +1094,37 @@ def expression_for_numba(
prefer_name : str, optional
extra_vars : Mapping, optional
blenders : Mapping, optional
bool_wrapping : bool, optional
swallow_errors : bool, optional
get_default : bool, optional
original_expr : str, optional
Original (pre-processing) expression, used for debugging
Returns
-------
str
"""
return unparse_(
RewriteForNumba(
spacename,
dim_slots,
spacevars,
rawname,
rawalias,
digital_encodings,
prefer_name,
extra_vars,
blenders,
bool_wrapping,
swallow_errors,
get_default,
).visit(ast.parse(expr))
)
with warnings.catch_warnings(record=True) as warning_list:
result = unparse_(
RewriteForNumba(
spacename,
dim_slots,
spacevars,
rawname,
rawalias,
digital_encodings,
prefer_name,
extra_vars,
blenders,
bool_wrapping,
swallow_errors,
get_default,
original_expr=original_expr or expr,
).visit(ast.parse(expr))
)
for warning in warning_list:
warnings.warn(warning.message, warning.category, stacklevel=2)
return result


class Asterize:
Expand Down
9 changes: 9 additions & 0 deletions sharrow/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,7 @@ def init_sub_funcs(
extra_vars=self.tree.extra_vars,
blenders=blenders,
bool_wrapping=self.bool_wrapping,
original_expr=init_expr,
)
except KeyError as key_err:
# there was an error, but lets make sure we process the
Expand All @@ -1306,6 +1307,7 @@ def init_sub_funcs(
blenders=blenders,
bool_wrapping=self.bool_wrapping,
swallow_errors=True,
original_expr=init_expr,
)
# Now for the fallback processing...
if ".." in key_err.args[0]:
Expand All @@ -1330,6 +1332,7 @@ def init_sub_funcs(
extra_vars=self.tree.extra_vars,
blenders=blenders,
bool_wrapping=self.bool_wrapping,
original_expr=init_expr,
)
except KeyError as err: # noqa: F841
pass
Expand All @@ -1350,6 +1353,7 @@ def init_sub_funcs(
blenders=blenders,
bool_wrapping=self.bool_wrapping,
get_default=True,
original_expr=init_expr,
)
except KeyError as err: # noqa: F841
pass
Expand All @@ -1376,6 +1380,7 @@ def init_sub_funcs(
blenders=blenders,
bool_wrapping=self.bool_wrapping,
get_default=True,
original_expr=init_expr,
)
except KeyError as err: # noqa: F841
pass
Expand Down Expand Up @@ -1417,6 +1422,7 @@ def init_sub_funcs(
blenders=blenders,
bool_wrapping=self.bool_wrapping,
get_default=gd,
original_expr=init_expr,
)
except KeyError:
# there was an error, but lets make sure we process the
Expand All @@ -1434,6 +1440,7 @@ def init_sub_funcs(
bool_wrapping=self.bool_wrapping,
swallow_errors=True,
get_default=gd,
original_expr=init_expr,
)

# now find instances where an identifier is previously created in this flow.
Expand All @@ -1445,6 +1452,7 @@ def init_sub_funcs(
"_outputs",
extra_vars=self.tree.extra_vars,
bool_wrapping=self.bool_wrapping,
original_expr=init_expr,
)

aux_tokens = {
Expand All @@ -1461,6 +1469,7 @@ def init_sub_funcs(
prefer_name="aux_var",
extra_vars=self.tree.extra_vars,
bool_wrapping=self.bool_wrapping,
original_expr=init_expr,
)

if (k == init_expr) and (init_expr == expr) and k.isidentifier():
Expand Down

0 comments on commit 3c3deda

Please sign in to comment.