diff --git a/docs/changelog.md b/docs/changelog.md index 07a42b84..c75f2199 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -20,6 +20,8 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea - Added to/from i64 to i64 methods. - Upgraded `egg-smol` dependency ([changes](https://github.com/saulshanabrook/egg-smol/compare/353c4387640019bd2066991ee0488dc6d5c54168...2ac80cb1162c61baef295d8e6d00351bfe84883f)) +- Add support for functions which mutates their args, like `__setitem__` [#35](https://github.com/metadsl/egglog-python/pull/35) + ## 0.5.1 (2023-07-18) - Added support for negation on `f64` sort diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index edaa6101..a76a4d8e 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -160,6 +160,43 @@ def baz(a: i64Like, b: i64Like=i64(0)) -> i64: baz(1) ``` +### Mutating arguments + +In order to support Python functions and methods which mutate their arguments, you can pass in the `mutate_first_arg` keyword argument to the `@egraph.function` decorator and the `mutates_self` argument to the `@egraph.method` decorator. This will cause the first argument to be mutated in place, instead of being copied. + +```{code-cell} python +from copy import copy +mutate_egraph = EGraph() + +@mutate_egraph.class_ +class Int(Expr): + def __init__(self, i: i64Like) -> None: + ... + + def __add__(self, other: Int) -> Int: # type: ignore[empty-body] + ... + +@mutate_egraph.function(mutates_first_arg=True) +def incr(x: Int) -> None: + ... + +i = var("i", Int) +incr_i = copy(i) +incr(incr_i) + +x = Int(10) +incr(x) +mutate_egraph.register(rewrite(incr_i).to(i + Int(1)), x) +mutate_egraph.run(10) +mutate_egraph.check(eq(x).to(Int(10) + Int(1))) +mutate_egraph +``` + +Any function which mutates its first argument must return `None`. In egglog, this is translated into a function which +returns the type of its first argument. + +Note that dunder methods such as `__setitem__` will automatically be marked as mutating their first argument. + ### Datatype functions In egglog, the `(datatype ...)` command can also be used to declare functions. All of the functions declared in this block return the type of the declared datatype. Similarily, in Python, we can use the `@egraph.class_` decorator on a class to define a number of functions associated with that class. These @@ -534,7 +571,9 @@ egraph.register( # (extract y :variants 2) y = egraph.define("y", Math(6) + Math(2) * Math.var("x")) egraph.run(10) -egraph.extract_multiple(y, 2) +# TODO: For some reason this is extracting temp vars +# egraph.extract_multiple(y, 2) +egraph ``` ### Simplify diff --git a/docs/tutorials/array-api.ipynb b/docs/tutorials/array-api.ipynb index 39eacb4e..721520c2 100644 --- a/docs/tutorials/array-api.ipynb +++ b/docs/tutorials/array-api.ipynb @@ -84,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -235,29 +235,49 @@ " )\n", " > NDArray.scalar_float(Float(1e-05))\n", ").bool()\n", - " -> (abs(((astype(NDArray.scalar_int(Int(150)), DType.float64) / NDArray.scalar_float(Float(150.0))) - NDArray.scalar_float(Float(1.0)))) > NDArray.scalar_float(Float(1e-05))).bool()\n" + " -> FALSE\n", + "asarray(NDArray.var(\"X\")).shape[Int(1)] < (\n", + " unique_values(concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))).shape[\n", + " Int(0)\n", + " ]\n", + " - Int(1)\n", + ")\n", + " -> NDArray.var(\"X\").shape[Int(1)] < (unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)] - Int(1))\n", + " -> FALSE\n", + "(\n", + " unique_values(concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))).shape[\n", + " Int(0)\n", + " ]\n", + " - Int(1)\n", + ") < Int(2)\n", + " -> (unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)] - Int(1)) < Int(2)\n", + " -> FALSE\n", + "asarray(NDArray.var(\"X\")).shape.length()\n", + " -> NDArray.var(\"X\").ndim\n", + " -> Int(2)\n", + "unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))).length()\n", + " -> Int(2)\n", + " -> Int(2)\n", + "unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n", + "unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n", + " -> unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)]\n", + " -> Int(3)\n" ] }, { - "ename": "EggSmolError", - "evalue": "Not found: fake expression Bool.to_py [Value { tag: \"Bool\", bits: 119 }]", + "ename": "TypeError", + "evalue": "'RuntimeExpr' object does not support item assignment", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mEggSmolError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[30], line 603\u001b[0m\n\u001b[1;32m 589\u001b[0m \u001b[39m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 590\u001b[0m egraph\u001b[39m.\u001b[39mregister(\n\u001b[1;32m 591\u001b[0m rewrite(X_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(X\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[1;32m 592\u001b[0m rewrite(y_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(y\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 599\u001b[0m rewrite(unique_values(y_arr)\u001b[39m.\u001b[39mshape)\u001b[39m.\u001b[39mto(TupleInt(Int(\u001b[39m3\u001b[39m))),\n\u001b[1;32m 600\u001b[0m )\n\u001b[0;32m--> 603\u001b[0m res \u001b[39m=\u001b[39m fit(X_arr, y_arr)\n\u001b[1;32m 605\u001b[0m \u001b[39m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 606\u001b[0m \n\u001b[1;32m 607\u001b[0m \u001b[39m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 608\u001b[0m \u001b[39m# y_arr = NDArray(y_obj)\u001b[39;00m\n", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 711\u001b[0m\n\u001b[1;32m 697\u001b[0m \u001b[39m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 698\u001b[0m egraph\u001b[39m.\u001b[39mregister(\n\u001b[1;32m 699\u001b[0m rewrite(X_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(X\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[1;32m 700\u001b[0m rewrite(y_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(y\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 707\u001b[0m rewrite(unique_values(y_arr)\u001b[39m.\u001b[39mshape)\u001b[39m.\u001b[39mto(TupleInt(Int(\u001b[39m3\u001b[39m))),\n\u001b[1;32m 708\u001b[0m )\n\u001b[0;32m--> 711\u001b[0m res \u001b[39m=\u001b[39m fit(X_arr, y_arr)\n\u001b[1;32m 713\u001b[0m \u001b[39m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 714\u001b[0m \n\u001b[1;32m 715\u001b[0m \u001b[39m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 716\u001b[0m \u001b[39m# y_arr = NDArray(y_obj)\u001b[39;00m\n", "Cell \u001b[0;32mIn[1], line 15\u001b[0m, in \u001b[0;36mfit\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[39mwith\u001b[39;00m config_context(array_api_dispatch\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m):\n\u001b[1;32m 14\u001b[0m lda \u001b[39m=\u001b[39m LinearDiscriminantAnalysis(n_components\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m X_r2 \u001b[39m=\u001b[39m lda\u001b[39m.\u001b[39;49mfit(X, y)\u001b[39m.\u001b[39mtransform(X)\n\u001b[1;32m 16\u001b[0m \u001b[39mreturn\u001b[39;00m X_r2\n\u001b[1;32m 18\u001b[0m target_names \u001b[39m=\u001b[39m iris\u001b[39m.\u001b[39mtarget_names\n", "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/base.py:1151\u001b[0m, in \u001b[0;36m_fit_context..decorator..wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1144\u001b[0m estimator\u001b[39m.\u001b[39m_validate_params()\n\u001b[1;32m 1146\u001b[0m \u001b[39mwith\u001b[39;00m config_context(\n\u001b[1;32m 1147\u001b[0m skip_parameter_validation\u001b[39m=\u001b[39m(\n\u001b[1;32m 1148\u001b[0m prefer_skip_nested_validation \u001b[39mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1149\u001b[0m )\n\u001b[1;32m 1150\u001b[0m ):\n\u001b[0;32m-> 1151\u001b[0m \u001b[39mreturn\u001b[39;00m fit_method(estimator, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", - "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:602\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 599\u001b[0m \u001b[39mif\u001b[39;00m xp\u001b[39m.\u001b[39many(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_ \u001b[39m<\u001b[39m \u001b[39m0\u001b[39m):\n\u001b[1;32m 600\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mpriors must be non-negative\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 602\u001b[0m \u001b[39mif\u001b[39;00m xp\u001b[39m.\u001b[39mabs(xp\u001b[39m.\u001b[39msum(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_) \u001b[39m-\u001b[39m \u001b[39m1.0\u001b[39m) \u001b[39m>\u001b[39m \u001b[39m1e-5\u001b[39m:\n\u001b[1;32m 603\u001b[0m warnings\u001b[39m.\u001b[39mwarn(\u001b[39m\"\u001b[39m\u001b[39mThe priors do not sum to 1. Renormalizing\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mUserWarning\u001b[39;00m)\n\u001b[1;32m 604\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_ \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_ \u001b[39m/\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_\u001b[39m.\u001b[39msum()\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:403\u001b[0m, in \u001b[0;36m_preserved_method\u001b[0;34m(self, __name)\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n\u001b[1;32m 402\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_typed_expr__\u001b[39m.\u001b[39mtp\u001b[39m.\u001b[39mname\u001b[39m}\u001b[39;00m\u001b[39m has no method \u001b[39m\u001b[39m{\u001b[39;00m__name\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 403\u001b[0m \u001b[39mreturn\u001b[39;00m method(\u001b[39mself\u001b[39;49m)\n", - "Cell \u001b[0;32mIn[30], line 319\u001b[0m, in \u001b[0;36mNDArray.__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[39m@egraph\u001b[39m\u001b[39m.\u001b[39mmethod(preserve\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 318\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__bool__\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mbool\u001b[39m:\n\u001b[0;32m--> 319\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mbool\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbool())\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:403\u001b[0m, in \u001b[0;36m_preserved_method\u001b[0;34m(self, __name)\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n\u001b[1;32m 402\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_typed_expr__\u001b[39m.\u001b[39mtp\u001b[39m.\u001b[39mname\u001b[39m}\u001b[39;00m\u001b[39m has no method \u001b[39m\u001b[39m{\u001b[39;00m__name\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 403\u001b[0m \u001b[39mreturn\u001b[39;00m method(\u001b[39mself\u001b[39;49m)\n", - "Cell \u001b[0;32mIn[30], line 40\u001b[0m, in \u001b[0;36mBool.__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[39m@egraph\u001b[39m\u001b[39m.\u001b[39mmethod(preserve\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 39\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__bool__\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mbool\u001b[39m:\n\u001b[0;32m---> 40\u001b[0m \u001b[39mreturn\u001b[39;00m extract_py(\u001b[39mself\u001b[39;49m)\n", - "Cell \u001b[0;32mIn[30], line 32\u001b[0m, in \u001b[0;36mextract_py\u001b[0;34m(e)\u001b[0m\n\u001b[1;32m 30\u001b[0m egraph\u001b[39m.\u001b[39mrun((run(runtime_ruleset, limit\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m) \u001b[39m+\u001b[39m run(limit\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m))\u001b[39m.\u001b[39msaturate())\n\u001b[1;32m 31\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m -> \u001b[39m\u001b[39m{\u001b[39;00megraph\u001b[39m.\u001b[39mextract(final_object)\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 32\u001b[0m res \u001b[39m=\u001b[39m egraph\u001b[39m.\u001b[39mload_object(egraph\u001b[39m.\u001b[39;49mextract(final_object\u001b[39m.\u001b[39;49mto_py()))\n\u001b[1;32m 33\u001b[0m \u001b[39mreturn\u001b[39;00m res\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/egraph.py:737\u001b[0m, in \u001b[0;36mEGraph.extract\u001b[0;34m(self, expr)\u001b[0m\n\u001b[1;32m 735\u001b[0m typed_expr \u001b[39m=\u001b[39m expr_parts(expr)\n\u001b[1;32m 736\u001b[0m egg_expr \u001b[39m=\u001b[39m typed_expr\u001b[39m.\u001b[39mto_egg(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_mod_decls)\n\u001b[0;32m--> 737\u001b[0m extract_report \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_extract(egg_expr, \u001b[39m0\u001b[39;49m)\n\u001b[1;32m 738\u001b[0m new_typed_expr \u001b[39m=\u001b[39m TypedExprDecl\u001b[39m.\u001b[39mfrom_egg(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_mod_decls, extract_report\u001b[39m.\u001b[39mexpr)\n\u001b[1;32m 739\u001b[0m \u001b[39mif\u001b[39;00m new_typed_expr\u001b[39m.\u001b[39mtp \u001b[39m!=\u001b[39m typed_expr\u001b[39m.\u001b[39mtp:\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/egraph.py:754\u001b[0m, in \u001b[0;36mEGraph._run_extract\u001b[0;34m(self, expr, n)\u001b[0m\n\u001b[1;32m 753\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_run_extract\u001b[39m(\u001b[39mself\u001b[39m, expr: bindings\u001b[39m.\u001b[39m_Expr, n: \u001b[39mint\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m bindings\u001b[39m.\u001b[39mExtractReport:\n\u001b[0;32m--> 754\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_process_commands([bindings\u001b[39m.\u001b[39;49mExtract(n, expr)])\n\u001b[1;32m 755\u001b[0m extract_report \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_egraph\u001b[39m.\u001b[39mextract_report()\n\u001b[1;32m 756\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m extract_report:\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/egraph.py:634\u001b[0m, in \u001b[0;36mEGraph._process_commands\u001b[0;34m(self, commands)\u001b[0m\n\u001b[1;32m 633\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_process_commands\u001b[39m(\u001b[39mself\u001b[39m, commands: Iterable[bindings\u001b[39m.\u001b[39m_Command]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 634\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_egraph\u001b[39m.\u001b[39;49mrun_program(\u001b[39m*\u001b[39;49mcommands)\n", - "\u001b[0;31mEggSmolError\u001b[0m: Not found: fake expression Bool.to_py [Value { tag: \"Bool\", bits: 119 }]" + "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:629\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 624\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 625\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mcovariance estimator \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 626\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mis not supported \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 627\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mwith svd solver. Try another solver\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 628\u001b[0m )\n\u001b[0;32m--> 629\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_solve_svd(X, y)\n\u001b[1;32m 630\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msolver \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mlsqr\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 631\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_solve_lstsq(\n\u001b[1;32m 632\u001b[0m X,\n\u001b[1;32m 633\u001b[0m y,\n\u001b[1;32m 634\u001b[0m shrinkage\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshrinkage,\n\u001b[1;32m 635\u001b[0m covariance_estimator\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator,\n\u001b[1;32m 636\u001b[0m )\n", + "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:501\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis._solve_svd\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 498\u001b[0m n_samples, n_features \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mshape\n\u001b[1;32m 499\u001b[0m n_classes \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclasses_\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[0;32m--> 501\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmeans_ \u001b[39m=\u001b[39m _class_means(X, y)\n\u001b[1;32m 502\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstore_covariance:\n\u001b[1;32m 503\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_ \u001b[39m=\u001b[39m _class_cov(X, y, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_)\n", + "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:121\u001b[0m, in \u001b[0;36m_class_means\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[39mprint\u001b[39m(classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m])\n\u001b[1;32m 120\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]):\n\u001b[0;32m--> 121\u001b[0m means[i, :] \u001b[39m=\u001b[39m xp\u001b[39m.\u001b[39mmean(X[y \u001b[39m==\u001b[39m i], axis\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m 122\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 123\u001b[0m \u001b[39m# TODO: Explore the choice of using bincount + add.at as it seems sub optimal\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \u001b[39m# from a performance-wise\u001b[39;00m\n\u001b[1;32m 125\u001b[0m cnt \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mbincount(y)\n", + "\u001b[0;31mTypeError\u001b[0m: 'RuntimeExpr' object does not support item assignment" ] } ], @@ -269,6 +289,7 @@ "from egglog.egraph import Unit\n", "import numpy as np\n", "import numbers\n", + "from types import SimpleNamespace\n", "\n", "from egglog import *\n", "\n", @@ -412,9 +433,18 @@ " def __init__(self, value: f64Like) -> None:\n", " ...\n", "\n", + " def abs(self) -> Float:\n", + " ...\n", "\n", "converter(float, Float, lambda x: Float(x))\n", "\n", + "@egraph.register\n", + "def _float(f: f64, f2: f64, r: Bool, o: Float):\n", + " return [\n", + " rewrite(Float(f).abs()).to(Float(f), f >= 0.0),\n", + " rewrite(Float(f).abs()).to(Float(-f), f < 0.0),\n", + " ]\n", + "\n", "\n", "@egraph.class_\n", "class Int(Expr):\n", @@ -441,12 +471,17 @@ "\n", " def __add__(self, other: Int) -> Int:\n", " ...\n", + " def __sub__(self, other: Int) -> Int: ...\n", "\n", " @egraph.method(preserve=True)\n", " def __int__(self) -> int:\n", " return extract_py(self)\n", "\n", " @egraph.method(preserve=True)\n", + " def __index__(self) -> int:\n", + " return extract_py(self)\n", + "\n", + " @egraph.method(preserve=True)\n", " def __float__(self) -> float:\n", " return float(int(self))\n", "\n", @@ -478,6 +513,7 @@ " yield rule(eq(o).to(Int(j))).then(set_(o.to_py()).to(PyObject.from_int(j)))\n", "\n", " yield rewrite(Int(i) + Int(j)).to(Int(i + j))\n", + " yield rewrite(Int(i) - Int(j)).to(Int(i - j))\n", "\n", "\n", "converter(int, Int, lambda x: Int(x))\n", @@ -544,12 +580,17 @@ " ...\n", "\n", "\n", + "\n", "converter(tuple, IndexKey, lambda x: IndexKey.tuple_int(convert(x, TupleInt)))\n", "converter(int, IndexKey, lambda x: IndexKey.int(Int(x)))\n", "converter(Int, IndexKey, lambda x: IndexKey.int(x))\n", "\n", "\n", "@egraph.class_\n", + "class Device(Expr): ...\n", + "\n", + "\n", + "@egraph.class_\n", "class NDArray(Expr):\n", " def __init__(self, py_array: PyObject) -> None:\n", " ...\n", @@ -572,6 +613,11 @@ " ...\n", "\n", " @property\n", + " def device(self) -> Device:\n", + " ...\n", + "\n", + "\n", + " @property\n", " def shape(self) -> TupleInt:\n", " ...\n", "\n", @@ -605,6 +651,9 @@ "\n", " def __gt__(self, other: NDArray) -> NDArray:\n", " ...\n", + " \n", + " def __eq__(self, other: NDArray) -> NDArray:\n", + " ...\n", "\n", " @classmethod\n", " def scalar_float(cls, other: Float) -> NDArray:\n", @@ -618,16 +667,28 @@ " def scalar_bool(cls, other: Bool) -> NDArray:\n", " ...\n", "\n", + "@egraph.function\n", + "def ndarray_index(x: NDArray) -> IndexKey:\n", + " ...\n", + "\n", + "converter(NDArray, IndexKey, ndarray_index)\n", + "\n", + "\n", "\n", "converter(float, NDArray, lambda x: NDArray.scalar_float(Float(x)))\n", "converter(int, NDArray, lambda x: NDArray.scalar_int(Int(x)))\n", "\n", "\n", "@egraph.register\n", - "def _ndarray(x: NDArray, b: Bool):\n", + "def _ndarray(x: NDArray, b: Bool, f: Float, fi1: f64, fi2: f64):\n", " return [\n", " rewrite(x.ndim).to(x.shape.length()),\n", " rewrite(NDArray.scalar_bool(b).bool()).to(b),\n", + " # TODO: Push these down to float\n", + " rewrite(NDArray.scalar_float(f) / NDArray.scalar_float(f)).to(NDArray.scalar_float(Float(1.0))),\n", + " rewrite(NDArray.scalar_float(f) - NDArray.scalar_float(f)).to(NDArray.scalar_float(Float(0.0))),\n", + " rewrite(NDArray.scalar_float(Float(fi1)) > NDArray.scalar_float(Float(fi2))).to(NDArray.scalar_bool(TRUE), fi1 > fi2),\n", + " rewrite(NDArray.scalar_float(Float(fi1)) > NDArray.scalar_float(Float(fi2))).to(NDArray.scalar_bool(FALSE), fi1 <= fi2),\n", " ]\n", "\n", "\n", @@ -703,6 +764,35 @@ "converter(type(None), OptionalDType, lambda x: OptionalDType.none)\n", "converter(DType, OptionalDType, lambda x: OptionalDType.some(x))\n", "\n", + "@egraph.class_\n", + "class OptionalDevice(Expr):\n", + " none: ClassVar[OptionalDevice]\n", + "\n", + " @classmethod\n", + " def some(cls, value: Device) -> OptionalDevice:\n", + " ...\n", + "\n", + "\n", + "converter(type(None), OptionalDevice, lambda x: OptionalDevice.none)\n", + "converter(Device, OptionalDevice, lambda x: OptionalDevice.some(x))\n", + "\n", + "\n", + "@egraph.class_\n", + "class OptionalTupleInt(Expr):\n", + " none: ClassVar[OptionalTupleInt]\n", + "\n", + " @classmethod\n", + " def some(cls, value: TupleInt) -> OptionalTupleInt:\n", + " ...\n", + "\n", + "\n", + "converter(type(None), OptionalTupleInt, lambda x: OptionalTupleInt.none)\n", + "converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x))\n", + "converter(int, OptionalTupleInt, lambda x: OptionalTupleInt.some(TupleInt(Int(x))))\n", + "\n", + "\n", + "\n", + "\n", "\n", "@egraph.function\n", "def asarray(a: NDArray, dtype: OptionalDType = OptionalDType.none, copy: OptionalBool = OptionalBool.none) -> NDArray:\n", @@ -801,8 +891,7 @@ " return [\n", " rewrite(astype(x, dtype).dtype).to(dtype),\n", " rewrite(sum(astype(x, dtype))).to(astype(sum(x), dtype)),\n", - " rewrite(astype(NDArray.scalar_int(Int(i)), float64).to(NDArray.scalar_float(Float(\n", - "\n", + " rewrite(astype(NDArray.scalar_int(Int(i)), float64)).to(NDArray.scalar_float(Float(f64.from_i64(i))))\n", " ]\n", "\n", "\n", @@ -814,6 +903,44 @@ "def abs(x: NDArray) -> NDArray:\n", " ...\n", "\n", + "@egraph.register\n", + "def _abs(f: Float):\n", + " return [\n", + " rewrite(abs(NDArray.scalar_float(f))).to(NDArray.scalar_float(f)),\n", + " ]\n", + "\n", + "@egraph.function\n", + "def unique_inverse(x: NDArray) -> TupleNDArray:\n", + " ...\n", + "\n", + "@egraph.register\n", + "def _unique_inverse(x: NDArray):\n", + " return [\n", + " rewrite(unique_inverse(x).length()).to(Int(2)),\n", + " # Shape of unique_inverse first element is same as shape of unique_values\n", + " rewrite(unique_inverse(x)[Int(0)].shape).to(unique_values(x).shape),\n", + " ]\n", + "\n", + "@egraph.function\n", + "def zeros(shape: TupleInt, dtype: OptionalDType = OptionalDType.none, device: OptionalDevice = OptionalDevice.none) -> NDArray:\n", + " ...\n", + "@egraph.function\n", + "def mean(x: NDArray, axis: OptionalTupleInt = OptionalTupleInt.none) -> NDArray: ...\n", + "\n", + "\n", + "linalg = sys.modules[__name__]\n", + "\n", + "@egraph.function\n", + "def svd(x: NDArray) -> TupleNDArray:\n", + " ...\n", + "\n", + "\n", + "@egraph.register\n", + "def _linalg(x: NDArray):\n", + " return [\n", + " rewrite(svd(x).length()).to(Int(3)),\n", + " ]\n", + "\n", "##\n", "# Interval analysis\n", "#\n", @@ -873,187 +1000,12 @@ "# y_arr = NDArray(y_obj)" ] }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "RunReport(True, datetime.timedelta(microseconds=4), datetime.timedelta(0), datetime.timedelta(microseconds=2))" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sum((astype(unique_counts(NDArray.var(\"y\"))[Int(1)], DType.float64) / NDArray.scalar_float(Float(150.0))))" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "x = astype(unique_counts(NDArray.var(\"y\"))[Int(1)], DType.float64) / NDArray.scalar_float(Float(150.0))\n", - "egraph.check(\n", - "ndarray_all_greater_0(x)\n", - ")\n", - " " - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "class MyObj:\n", - "\n", - " def __len__(self):\n", - " return 1\n", - " \n", - " def __getitem__(self, i):\n", - " print(\"GETITEM\", i)\n", - " return 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "GETITEM 0\n", - "GETITEM 1\n" - ] - }, - { - "data": { - "text/plain": [ - "1" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x = iter(MyObj())\n", - "next(x)\n", - "next(x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TupleInt(Int(1))[Int(1)] < TupleInt(Int(1))[Int(0)]\n", - " -> TupleInt(Int(1))[Int(1)] < Int(1)\n" - ] - }, - { - "ename": "EggSmolError", - "evalue": "Not found: fake expression Bool.to_py [Value { tag: \"Bool\", bits: 100 }]", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mEggSmolError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mmin\u001b[39;49m(TupleInt(Int(\u001b[39m1\u001b[39;49m)))\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:403\u001b[0m, in \u001b[0;36m_preserved_method\u001b[0;34m(self, __name)\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n\u001b[1;32m 402\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_typed_expr__\u001b[39m.\u001b[39mtp\u001b[39m.\u001b[39mname\u001b[39m}\u001b[39;00m\u001b[39m has no method \u001b[39m\u001b[39m{\u001b[39;00m__name\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 403\u001b[0m \u001b[39mreturn\u001b[39;00m method(\u001b[39mself\u001b[39;49m)\n", - "Cell \u001b[0;32mIn[3], line 37\u001b[0m, in \u001b[0;36mBool.__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[39m@egraph\u001b[39m\u001b[39m.\u001b[39mmethod(preserve\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 36\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__bool__\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mbool\u001b[39m:\n\u001b[0;32m---> 37\u001b[0m \u001b[39mreturn\u001b[39;00m extract_py(\u001b[39mself\u001b[39;49m)\n", - "Cell \u001b[0;32mIn[3], line 29\u001b[0m, in \u001b[0;36mextract_py\u001b[0;34m(e)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[39mwith\u001b[39;00m egraph:\n\u001b[1;32m 28\u001b[0m egraph\u001b[39m.\u001b[39mrun((run(runtime_ruleset, limit\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m) \u001b[39m+\u001b[39m run(limit\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m))\u001b[39m.\u001b[39msaturate())\n\u001b[0;32m---> 29\u001b[0m res \u001b[39m=\u001b[39m egraph\u001b[39m.\u001b[39mload_object(egraph\u001b[39m.\u001b[39;49mextract(final_object\u001b[39m.\u001b[39;49mto_py()))\n\u001b[1;32m 30\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m -> \u001b[39m\u001b[39m{\u001b[39;00mres\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 31\u001b[0m \u001b[39mreturn\u001b[39;00m res\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/egraph.py:737\u001b[0m, in \u001b[0;36mEGraph.extract\u001b[0;34m(self, expr)\u001b[0m\n\u001b[1;32m 735\u001b[0m typed_expr \u001b[39m=\u001b[39m expr_parts(expr)\n\u001b[1;32m 736\u001b[0m egg_expr \u001b[39m=\u001b[39m typed_expr\u001b[39m.\u001b[39mto_egg(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_mod_decls)\n\u001b[0;32m--> 737\u001b[0m extract_report \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_extract(egg_expr, \u001b[39m0\u001b[39;49m)\n\u001b[1;32m 738\u001b[0m new_typed_expr \u001b[39m=\u001b[39m TypedExprDecl\u001b[39m.\u001b[39mfrom_egg(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_mod_decls, extract_report\u001b[39m.\u001b[39mexpr)\n\u001b[1;32m 739\u001b[0m \u001b[39mif\u001b[39;00m new_typed_expr\u001b[39m.\u001b[39mtp \u001b[39m!=\u001b[39m typed_expr\u001b[39m.\u001b[39mtp:\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/egraph.py:754\u001b[0m, in \u001b[0;36mEGraph._run_extract\u001b[0;34m(self, expr, n)\u001b[0m\n\u001b[1;32m 753\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_run_extract\u001b[39m(\u001b[39mself\u001b[39m, expr: bindings\u001b[39m.\u001b[39m_Expr, n: \u001b[39mint\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m bindings\u001b[39m.\u001b[39mExtractReport:\n\u001b[0;32m--> 754\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_process_commands([bindings\u001b[39m.\u001b[39;49mExtract(n, expr)])\n\u001b[1;32m 755\u001b[0m extract_report \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_egraph\u001b[39m.\u001b[39mextract_report()\n\u001b[1;32m 756\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m extract_report:\n", - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/egraph.py:634\u001b[0m, in \u001b[0;36mEGraph._process_commands\u001b[0;34m(self, commands)\u001b[0m\n\u001b[1;32m 633\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_process_commands\u001b[39m(\u001b[39mself\u001b[39m, commands: Iterable[bindings\u001b[39m.\u001b[39m_Command]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 634\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_egraph\u001b[39m.\u001b[39;49mrun_program(\u001b[39m*\u001b[39;49mcommands)\n", - "\u001b[0;31mEggSmolError\u001b[0m: Not found: fake expression Bool.to_py [Value { tag: \"Bool\", bits: 100 }]" - ] - } - ], - "source": [ - "min(TupleInt(Int(1)))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'egraph' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m egraph\n", - "\u001b[0;31mNameError\u001b[0m: name 'egraph' is not defined" - ] - } - ], - "source": [ - "egraph" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/xn/05ktz3056kqd9n8frgd6236h0000gn/T/ipykernel_75113/2632455728.py:1: UserWarning: The numpy.array_api submodule is still experimental. See NEP 47.\n", - " import numpy.array_api as npa\n" - ] - } - ], - "source": [ - "import numpy.array_api as npa" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "UniqueCountsResult(values=Array([0, 1, 2], dtype=int64), counts=Array([50, 50, 50], dtype=int64))" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "npa.unique_counts(npa.asarray(y))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "any(((astype(unique_counts(NDArray.var(\"y\"))[Int(1)], DType.float64) / NDArray.scalar_float(Float(150.0))) < NDArray.scalar_int(Int(0)))).bool()" - ] + "source": [] } ], "metadata": { diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 63a78198..374c4d8e 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -37,6 +37,7 @@ "ExprDecl", "TypedExprDecl", "ClassDecl", + "PrettyContext", ] # Special methods which we might want to use as functions # Mapping to the operator they represent for pretty printing them @@ -288,7 +289,7 @@ def register_constant_callable( self._decl.set_constant_type(ref, type_ref) # Create a function decleartion for a constant function. This is similar to how egglog compiles # the `declare` command. - return FunctionDecl((), (), (), type_ref.to_var()).to_commands(self, egg_name or ref.generate_egg_name()) + return FunctionDecl((), (), (), type_ref.to_var(), False).to_commands(self, egg_name or ref.generate_egg_name()) def register_preserved_method(self, class_: str, method: str, fn: Callable) -> None: self._decl._classes[class_].preserved_methods[method] = fn @@ -337,7 +338,14 @@ def to_constant_function_decl(self) -> FunctionDecl: Create a function declaration for a constant function. This is similar to how egglog compiles the `constant` command. """ - return FunctionDecl(arg_types=(), arg_names=(), arg_defaults=(), return_type=self.to_var(), var_arg_type=None) + return FunctionDecl( + arg_types=(), + arg_names=(), + arg_defaults=(), + return_type=self.to_var(), + mutates_first_arg=False, + var_arg_type=None, + ) @dataclass(frozen=True) @@ -432,8 +440,14 @@ class FunctionDecl: arg_names: Optional[tuple[str, ...]] arg_defaults: tuple[Optional[ExprDecl], ...] return_type: TypeOrVarRef + mutates_first_arg: bool var_arg_type: Optional[TypeOrVarRef] = None + def __post_init__(self): + # If we mutate the first arg, then the first arg should be the same type as the return + if self.mutates_first_arg: + assert self.arg_types[0] == self.return_type + def to_signature(self, transform_default: Callable[[TypedExprDecl], object]) -> Signature: arg_names = self.arg_names or tuple(f"__{i}" for i in range(len(self.arg_types))) parameters = [ @@ -491,7 +505,7 @@ def from_egg(cls, var: bindings.Var) -> TypedExprDecl: def to_egg(self, _decls: ModuleDeclarations) -> bindings.Var: return bindings.Var(self.name) - def pretty(self, mod_decls: ModuleDeclarations, **kwargs) -> str: + def pretty(self, context: PrettyContext, **kwargs) -> str: return self.name @@ -525,7 +539,7 @@ def to_egg(self, _decls: ModuleDeclarations) -> bindings.Lit: return bindings.Lit(bindings.String(self.value)) assert_never(self.value) - def pretty(self, mod_decls: ModuleDeclarations, wrap_lit=True, **kwargs) -> str: + def pretty(self, context: PrettyContext, wrap_lit=True, **kwargs) -> str: """ Returns a string representation of the literal. @@ -581,7 +595,7 @@ def to_egg(self, mod_decls: ModuleDeclarations) -> bindings.Call: egg_fn = mod_decls.get_egg_fn(self.callable) return bindings.Call(egg_fn, [a.to_egg(mod_decls) for a in self.args]) - def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str: + def pretty(self, context: PrettyContext, parens=True, **kwargs) -> str: """ Pretty print the call. @@ -590,8 +604,13 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str: ref, args = self.callable, [a.expr for a in self.args] # Special case != since it doesn't have a decl if isinstance(ref, MethodRef) and ref.method_name == "__ne__": - return f"{args[0].pretty(mod_decls, wrap_lit=True)} != {args[1].pretty(mod_decls, wrap_lit=True)}" - defaults = mod_decls.get_function_decl(ref).arg_defaults + return f"{args[0].pretty(context, wrap_lit=True)} != {args[1].pretty(context, wrap_lit=True)}" + function_decl = context.mod_decls.get_function_decl(ref) + defaults = function_decl.arg_defaults + if function_decl.mutates_first_arg: + mutated_arg_type = function_decl.arg_types[0].to_just().name + else: + mutated_arg_type = None if isinstance(ref, FunctionRef): fn_str = ref.name elif isinstance(ref, ClassMethodRef): @@ -605,23 +624,37 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str: slf, *args = args defaults = defaults[1:] if name in UNARY_METHODS: - return f"{UNARY_METHODS[name]}{slf.pretty(mod_decls)}" + return f"{UNARY_METHODS[name]}{slf.pretty(context)}" elif name in BINARY_METHODS: assert len(args) == 1 - expr = f"{slf.pretty(mod_decls )} {BINARY_METHODS[name]} {args[0].pretty(mod_decls, wrap_lit=False)}" + expr = f"{slf.pretty(context )} {BINARY_METHODS[name]} {args[0].pretty(context, wrap_lit=False)}" return expr if not parens else f"({expr})" elif name == "__getitem__": assert len(args) == 1 - return f"{slf.pretty(mod_decls)}[{args[0].pretty(mod_decls, wrap_lit=False)}]" + return f"{slf.pretty(context)}[{args[0].pretty(context, wrap_lit=False)}]" elif name == "__call__": - return f"{slf.pretty(mod_decls)}({', '.join(a.pretty(mod_decls, wrap_lit=False) for a in args)})" - fn_str = f"{slf.pretty(mod_decls)}.{name}" + return f"{slf.pretty(context)}({', '.join(a.pretty(context, wrap_lit=False) for a in args)})" + elif name == "__delitem__": + assert len(args) == 1 + assert mutated_arg_type + name = context.name_expr(mutated_arg_type, slf) + context.statements.append(f"del {name}[{args[0].pretty(context, parens=False, wrap_lit=False)}]") + return name + elif name == "__setitem__": + assert len(args) == 2 + assert mutated_arg_type + name = context.name_expr(mutated_arg_type, slf) + context.statements.append( + f"{name}[{args[0].pretty(context, parens=False, wrap_lit=False)}] = {args[1].pretty(context, parens=False, wrap_lit=False)}" + ) + return name + fn_str = f"{slf.pretty(context)}.{name}" elif isinstance(ref, ConstantRef): return ref.name elif isinstance(ref, ClassVariableRef): return f"{ref.class_name}.{ref.variable_name}" elif isinstance(ref, PropertyRef): - return f"{args[0].pretty(mod_decls)}.{ref.property_name}" + return f"{args[0].pretty(context)}.{ref.property_name}" else: assert_never(ref) # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default @@ -632,36 +665,85 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str: n_defaults += 1 if n_defaults: args = args[:-n_defaults] - return f"{fn_str}({', '.join(a.pretty(mod_decls, wrap_lit=False) for a in args)})" + if mutated_arg_type: + name = context.name_expr(mutated_arg_type, args[0]) + context.statements.append( + f"{fn_str}({', '.join({name}, *(a.pretty(context, wrap_lit=False) for a in args[1:]))})" + ) + return name + return f"{fn_str}({', '.join(a.pretty(context, wrap_lit=False) for a in args)})" + + +@dataclass +class PrettyContext: + mod_decls: ModuleDeclarations + # List of statements of "context" setting variable for the expr + statements: list[str] = field(default_factory=list) + + _gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0)) + + def generate_name(self, typ: str) -> str: + self._gen_name_types[typ] += 1 + return f"_{typ}_{self._gen_name_types[typ]}" + + def name_expr(self, expr_type: str, expr: ExprDecl) -> str: + name = self.generate_name(expr_type) + self.statements.append(f"{name} = copy({expr.pretty(self, parens=False)})") + return name + + def render(self, expr: str) -> str: + return "\n".join(self.statements + [expr]) def test_expr_pretty(): - mod_decls = ModuleDeclarations(Declarations()) - assert VarDecl("x").pretty(mod_decls) == "x" - assert LitDecl(42).pretty(mod_decls) == "i64(42)" - assert LitDecl("foo").pretty(mod_decls) == 'String("foo")' - assert LitDecl(None).pretty(mod_decls) == "unit()" + context = PrettyContext(ModuleDeclarations(Declarations())) + assert VarDecl("x").pretty(context) == "x" + assert LitDecl(42).pretty(context) == "i64(42)" + assert LitDecl("foo").pretty(context) == 'String("foo")' + assert LitDecl(None).pretty(context) == "unit()" def v(x: str) -> TypedExprDecl: return TypedExprDecl(JustTypeRef(""), VarDecl(x)) - assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(mod_decls) == "foo(x)" - assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(mod_decls) == "foo(x, y, z)" - assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(mod_decls) == "x + y" - assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(mod_decls) == "x[y]" - assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(mod_decls) == "foo(x, y)" - assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(mod_decls) == "foo.bar(x, y)" - assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(mod_decls) == "x(y)" + assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(context) == "foo(x)" + assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(context) == "foo(x, y, z)" + assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(context) == "x + y" + assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(context) == "x[y]" + assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(context) == "foo(x, y)" + assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(context) == "foo.bar(x, y)" + assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(context) == "x(y)" assert ( CallDecl( ClassMethodRef("Map", "__init__"), (), (JustTypeRef("i64"), JustTypeRef("Unit")), - ).pretty(mod_decls) + ).pretty(context) == "Map[i64, Unit]()" ) +def test_setitem_pretty(): + context = PrettyContext(ModuleDeclarations(Declarations())) + + def v(x: str) -> TypedExprDecl: + return TypedExprDecl(JustTypeRef("typ"), VarDecl(x)) + + final_expr = CallDecl(MethodRef("foo", "__setitem__"), (v("x"), v("y"), v("z"))).pretty(context) + assert context.render(final_expr) == "_typ_1 = x\n_typ_1[y] = z\n_typ_1" + + +def test_delitem_pretty(): + context = PrettyContext(ModuleDeclarations(Declarations())) + + def v(x: str) -> TypedExprDecl: + return TypedExprDecl(JustTypeRef("typ"), VarDecl(x)) + + final_expr = CallDecl(MethodRef("foo", "__delitem__"), (v("x"), v("y"))).pretty(context) + assert context.render(final_expr) == "_typ_1 = x\ndel _typ_1[y]\n_typ_1" + + +# TODO: Multiple mutations, + ExprDecl = Union[VarDecl, LitDecl, CallDecl] diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 6edf9143..d4c97f79 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -86,6 +86,8 @@ _BUILTIN_DECLS: Declarations | None = None +ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"} + @dataclass class _BaseModule(ABC): @@ -189,12 +191,14 @@ def _class( default = method.default merge = method.merge on_merge = method.on_merge + mutates_first_arg = method.mutates_self if method.preserve: self._mod_decls.register_preserved_method(cls_name, method_name, fn) continue else: fn = method egg_fn, cost, default, merge, on_merge = None, None, None, None, None + mutates_first_arg = False if isinstance(fn, classmethod): fn = fn.__func__ is_classmethod = True @@ -225,6 +229,7 @@ def _class( cost, merge, on_merge, + mutates_first_arg or method_name in ALWAYS_MUTATES_SELF, "cls" if is_classmethod and not is_init else slf_type_ref, parameters, is_init, @@ -264,6 +269,7 @@ def method( # type: ignore cost: Optional[int] = None, merge: Optional[Callable[[Any, Any], Any]] = None, on_merge: Optional[Callable[[Any, Any], Iterable[ActionLike]]] = None, + mutates_self: bool = False, ) -> Callable[[CALLABLE], CALLABLE]: ... @@ -276,6 +282,7 @@ def method( default: Optional[EXPR] = None, merge: Optional[Callable[[EXPR, EXPR], EXPR]] = None, on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]] = None, + mutates_self: bool = False, ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ... @@ -288,8 +295,9 @@ def method( merge: Optional[Callable[[EXPR, EXPR], EXPR]] = None, on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]] = None, preserve: bool = False, + mutates_self: bool = False, ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: - return lambda fn: _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve) + return lambda fn: _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self) @overload def function(self, fn: CALLABLE, /) -> CALLABLE: @@ -303,6 +311,7 @@ def function( # type: ignore cost: Optional[int] = None, merge: Optional[Callable[[Any, Any], Any]] = None, on_merge: Optional[Callable[[Any, Any], Iterable[ActionLike]]] = None, + mutates_first_arg: bool = False, ) -> Callable[[CALLABLE], CALLABLE]: ... @@ -315,6 +324,7 @@ def function( default: Optional[EXPR] = None, merge: Optional[Callable[[EXPR, EXPR], EXPR]] = None, on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]] = None, + mutates_first_arg: bool = False, ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ... @@ -335,6 +345,7 @@ def _function( self, fn: Callable[..., RuntimeExpr], hint_locals: dict[str, Any], + mutates_first_arg: bool = False, egg_fn: Optional[str] = None, cost: Optional[int] = None, default: Optional[RuntimeExpr] = None, @@ -346,7 +357,9 @@ def _function( """ name = fn.__name__ # Save function decleartion - self._register_function(FunctionRef(name), egg_fn, fn, hint_locals, default, cost, merge, on_merge) + self._register_function( + FunctionRef(name), egg_fn, fn, hint_locals, default, cost, merge, on_merge, mutates_first_arg + ) # Return a runtime function which will act like the decleration return RuntimeFunction(self._mod_decls, name) @@ -362,6 +375,7 @@ def _register_function( cost: Optional[int], merge: Optional[Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr]], on_merge: Optional[Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]]], + mutates_first_arg: bool, # The first arg is either cls, for a classmethod, a self type, or none for a function first_arg: Literal["cls"] | TypeOrVarRef | None = None, cls_typevars: list[TypeVar] = [], @@ -376,13 +390,6 @@ def _register_function( if cls_type_and_name: hint_globals[cls_type_and_name[1]] = cls_type_and_name[0] hints = get_type_hints(fn, hint_globals, hint_locals) - # If this is an init fn use the first arg as the return type - if is_init: - if not isinstance(first_arg, (ClassTypeVarRef, TypeRefWithVars)): - raise ValueError("Init function must have a self type") - return_type = first_arg - else: - return_type = self._resolve_type_annotation(hints["return"], cls_typevars, cls_type_and_name) params = list(signature(fn).parameters.values()) arg_names = tuple(t.name for t in params) @@ -421,6 +428,17 @@ def _register_function( if isinstance(first_arg, (ClassTypeVarRef, TypeRefWithVars)) and not is_init: arg_types = (first_arg,) + arg_types + # If this is an init fn use the first arg as the return type + if is_init: + assert not mutates_first_arg + if not isinstance(first_arg, (ClassTypeVarRef, TypeRefWithVars)): + raise ValueError("Init function must have a self type") + return_type = first_arg + elif mutates_first_arg: + return_type = arg_types[0] + else: + return_type = self._resolve_type_annotation(hints["return"], cls_typevars, cls_type_and_name) + default_decl = None if default is None else default.__egg_typed_expr__.expr merge_decl = ( None @@ -446,6 +464,7 @@ def _register_function( arg_types=arg_types, arg_names=arg_names, arg_defaults=arg_defaults, + mutates_first_arg=mutates_first_arg, ) self._process_commands( self._mod_decls.register_function_callable( @@ -538,7 +557,9 @@ def relation(self, name: str, /, *tps: type, egg_fn: Optional[str] = None) -> Ca Defines a relation, which is the same as a function which returns unit. """ arg_types = tuple(self._resolve_type_annotation(cast(object, tp), [], None) for tp in tps) - fn_decl = FunctionDecl(arg_types, None, tuple(None for _ in tps), TypeRefWithVars("Unit")) + fn_decl = FunctionDecl( + arg_types, None, tuple(None for _ in tps), TypeRefWithVars("Unit"), mutates_first_arg=False + ) commands = self._mod_decls.register_function_callable( FunctionRef(name), fn_decl, egg_fn, cost=None, default=None, merge=None, merge_action=[] ) @@ -811,6 +832,7 @@ class _WrappedMethod(Generic[P, EXPR]): on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]] fn: Callable[P, EXPR] preserve: bool + mutates_self: bool def __call__(self, *args: P.args, **kwargs: P.kwargs) -> EXPR: raise NotImplementedError("We should never call a wrapped method. Did you forget to wrap the class?") @@ -965,7 +987,7 @@ class Rule(Command): def __str__(self) -> str: head_str = ", ".join(map(str, self.head)) body_str = ", ".join(map(str, self.body)) - return f"rule({head_str}).then({body_str})" + return f"rule({body_str}).then({head_str})" def _to_egg_command(self) -> bindings.RuleCommand: return bindings.RuleCommand( diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index dba97f9a..a192f165 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -49,7 +49,7 @@ ] -BLACK_MODE = black.Mode() # type: ignore +BLACK_MODE = black.Mode(line_length=180) # type: ignore UNIT_CLASS_NAME = "Unit" UNARY_LIT_CLASS_NAMES = {"i64", "f64", "String"} @@ -116,7 +116,7 @@ class RuntimeClass: __egg_decls__: ModuleDeclarations __egg_name__: str - def __call__(self, *args: object) -> RuntimeExpr: + def __call__(self, *args: object) -> Optional[RuntimeExpr]: """ Create an instance of this kind by calling the __init__ classmethod """ @@ -182,7 +182,7 @@ def __post_init__(self): if len(self.__egg_tp__.args) != desired_args: raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}") - def __call__(self, *args: object) -> RuntimeExpr: + def __call__(self, *args: object) -> Optional[RuntimeExpr]: return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), "__init__")(*args) def __getattr__(self, name: str) -> RuntimeClassMethod: @@ -215,7 +215,7 @@ def __post_init__(self): self.__egg_fn_ref__ = FunctionRef(self.__egg_name__) self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_fn_ref__) - def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr: + def __call__(self, *args: object, **kwargs: object) -> Optional[RuntimeExpr]: return _call(self.__egg_decls__, self.__egg_fn_ref__, self.__egg_fn_decl__, args, kwargs) def __str__(self) -> str: @@ -230,7 +230,7 @@ def _call( args: Collection[object], kwargs: dict[str, object], bound_params: Optional[tuple[JustTypeRef, ...]] = None, -) -> RuntimeExpr: +) -> Optional[RuntimeExpr]: # Turn all keyword args into positional args if fn_decl: @@ -238,8 +238,10 @@ def _call( bound.apply_defaults() assert not bound.kwargs args = bound.args + mutates_first_arg = fn_decl.mutates_first_arg else: assert not kwargs + mutates_first_arg = False upcasted_args: list[RuntimeExpr] if fn_decl is not None: upcasted_args = [ @@ -263,6 +265,12 @@ def _call( return_tp = JustTypeRef("Unit") expr_decl = CallDecl(callable_ref, arg_decls, bound_params) + typed_expr_decl = TypedExprDecl(return_tp, expr_decl) + if mutates_first_arg: + first_arg = upcasted_args[0] + first_arg.__egg_typed_expr__ = typed_expr_decl + first_arg.__egg_decls__ = decls + return None return RuntimeExpr(decls, TypedExprDecl(return_tp, expr_decl)) @@ -282,7 +290,7 @@ def __post_init__(self): except KeyError: raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}") - def __call__(self, *args: object, **kwargs) -> RuntimeExpr: + def __call__(self, *args: object, **kwargs) -> Optional[RuntimeExpr]: bound_params = self.__egg_tp__.args if isinstance(self.__egg_tp__, JustTypeRef) else None return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, kwargs, bound_params) @@ -298,14 +306,13 @@ def class_name(self) -> str: @dataclass class RuntimeMethod: - __egg_decls__: ModuleDeclarations - __egg_typed_expr__: TypedExprDecl + __egg_self__: RuntimeExpr __egg_method_name__: str __egg_callable_ref__: MethodRef | PropertyRef = field(init=False) __egg_fn_decl__: Optional[FunctionDecl] = field(init=False) def __post_init__(self): - if self.__egg_method_name__ in self.__egg_decls__.get_class_decl(self.class_name).properties: + if self.__egg_method_name__ in self.__egg_self__.__egg_decls__.get_class_decl(self.class_name).properties: self.__egg_callable_ref__ = PropertyRef(self.class_name, self.__egg_method_name__) else: self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__) @@ -315,18 +322,17 @@ def __post_init__(self): self.__egg_fn_decl__ = None else: try: - self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__) + self.__egg_fn_decl__ = self.__egg_self__.__egg_decls__.get_function_decl(self.__egg_callable_ref__) except KeyError: raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}") - def __call__(self, *args: object, **kwargs) -> RuntimeExpr: - first_arg = RuntimeExpr(self.__egg_decls__, self.__egg_typed_expr__) - args = (first_arg, *args) - return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, kwargs) + def __call__(self, *args: object, **kwargs) -> Optional[RuntimeExpr]: + args = (self.__egg_self__, *args) + return _call(self.__egg_self__.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, kwargs) @property def class_name(self) -> str: - return self.__egg_typed_expr__.tp.name + return self.__egg_self__.__egg_typed_expr__.tp.name @dataclass @@ -334,14 +340,14 @@ class RuntimeExpr: __egg_decls__: ModuleDeclarations __egg_typed_expr__: TypedExprDecl - def __getattr__(self, name: str) -> RuntimeMethod | RuntimeExpr | Callable: + def __getattr__(self, name: str) -> RuntimeMethod | RuntimeExpr | Callable | None: class_decl = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name) preserved_methods = class_decl.preserved_methods if name in preserved_methods: return preserved_methods[name].__get__(self) - method = RuntimeMethod(self.__egg_decls__, self.__egg_typed_expr__, name) + method = RuntimeMethod(self, name) if isinstance(method.__egg_callable_ref__, PropertyRef): return method() return method @@ -353,13 +359,16 @@ def __repr__(self) -> str: return str(self) def __str__(self) -> str: - pretty_expr = self.__egg_typed_expr__.expr.pretty(self.__egg_decls__, parens=False) + context = PrettyContext(self.__egg_decls__) + pretty_expr = self.__egg_typed_expr__.expr.pretty(context, parens=False) try: if config.SHOW_TYPES: - s = f"_: {self.__egg_typed_expr__.tp.pretty()} = {pretty_expr}" - return black.format_str(s, mode=black.FileMode()).strip() + raise NotImplementedError() + # s = f"_: {self.__egg_typed_expr__.tp.pretty()} = {pretty_expr}" + # return black.format_str(s, mode=black.FileMode()).strip() else: - return black.format_str(pretty_expr, mode=black.FileMode(line_length=180)).strip() + pretty_statements = context.render(pretty_expr) + return black.format_str(pretty_statements, mode=BLACK_MODE).strip() except black.parsing.InvalidInput: return pretty_expr @@ -378,22 +387,31 @@ def __eq__(self, other: NoReturn) -> Expr: # type: ignore "Do not use == on RuntimeExpr. Compare the __egg_typed_expr__ attribute instead for structural equality." ) + # Implement these so that copy() works on this object + # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion + + def __getstate__(self): + return (self.__egg_decls__, self.__egg_typed_expr__) + + def __setstate__(self, d): + self.__egg_decls__, self.__egg_typed_expr__ = d + # Define each of the special methods, since we have already declared them for pretty printing -for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call__"]: +for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call__", "__setitem__", "__delitem__"]: - def _special_method(self: RuntimeExpr, *args: object, __name: str = name) -> RuntimeExpr: + def _special_method(self: RuntimeExpr, *args: object, __name: str = name) -> Optional[RuntimeExpr]: # First, try to resolve as preserved method try: method = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).preserved_methods[__name] except KeyError: - return RuntimeMethod(self.__egg_decls__, self.__egg_typed_expr__, __name)(*args) + return RuntimeMethod(self, __name)(*args) else: return method(self, *args) setattr(RuntimeExpr, name, _special_method) -for name in ["__bool__", "__len__", "__complex__", "__int__", "__float__", "__hash__", "__iter__"]: +for name in ["__bool__", "__len__", "__complex__", "__int__", "__float__", "__hash__", "__iter__", "__index__"]: def _preserved_method(self: RuntimeExpr, __name: str = name): try: @@ -414,7 +432,7 @@ def _resolve_callable(callable: object) -> CallableRef: if isinstance(callable, RuntimeClassMethod): return ClassMethodRef(callable.class_name, callable.__egg_method_name__) if isinstance(callable, RuntimeMethod): - return MethodRef(callable.__egg_typed_expr__.tp.name, callable.__egg_method_name__) + return MethodRef(callable.__egg_self__.__egg_typed_expr__.tp.name, callable.__egg_method_name__) if isinstance(callable, RuntimeClass): return ClassMethodRef(callable.__egg_name__, "__init__") raise NotImplementedError(f"Cannot turn {callable} into a callable ref") diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 335e178b..a3facb54 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -2,10 +2,18 @@ import importlib import pathlib +from copy import copy from typing import ClassVar import pytest from egglog import * +from egglog.declarations import ( + CallDecl, + FunctionRef, + JustTypeRef, + MethodRef, + TypedExprDecl, +) EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py")) @@ -331,3 +339,57 @@ def test_f64_negation() -> None: def test_not_equals(): egraph = EGraph() egraph.check(i64(10) != i64(2)) + + +class TestMutate: + def test_setitem_defaults(self): + egraph = EGraph() + + @egraph.class_ + class Foo(Expr): + def __init__(self) -> None: + ... + + def __setitem__(self, key: i64Like, value: i64Like) -> None: + ... + + foo = Foo() + foo[10] = 20 + assert str(foo) == "_Foo_1 = copy(Foo())\n_Foo_1[10] = 20\n_Foo_1" + assert expr_parts(foo) == TypedExprDecl( + JustTypeRef("Foo"), + CallDecl(MethodRef("Foo", "__setitem__"), (expr_parts(Foo()), expr_parts(i64(10)), expr_parts(i64(20)))), + ) + + def test_function(self): + egraph = EGraph() + + @egraph.class_ + class Math(Expr): + def __init__(self, i: i64Like) -> None: + ... + + def __add__(self, other: Math) -> Math: # type: ignore[empty-body] + ... + + @egraph.function(mutates_first_arg=True) + def incr(x: Math) -> None: + ... + + x = Math(i64(10)) + x_copied = copy(x) + incr(x) + assert expr_parts(x_copied) == expr_parts(Math(i64(10))) + assert expr_parts(x) == TypedExprDecl( + JustTypeRef("Math"), + CallDecl(FunctionRef("incr"), (expr_parts(x_copied),)), + ) + assert str(x) == "_Math_1 = copy(Math(10))\nincr(_Math_1)\n_Math_1" + assert str(x + Math(10)) == "_Math_1 = copy(Math(10))\nincr(_Math_1)\n_Math_1 + Math(10)" + + i, j = vars_("i j", Math) + incr_i = copy(i) + incr(incr_i) + egraph.register(rewrite(incr_i).to(i + Math(1)), x) + egraph.run(10) + egraph.check(eq(x).to(Math(10) + Math(1))) diff --git a/python/tests/test_runtime.py b/python/tests/test_runtime.py index 225757f7..10263320 100644 --- a/python/tests/test_runtime.py +++ b/python/tests/test_runtime.py @@ -32,13 +32,14 @@ def test_function_call(): (), (), TypeRefWithVars("i64"), + False, ), }, ) ) one = RuntimeFunction(decls, "one") assert ( - one().__egg_typed_expr__ + one().__egg_typed_expr__ # type: ignore == RuntimeExpr(decls, TypedExprDecl(JustTypeRef("i64"), CallDecl(FunctionRef("one")))).__egg_typed_expr__ ) @@ -60,6 +61,7 @@ def test_classmethod_call(): (), (), TypeRefWithVars("Map", (K, V)), + False, ) }, ), @@ -72,7 +74,7 @@ def test_classmethod_call(): i64 = RuntimeClass(decls, "i64") unit = RuntimeClass(decls, "unit") assert ( - Map[i64, unit].create().__egg_typed_expr__ + Map[i64, unit].create().__egg_typed_expr__ # type: ignore == RuntimeExpr( decls, TypedExprDecl( @@ -98,6 +100,7 @@ def test_expr_special(): (), (None, None), TypeRefWithVars("i64"), + False, ) }, class_methods={ @@ -106,6 +109,7 @@ def test_expr_special(): (), (None,), TypeRefWithVars("i64"), + False, ) }, ),