Skip to content

Commit

Permalink
Upstream updates
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Sep 14, 2023
1 parent 7621470 commit 959b382
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 99 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.18.1", features = ["extension-module"] }
# egglog = { git = "https://github.com/egraphs-good/egglog", rev = "c83fc750878755eb610a314da90f9273b3bfe25d" }
egglog = { git = "https://github.com/oflatt/egg-smol", rev = "f6df3ff831b65405665e1751b0ef71c61b025432" }
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "4d67f262a6f27aa5cfb62a2cfc7df968959105df" }
# egglog = { git = "https://github.com/oflatt/egg-smol", rev = "f6df3ff831b65405665e1751b0ef71c61b025432" }
# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "c01695618ed4de2fbfa8116476e208bc1ca86612" }

pyo3-log = "0.8.1"
Expand Down
6 changes: 6 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea

## Unreleased

- Bump [egglog dep](https://github.com/egraphs-good/egglog/compare/c83fc750878755eb610a314da90f9273b3bfe25d...4d67f262a6f27aa5cfb62a2cfc7df968959105df)

### Breaking Changes

- Switches `RunReport` to include more granular timings

### New Features

- Add ability to pass `seminaive` flag to Egraph to replicate `--naive` CLI flag [#48](https://github.com/metadsl/egglog-python/pull/48)
- Add ability to inline leaves $n$ times instead of just once for visualization [#48](https://github.com/metadsl/egglog-python/pull/48)
- Add `Relation` and `PrintOverallStatistics` low level commands [#46](https://github.com/metadsl/egglog-python/pull/46)
- Adds `count-matches` and `replace` string commands [#46](https://github.com/metadsl/egglog-python/pull/46)

### Bug fixes

Expand Down
21 changes: 15 additions & 6 deletions python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,20 @@ class IdentSort:
@final
class RunReport:
updated: bool
search_time: timedelta
apply_time: timedelta
rebuild_time: timedelta
search_time_per_rule: dict[str, timedelta]
apply_time_per_rule: dict[str, timedelta]
search_time_per_ruleset: dict[str, timedelta]
apply_time_per_ruleset: dict[str, timedelta]
rebuild_time_per_ruleset: dict[str, timedelta]

def __init__(
self,
updated: bool,
search_time: timedelta,
apply_time: timedelta,
rebuild_time: timedelta,
search_time_per_rule: dict[str, timedelta],
apply_time_per_rule: dict[str, timedelta],
search_time_per_ruleset: dict[str, timedelta],
apply_time_per_ruleset: dict[str, timedelta],
rebuild_time_per_ruleset: dict[str, timedelta],
) -> None: ...

@final
Expand Down Expand Up @@ -432,6 +436,10 @@ class Relation:

def __init__(self, constructor: str, inputs: list[str]) -> None: ...

@final
class PrintOverallStatistics:
def __init__(self) -> None: ...

_Command = (
SetOption
| Datatype
Expand All @@ -458,6 +466,7 @@ _Command = (
| Include
| CheckProof
| Relation
| PrintOverallStatistics
)

def termdag_term_to_expr(termdag: TermDag, term: _Term) -> _Expr: ...
9 changes: 9 additions & 0 deletions python/egglog/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class String(Expr):
def __init__(self, value: str):
...

@BUILTINS.method(egg_fn="replace")
def replace(self, old: StringLike, new: StringLike) -> String: # type: ignore[empty-body]
...


@BUILTINS.function(egg_fn="+")
def join(*strings: StringLike) -> String: # type: ignore[empty-body]
Expand Down Expand Up @@ -150,6 +154,11 @@ def to_string(self) -> String: # type: ignore[empty-body]
converter(int, i64, i64)


@BUILTINS.function(egg_fn="count-matches")
def count_matches(s: StringLike, pattern: StringLike) -> i64: # type: ignore[empty-body]
...


f64Like = Union[float, "f64"]


Expand Down
8 changes: 4 additions & 4 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

def extract_py(e: Expr) -> Any:
egraph = EGraph.current()
egraph.push()
# print(e)
egraph.register(e)
egraph.run((run() * 30).saturate())
Expand Down Expand Up @@ -60,8 +61,10 @@ def extract_py(e: Expr) -> Any:
except EggSmolError:
other_versions = egraph.extract_multiple(final_object, 10)
other_verions_str = "\n\n".join(map(str, other_versions))
egraph.graphviz().render(view=True)
raise Exception(f"Failed to extract:\n{other_verions_str}")
# print(res)
egraph.pop()
return res


Expand Down Expand Up @@ -1235,10 +1238,7 @@ def _linalg(x: NDArray, full_matrices: Bool):
# to analyze `any(((astype(unique_counts(NDArray.var("y"))[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0))) < NDArray.scalar(Value.int(Int(0)))).bool()``
##


@array_api_module.function
def greater_zero(value: Value) -> Unit:
...
greater_zero = array_api_module.relation("greater_zero", Value)


# @array_api_module.function
Expand Down
89 changes: 17 additions & 72 deletions python/tests/__snapshots__/test_array_api/test_sklearn_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,74 +9,31 @@
assume_value_one_of(_NDArray_2, _TupleValue_1)
_NDArray_3 = reshape(_NDArray_2, TupleInt(Int(-1)))
_NDArray_4 = astype(unique_counts(_NDArray_3)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0)))
_NDArray_5 = zeros(
TupleInt(Int(3)) + TupleInt(Int(4)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device)
)
_NDArray_5 = zeros(TupleInt(Int(3)) + TupleInt(Int(4)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device))
_MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice()))
_IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1)
_NDArray_5[_IndexKey_1] = mean(
_NDArray_1[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(0))))],
OptionalIntOrTuple.int(Int(0)),
)
_NDArray_5[_IndexKey_1] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], OptionalIntOrTuple.int(Int(0)))
_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(1))) + _MultiAxisIndexKey_1)
_NDArray_5[_IndexKey_2] = mean(
_NDArray_1[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(1))))],
OptionalIntOrTuple.int(Int(0)),
)
_NDArray_5[_IndexKey_2] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], OptionalIntOrTuple.int(Int(0)))
_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(2))) + _MultiAxisIndexKey_1)
_NDArray_5[_IndexKey_3] = mean(
_NDArray_1[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(2))))],
OptionalIntOrTuple.int(Int(0)),
)
_NDArray_5[_IndexKey_3] = mean(_NDArray_1[ndarray_index(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(2))))], OptionalIntOrTuple.int(Int(0)))
_NDArray_6 = concat(
TupleNDArray(
_NDArray_1[ndarray_index(_NDArray_3 == NDArray.vector(_TupleValue_1)[IndexKey.int(Int(0))])]
- _NDArray_5[_IndexKey_1]
)
TupleNDArray(_NDArray_1[ndarray_index(_NDArray_3 == NDArray.vector(_TupleValue_1)[IndexKey.int(Int(0))])] - _NDArray_5[_IndexKey_1])
+ (
TupleNDArray(
_NDArray_1[ndarray_index(_NDArray_3 == NDArray.vector(_TupleValue_1)[IndexKey.int(Int(1))])]
- _NDArray_5[_IndexKey_2]
)
+ TupleNDArray(
_NDArray_1[ndarray_index(_NDArray_3 == NDArray.vector(_TupleValue_1)[IndexKey.int(Int(2))])]
- _NDArray_5[_IndexKey_3]
)
TupleNDArray(_NDArray_1[ndarray_index(_NDArray_3 == NDArray.vector(_TupleValue_1)[IndexKey.int(Int(1))])] - _NDArray_5[_IndexKey_2])
+ TupleNDArray(_NDArray_1[ndarray_index(_NDArray_3 == NDArray.vector(_TupleValue_1)[IndexKey.int(Int(2))])] - _NDArray_5[_IndexKey_3])
),
OptionalInt.some(Int(0)),
)
_NDArray_7 = std(_NDArray_6, OptionalIntOrTuple.int(Int(0)))
_NDArray_7[
ndarray_index(std(_NDArray_6, OptionalIntOrTuple.int(Int(0))) == NDArray.scalar(Value.int(Int(0))))
] = NDArray.scalar(Value.float(Float(1.0)))
_TupleNDArray_1 = svd(
sqrt(NDArray.scalar(Value.int(NDArray.scalar(Value.float(Float(1.0))).to_int() / Int(147))))
* (_NDArray_6 / _NDArray_7),
FALSE,
)
_Slice_1 = Slice(
OptionalInt.none,
OptionalInt.some(
astype(sum(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001)))), DType.int32).to_int()
),
)
_NDArray_8 = (
_TupleNDArray_1[Int(2)][
IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)
]
/ _NDArray_7
).T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)]
_NDArray_7[ndarray_index(std(_NDArray_6, OptionalIntOrTuple.int(Int(0))) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0)))
_TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.int(NDArray.scalar(Value.float(Float(1.0))).to_int() / Int(147)))) * (_NDArray_6 / _NDArray_7), FALSE)
_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_int()))
_NDArray_8 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_7).T / _TupleNDArray_1[
Int(1)
][IndexKey.slice(_Slice_1)]
_TupleNDArray_2 = svd(
(
sqrt(
NDArray.scalar(
Value.int(
(Int(150) * _NDArray_4.to_int()) * (NDArray.scalar(Value.float(Float(1.0))).to_int() / Int(2))
)
)
)
* (_NDArray_5 - (_NDArray_4 @ _NDArray_5)).T
).T
(sqrt(NDArray.scalar(Value.int((Int(150) * _NDArray_4.to_int()) * (NDArray.scalar(Value.float(Float(1.0))).to_int() / Int(2))))) * (_NDArray_5 - (_NDArray_4 @ _NDArray_5)).T).T
@ _NDArray_8,
FALSE,
)
Expand All @@ -92,15 +49,8 @@
Slice(
OptionalInt.none,
OptionalInt.some(
astype(
sum(
_TupleNDArray_2[Int(1)]
> (
NDArray.scalar(Value.float(Float(0.0001)))
* _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]
)
),
DType.int32,
sum(
astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32)
).to_int()
),
)
Expand All @@ -109,9 +59,4 @@
)
]
)
)[
IndexKey.multi_axis(
_MultiAxisIndexKey_1
+ MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(2)))))
)
]
)[IndexKey.multi_axis(_MultiAxisIndexKey_1 + MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice(OptionalInt.none, OptionalInt.some(Int(2))))))]
9 changes: 4 additions & 5 deletions python/tests/__snapshots__/test_array_api/test_to_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ def my_fn(X, y):
assert X.dtype == np.float64
assert y.dtype == np.int64
assert X.dtype == np.float64
_0 = np.array(150.0)
assert y.shape == (150,)
assert y.shape == (150,)
_0 = np.array(150.0)
assert X.shape == (150,) + (4,)
assert X.shape == (150,) + (4,)
assert y.shape == (150,)
Expand All @@ -20,9 +20,8 @@ def my_fn(X, y):
_4 = _2[1].astype(np.float64)
_5 = _4 / _0
_6 = np.zeros((3,) + (4,), dtype=np.float64)
_7 = _5 + X
_6 = np.zeros((3,) + (4,), dtype=np.float64)
_7 = _7 + _6
_7 = _7 + _6
_7 = _5 + X
return _7
_8 = _7 + _6
_8 = _7 + _6
return _8
3 changes: 0 additions & 3 deletions python/tests/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ def test_run_rules(self):

