Skip to content

Commit

Permalink
Add support for mutations
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Jul 29, 2023
1 parent 2cc04a6 commit 6d4d43f
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 64 deletions.
41 changes: 40 additions & 1 deletion docs/reference/egglog-translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,43 @@ def baz(a: i64Like, b: i64Like=i64(0)) -> i64:
baz(1)
```

### Mutating arguments

In order to support Python functions and methods which mutate their arguments, you can pass in the `mutate_first_arg` keyword argument to the `@egraph.function` decorator and the `mutates_self` argument to the `@egraph.method` decorator. This will cause the first argument to be mutated in place, instead of being copied.

```{code-cell} python
from copy import copy
mutate_egraph = EGraph()
@mutate_egraph.class_
class Int(Expr):
def __init__(self, i: i64Like) -> None:
...
def __add__(self, other: Int) -> Int: # type: ignore[empty-body]
...
@mutate_egraph.function(mutates_first_arg=True)
def incr(x: Int) -> None:
...
i = var("i", Int)
incr_i = copy(i)
incr(incr_i)
x = Int(10)
incr(x)
mutate_egraph.register(rewrite(incr_i).to(i + Int(1)), x)
mutate_egraph.run(10)
mutate_egraph.check(eq(x).to(Int(10) + Int(1)))
mutate_egraph
```

Any function which mutates its first argument must return `None`. In egglog, this is translated into a function which
returns the type of its first argument.

Note that dunder methods such as `__setitem__` will automatically be marked as mutating their first argument.

### Datatype functions

In egglog, the `(datatype ...)` command can also be used to declare functions. All of the functions declared in this block return the type of the declared datatype. Similarily, in Python, we can use the `@egraph.class_` decorator on a class to define a number of functions associated with that class. These
Expand Down Expand Up @@ -534,7 +571,9 @@ egraph.register(
# (extract y :variants 2)
y = egraph.define("y", Math(6) + Math(2) * Math.var("x"))
egraph.run(10)
egraph.extract_multiple(y, 2)
# TODO: For some reason this is extracting temp vars
# egraph.extract_multiple(y, 2)
egraph
```

