Skip to content

Commit

Permalink
fix issues
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Aug 29, 2024
1 parent 3db5fdb commit 710fd0c
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 24 deletions.
2 changes: 1 addition & 1 deletion _unittests/ut_tools/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def test_bdn_in_bdi(self):
inp1 = numpy.arange(2 * 3 * 5).reshape((2, 3, 5)).astype(numpy.float32)
inp2 = numpy.arange(5 * 7).reshape((7, 5)).astype(numpy.float32)
exp = numpy.einsum(equation, inp1, inp2)
got = apply_einsum_sequence(seq, inp1, inp2)
got = apply_einsum_sequence(seq, inp1, inp2, verbose=0)
self.assertEqualArray(exp, got)

onx = seq.to_onnx("Y", "X1", "X2")
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_tools/test_einsum_bug.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def common_test_equation(self, equation, dim1, dim2):
f.write(onx.SerializeToString())
a = numpy.random.rand(*list((2,) * dim1))
b = numpy.random.rand(*list((2,) * dim2))
oinf = CReferenceEvaluator(onx)
oinf = CReferenceEvaluator(onx, verbose=0)
got = oinf.run(None, {"X1": a, "X2": b})
expected = numpy.einsum(equation, a, b)
self.assertEqualArray(expected, got[0], atol=1e-15)
Expand Down
14 changes: 8 additions & 6 deletions onnx_extended/tools/einsum/einsum_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,10 @@ def _decompose_einsum_equation_simple(
diag = None
tr_row = mat[i]

for op_ in _apply_transpose_reshape(op, tr_row):
op_.compute_output_row(rows[1, :], verbose=verbose)
marked = graph.append(op_)
for iop in _apply_transpose_reshape(op, tr_row):
op = iop
iop.compute_output_row(rows[1, :], verbose=verbose)
marked = graph.append(iop)

# Reduction? (a dimension not used later)
red = []
Expand Down Expand Up @@ -560,7 +561,8 @@ def _decompose_einsum_equation_simple(
op.compute_output_row(rows[1, :], verbose=verbose)

# Removes empty axes.
for op_ in _apply_squeeze_transpose(op, rows[1, :], mat[len(shapes), :]):
op_.compute_output_row(rows[1, :], verbose=verbose)
graph.append(op_)
for iop in _apply_squeeze_transpose(op, rows[1, :], mat[len(shapes), :]):
op = iop
iop.compute_output_row(rows[1, :], verbose=verbose)
graph.append(iop)
return graph
23 changes: 8 additions & 15 deletions onnx_extended/tools/einsum/einsum_impl_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,17 +466,12 @@ def _check_shape_(self, m: numpy.ndarray):

def _get_data(self, data: Dict[int, Any], key: Union[int, "EinsumSubOp"]) -> Any:
if isinstance(key, int):
assert key in data, "Unable to find key %d in %r." % (
key,
list(sorted(data)),
)
assert key in data, f"Unable to find key {key!r} in {list(sorted(data))}"
return data[key]
if isinstance(key, EinsumSubOp):
assert id(key) in data, "Unable to find key %d in %r." % (
id(key),
list(sorted(data)),
)
return data[id(key)]
skey = id(key)
assert skey in data, f"Unable to find key {skey!r} in {list(sorted(data))}"
return data[skey]
raise TypeError(f"Unexpected input type {type(key)!r}.")

def _apply_id(
Expand Down Expand Up @@ -759,7 +754,7 @@ def apply(
return output

def _onnx_name(self) -> str:
return "einsum%d_%s" % (id(self), self.name[:2])
return f"einsum{id(self)}_{self.name}"

def _check_onnx_opset_(self, opset: Optional[int], limit: int):
if opset is not None and opset < limit:
Expand Down Expand Up @@ -1264,13 +1259,13 @@ def append(self, op: Union[int, EinsumSubOp]) -> Optional[EinsumSubOp]:
:return: op or None if op is an integer
"""
if isinstance(op, int):
assert op not in self._nodes, "Key %d already added." % op
assert op not in self._nodes, f"Key {op!r} already added."
self._nodes[op] = op
self.last_added_op = op
self._inputs[op] = op
return None
if isinstance(op, EinsumSubOp):
assert op not in self._nodes, "Key %d already added, op=%r." % (id(op), op)
assert op not in self._nodes, f"Key {id(op)} already added, op={op!r}."
self._nodes[id(op)] = op
self._ops.append(op)
self.last_added_op = op
Expand Down Expand Up @@ -1521,9 +1516,7 @@ def _pprint_forward(self) -> str:
rows.append(line)
return "\n".join(rows)

def _replace_node_sequence(
self, added: List[EinsumSubOp], deleted: List[EinsumSubOp]
):
def _replace_node_sequence(self, added: EinsumSubOp, deleted: List[EinsumSubOp]):
"""
Removes a sequence of nodes. The method does not check
that the graph remains consistent.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ ignore_missing_imports = true
packages = ["onnx_extended"]
exclude = [
"^_doc/examples", # skips examples in the documentation
"^_doc/auto_examples", # skips examples in the documentation
"^_unittests", # skips unit tests
"^build", # skips build
"^dist", # skips dist
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def get_cmake_args(self, cfg: str) -> List[str]:
"""
iswin = is_windows()
isdar = is_darwin()
cmake_cmd_args = []
cmake_cmd_args: List[str] = []

path = sys.executable
vers = (
Expand Down

0 comments on commit 710fd0c

Please sign in to comment.