run_report = egraph.run_report()
assert isinstance(run_report, RunReport)
total_time = run_report.search_time + run_report.apply_time + run_report.rebuild_time
# Verify less than the total time (which includes time spent in Python).
assert total_time < (end_time - start_time)

def test_extract(self):
# Example from extraction-cost
Expand Down
31 changes: 25 additions & 6 deletions src/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ convert_enums!(
egglog::ast::Command::Relation {constructor, inputs} => Relation {
constructor: constructor.to_string(),
inputs: inputs.iter().map(|i| i.to_string()).collect()
}
};
PrintOverallStatistics()
_c -> egglog::ast::Command::PrintOverallStatistics,
egglog::ast::Command::PrintOverallStatistics => PrintOverallStatistics {}
};
egglog::ExtractReport: "{:?}" => ExtractReport {
Best(termdag: TermDag, cost: usize, term: Term)
Expand Down Expand Up @@ -356,12 +359,28 @@ convert_struct!(
i -> IdentSort {ident: i.ident.to_string(), sort: i.sort.to_string()};
egglog::RunReport: "{:?}" => RunReport(
updated: bool,
search_time: WrappedDuration,
apply_time: WrappedDuration,
rebuild_time: WrappedDuration
search_time_per_rule: HashMap<String, WrappedDuration>,
apply_time_per_rule: HashMap<String, WrappedDuration>,
search_time_per_ruleset: HashMap<String, WrappedDuration>,
apply_time_per_ruleset: HashMap<String, WrappedDuration>,
rebuild_time_per_ruleset: HashMap<String, WrappedDuration>
)
r -> egglog::RunReport {updated: r.updated, search_time: r.search_time.0, apply_time: r.apply_time.0, rebuild_time: r.rebuild_time.0},
r -> RunReport {updated: r.updated, search_time: r.search_time.into(), apply_time: r.apply_time.into(), rebuild_time: r.rebuild_time.into()}
r -> egglog::RunReport {
updated: r.updated,
search_time_per_rule: r.search_time_per_rule.iter().map(|(k, v)| (k.clone().into(), v.clone().0)).collect(),
apply_time_per_rule: r.apply_time_per_rule.iter().map(|(k, v)| (k.clone().into(), v.clone().0)).collect(),
search_time_per_ruleset: r.search_time_per_ruleset.iter().map(|(k, v)| (k.clone().into(), v.clone().0)).collect(),
apply_time_per_ruleset: r.apply_time_per_ruleset.iter().map(|(k, v)| (k.clone().into(), v.clone().0)).collect(),
rebuild_time_per_ruleset: r.rebuild_time_per_ruleset.iter().map(|(k, v)| (k.clone().into(), v.clone().0)).collect()
},
r -> RunReport {
updated: r.updated,
search_time_per_rule: r.search_time_per_rule.iter().map(|(k, v)| (k.clone().to_string(), v.clone().into())).collect(),
apply_time_per_rule: r.apply_time_per_rule.iter().map(|(k, v)| (k.clone().to_string(), v.clone().into())).collect(),
search_time_per_ruleset: r.search_time_per_ruleset.iter().map(|(k, v)| (k.clone().to_string(), v.clone().into())).collect(),
apply_time_per_ruleset: r.apply_time_per_ruleset.iter().map(|(k, v)| (k.clone().to_string(), v.clone().into())).collect(),
rebuild_time_per_ruleset: r.rebuild_time_per_ruleset.iter().map(|(k, v)| (k.clone().to_string(), v.clone().into())).collect()
}
);

impl FromPyObject<'_> for Box<Schedule> {
Expand Down
1 change: 1 addition & 0 deletions stubtest_allow
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Add allows for all no arg classes because they somehow take all args at runtime
.*egglog.bindings.Unit.__init__.*
.*egglog.bindings.CheckProof.__init__.*
.*egglog.bindings.PrintOverallStatistics.__init__.*

0 comments on commit 959b382

Please sign in to comment.