### Simplify
Expand Down
136 changes: 109 additions & 27 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"ExprDecl",
"TypedExprDecl",
"ClassDecl",
"PrettyContext",
]
# Special methods which we might want to use as functions
# Mapping to the operator they represent for pretty printing them
Expand Down Expand Up @@ -288,7 +289,7 @@ def register_constant_callable(
self._decl.set_constant_type(ref, type_ref)
# Create a function decleartion for a constant function. This is similar to how egglog compiles
# the `declare` command.
return FunctionDecl((), (), (), type_ref.to_var()).to_commands(self, egg_name or ref.generate_egg_name())
return FunctionDecl((), (), (), type_ref.to_var(), False).to_commands(self, egg_name or ref.generate_egg_name())

def register_preserved_method(self, class_: str, method: str, fn: Callable) -> None:
self._decl._classes[class_].preserved_methods[method] = fn
Expand Down Expand Up @@ -337,7 +338,14 @@ def to_constant_function_decl(self) -> FunctionDecl:
Create a function declaration for a constant function. This is similar to how egglog compiles
the `constant` command.
"""
return FunctionDecl(arg_types=(), arg_names=(), arg_defaults=(), return_type=self.to_var(), var_arg_type=None)
return FunctionDecl(
arg_types=(),
arg_names=(),
arg_defaults=(),
return_type=self.to_var(),
mutates_first_arg=False,
var_arg_type=None,
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -432,8 +440,14 @@ class FunctionDecl:
arg_names: Optional[tuple[str, ...]]
arg_defaults: tuple[Optional[ExprDecl], ...]
return_type: TypeOrVarRef
mutates_first_arg: bool
var_arg_type: Optional[TypeOrVarRef] = None

def __post_init__(self):
# If we mutate the first arg, then the first arg should be the same type as the return
if self.mutates_first_arg:
assert self.arg_types[0] == self.return_type

def to_signature(self, transform_default: Callable[[TypedExprDecl], object]) -> Signature:
arg_names = self.arg_names or tuple(f"__{i}" for i in range(len(self.arg_types)))
parameters = [
Expand Down Expand Up @@ -491,7 +505,7 @@ def from_egg(cls, var: bindings.Var) -> TypedExprDecl:
def to_egg(self, _decls: ModuleDeclarations) -> bindings.Var:
return bindings.Var(self.name)

def pretty(self, mod_decls: ModuleDeclarations, **kwargs) -> str:
def pretty(self, context: PrettyContext, **kwargs) -> str:
return self.name


Expand Down Expand Up @@ -525,7 +539,7 @@ def to_egg(self, _decls: ModuleDeclarations) -> bindings.Lit:
return bindings.Lit(bindings.String(self.value))
assert_never(self.value)

def pretty(self, mod_decls: ModuleDeclarations, wrap_lit=True, **kwargs) -> str:
def pretty(self, context: PrettyContext, wrap_lit=True, **kwargs) -> str:
"""
Returns a string representation of the literal.
Expand Down Expand Up @@ -581,7 +595,7 @@ def to_egg(self, mod_decls: ModuleDeclarations) -> bindings.Call:
egg_fn = mod_decls.get_egg_fn(self.callable)
return bindings.Call(egg_fn, [a.to_egg(mod_decls) for a in self.args])

def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
def pretty(self, context: PrettyContext, parens=True, **kwargs) -> str:
"""
Pretty print the call.
Expand All @@ -590,8 +604,13 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
ref, args = self.callable, [a.expr for a in self.args]
# Special case != since it doesn't have a decl
if isinstance(ref, MethodRef) and ref.method_name == "__ne__":
return f"{args[0].pretty(mod_decls, wrap_lit=True)} != {args[1].pretty(mod_decls, wrap_lit=True)}"
defaults = mod_decls.get_function_decl(ref).arg_defaults
return f"{args[0].pretty(context, wrap_lit=True)} != {args[1].pretty(context, wrap_lit=True)}"
function_decl = context.mod_decls.get_function_decl(ref)
defaults = function_decl.arg_defaults
if function_decl.mutates_first_arg:
mutated_arg_type = function_decl.arg_types[0].to_just().name
else:
mutated_arg_type = None
if isinstance(ref, FunctionRef):
fn_str = ref.name
elif isinstance(ref, ClassMethodRef):
Expand All @@ -605,23 +624,37 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
slf, *args = args
defaults = defaults[1:]
if name in UNARY_METHODS:
return f"{UNARY_METHODS[name]}{slf.pretty(mod_decls)}"
return f"{UNARY_METHODS[name]}{slf.pretty(context)}"
elif name in BINARY_METHODS:
assert len(args) == 1
expr = f"{slf.pretty(mod_decls )} {BINARY_METHODS[name]} {args[0].pretty(mod_decls, wrap_lit=False)}"
expr = f"{slf.pretty(context )} {BINARY_METHODS[name]} {args[0].pretty(context, wrap_lit=False)}"
return expr if not parens else f"({expr})"
elif name == "__getitem__":
assert len(args) == 1
return f"{slf.pretty(mod_decls)}[{args[0].pretty(mod_decls, wrap_lit=False)}]"
return f"{slf.pretty(context)}[{args[0].pretty(context, wrap_lit=False)}]"
elif name == "__call__":
return f"{slf.pretty(mod_decls)}({', '.join(a.pretty(mod_decls, wrap_lit=False) for a in args)})"
fn_str = f"{slf.pretty(mod_decls)}.{name}"
return f"{slf.pretty(context)}({', '.join(a.pretty(context, wrap_lit=False) for a in args)})"
elif name == "__delitem__":
assert len(args) == 1
assert mutated_arg_type
name = context.name_expr(mutated_arg_type, slf)
context.statements.append(f"del {name}[{args[0].pretty(context, parens=False, wrap_lit=False)}]")
return name
elif name == "__setitem__":
assert len(args) == 2
assert mutated_arg_type
name = context.name_expr(mutated_arg_type, slf)
context.statements.append(
f"{name}[{args[0].pretty(context, parens=False, wrap_lit=False)}] = {args[1].pretty(context, parens=False, wrap_lit=False)}"
)
return name
fn_str = f"{slf.pretty(context)}.{name}"
elif isinstance(ref, ConstantRef):
return ref.name
elif isinstance(ref, ClassVariableRef):
return f"{ref.class_name}.{ref.variable_name}"
elif isinstance(ref, PropertyRef):
return f"{args[0].pretty(mod_decls)}.{ref.property_name}"
return f"{args[0].pretty(context)}.{ref.property_name}"
else:
assert_never(ref)
# Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
Expand All @@ -632,36 +665,85 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
n_defaults += 1
if n_defaults:
args = args[:-n_defaults]
return f"{fn_str}({', '.join(a.pretty(mod_decls, wrap_lit=False) for a in args)})"
if mutated_arg_type:
name = context.name_expr(mutated_arg_type, args[0])
context.statements.append(
f"{fn_str}({', '.join({name}, *(a.pretty(context, wrap_lit=False) for a in args[1:]))})"
)
return name
return f"{fn_str}({', '.join(a.pretty(context, wrap_lit=False) for a in args)})"


@dataclass
class PrettyContext:
mod_decls: ModuleDeclarations
# List of statements of "context" setting variable for the expr
statements: list[str] = field(default_factory=list)

_gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))

def generate_name(self, typ: str) -> str:
self._gen_name_types[typ] += 1
return f"_{typ}_{self._gen_name_types[typ]}"

def name_expr(self, expr_type: str, expr: ExprDecl) -> str:
name = self.generate_name(expr_type)
self.statements.append(f"{name} = copy({expr.pretty(self, parens=False)})")
return name

def render(self, expr: str) -> str:
return "\n".join(self.statements + [expr])


def test_expr_pretty():
mod_decls = ModuleDeclarations(Declarations())
assert VarDecl("x").pretty(mod_decls) == "x"
assert LitDecl(42).pretty(mod_decls) == "i64(42)"
assert LitDecl("foo").pretty(mod_decls) == 'String("foo")'
assert LitDecl(None).pretty(mod_decls) == "unit()"
context = PrettyContext(ModuleDeclarations(Declarations()))
assert VarDecl("x").pretty(context) == "x"
assert LitDecl(42).pretty(context) == "i64(42)"
assert LitDecl("foo").pretty(context) == 'String("foo")'
assert LitDecl(None).pretty(context) == "unit()"

def v(x: str) -> TypedExprDecl:
return TypedExprDecl(JustTypeRef(""), VarDecl(x))

assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(mod_decls) == "foo(x)"
assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(mod_decls) == "foo(x, y, z)"
assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(mod_decls) == "x + y"
assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(mod_decls) == "x[y]"
assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(mod_decls) == "foo(x, y)"
assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(mod_decls) == "foo.bar(x, y)"
assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(mod_decls) == "x(y)"
assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(context) == "foo(x)"
assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(context) == "foo(x, y, z)"
assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(context) == "x + y"
assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(context) == "x[y]"
assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(context) == "foo(x, y)"
assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(context) == "foo.bar(x, y)"
assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(context) == "x(y)"
assert (
CallDecl(
ClassMethodRef("Map", "__init__"),
(),
(JustTypeRef("i64"), JustTypeRef("Unit")),
).pretty(mod_decls)
).pretty(context)
== "Map[i64, Unit]()"
)


def test_setitem_pretty():
context = PrettyContext(ModuleDeclarations(Declarations()))

def v(x: str) -> TypedExprDecl:
return TypedExprDecl(JustTypeRef("typ"), VarDecl(x))

final_expr = CallDecl(MethodRef("foo", "__setitem__"), (v("x"), v("y"), v("z"))).pretty(context)
assert context.render(final_expr) == "_typ_1 = x\n_typ_1[y] = z\n_typ_1"


def test_delitem_pretty():
context = PrettyContext(ModuleDeclarations(Declarations()))

def v(x: str) -> TypedExprDecl:
return TypedExprDecl(JustTypeRef("typ"), VarDecl(x))

final_expr = CallDecl(MethodRef("foo", "__delitem__"), (v("x"), v("y"))).pretty(context)
assert context.render(final_expr) == "_typ_1 = x\ndel _typ_1[y]\n_typ_1"


# TODO: Multiple mutations,

ExprDecl = Union[VarDecl, LitDecl, CallDecl]


Expand Down
Loading

0 comments on commit 6d4d43f

Please sign in to comment.