Skip to content

Commit

Permalink
Merge pull request #43 from metadsl/next-sklearn
Browse files Browse the repository at this point in the history
Cleanup sklearn notebook
  • Loading branch information
saulshanabrook authored Aug 26, 2023
2 parents 8de8612 + 9277506 commit bb25e57
Show file tree
Hide file tree
Showing 6 changed files with 21,323 additions and 352 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea
- Removes custom fork of egglog now that visualizations are in core
- Adds int and float to string functions
- Switches `define` to `let`
- Tidy up notebook appearence [#43](https://github.com/metadsl/egglog-python/pull/43)
- Display expressions as code in Jupyter notebook
- Display all expressions when graphing

## 0.5.1 (2023-07-18)

Expand Down
21,632 changes: 21,288 additions & 344 deletions docs/tutorials/array-api.ipynb

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ class EGraph:
def run_program(self, *commands: _Command) -> list[str]: ...
def extract_report(self) -> Optional[_ExtractReport]: ...
def run_report(self) -> Optional[RunReport]: ...
def to_graphviz_string(self) -> str: ...
def to_graphviz_string(
self,
*,
max_functions: Optional[int] = None,
max_calls_per_function: Optional[int] = None,
) -> str: ...
def save_object(self, __o: object, /) -> _Expr: ...
def load_object(self, __e: _Expr, /) -> object: ...

Expand Down
8 changes: 4 additions & 4 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@


def extract_py(e: Expr) -> Any:
print(e)
# print(e)
egraph.register(e)
egraph.run((run() * 10).saturate())
final_object = egraph.extract(e)
print(f" -> {final_object}")
# print(f" -> {final_object}")
# with egraph:
egraph.run((run(runtime_ruleset) * 10 + run() * 10).saturate())
# Run saturation again b/c sometimes it doesn't work the first time.
Expand All @@ -55,7 +55,7 @@ def extract_py(e: Expr) -> Any:
# final_object = egraph.extract(egraph.extract(final_object))
# egraph.run(run(limit=10).saturate())

print(f" -> {egraph.extract(final_object)}\n")
# print(f" -> {egraph.extract(final_object)}\n")
res = egraph.load_object(egraph.extract(final_object.to_py())) # type: ignore[attr-defined]
return res

Expand Down Expand Up @@ -172,7 +172,7 @@ def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind):
]


assert not bool(isdtype(DType.float32, IsDtypeKind.string("integral")))
# assert not bool(isdtype(DType.float32, IsDtypeKind.string("integral")))


@egraph.class_
Expand Down
5 changes: 5 additions & 0 deletions python/egglog/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,11 @@ def __str__(self) -> str:
except black.parsing.InvalidInput:
return pretty_expr

def _ipython_display_(self) -> None:
from IPython.display import Code, display

display(Code(str(self), language="python"))

def __dir__(self) -> Iterable[str]:
return list(self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).methods)

Expand Down
20 changes: 17 additions & 3 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::error::EggResult;
use crate::py_object_sort::PyObjectSort;

use egglog::sort::Sort;
use egglog::SerializeConfig;
use log::info;
use pyo3::{prelude::*, PyTraverseError, PyVisit};
use std::path::PathBuf;
Expand Down Expand Up @@ -81,11 +82,24 @@ impl EGraph {
}

/// Returns the EGraph as graphviz string.
#[pyo3(signature = ())]
fn to_graphviz_string(&self) -> String {
#[pyo3(
signature = (*, max_functions=None, max_calls_per_function=None),
text_signature = "(self, *, max_functions=None, max_calls_per_function=None)"
)]
fn to_graphviz_string(
&self,
max_functions: Option<usize>,
max_calls_per_function: Option<usize>,
) -> String {
info!("Getting graphviz");
// TODO: Expose full serialized e-graph in the future
self.egraph.serialize_for_graphviz().to_dot()
let mut serialized = self.egraph.serialize(SerializeConfig {
max_functions,
max_calls_per_function,
include_temporary_functions: false,
});
serialized.inline_leaves();
serialized.to_dot()
}

/// Register a Python object with the EGraph and return the Expr which represents it.
Expand Down

0 comments on commit bb25e57

Please sign in to comment.