Skip to content

Commit

Permalink
def_fun_prop and def_fun_trans tests and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Oct 3, 2024
1 parent 883f600 commit 4283759
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 351 deletions.
86 changes: 53 additions & 33 deletions SciLean/Meta/GenerateFunProp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ namespace FunProp

syntax Parser.suffix := "add_suffix" ident
syntax Parser.trans := "with_transitive"
syntax Parser.config := Parser.suffix <|> Parser.trans
syntax Parser.argSubsets := "arg_subsets"
syntax Parser.config := Parser.suffix <|> Parser.trans <|> Parser.argSubsets

syntax Parser.funPropProof := "by" tacticSeq

Expand All @@ -197,6 +198,7 @@ open Lean
structure DefFunPropConfig where
withTransitive := false
suffix : Option Name := none
argSubsets := false

open Lean Syntax Elab
def parseDefFunPropConfig (stx : TSyntaxArray ``Parser.config) : MetaM DefFunPropConfig := do
Expand All @@ -211,6 +213,7 @@ def parseDefFunPropConfig (stx : TSyntaxArray ``Parser.config) : MetaM DefFunPro
throwErrorAt s.raw s!"suffix already specified as `{cfg.suffix.get!}`"
pure {cfg with suffix := id.getId}
| `(Parser.trans| with_transitive) => pure {cfg with withTransitive := true}
| `(Parser.argSubsets| arg_subsets) => pure {cfg with argSubsets := true}
| _ => throwErrorAt s.raw "invalid option {s}"

return cfg
Expand All @@ -228,44 +231,17 @@ def parseFunPropTactic (fId : Name) (stx : Option (TSyntax ``Parser.funPropProof



/-- Define function property for a function in particular arguments.
Example:
```
def foo (x y z : ℝ) := x*x+y*z
def_fun_prop foo in x y z : Continuous
```
Proves continuity of `foo` in `x`, `y` and `z` as theorem `foo.arg_xyz.Continuous_rule`.
You can add additional assumptions, custom tactic to prove the property as demonstrated by the
following example:
```
def_fun_prop bar in x y
add_suffix _simple
with_transitive
(xy : R×R) (h : xy.2 ≠ 0) : (DifferentiableAt R · xy) by unfold bar; fun_prop (disch:=assumption)
```
where
- `add_suffix _simple` adds `_simple` to the end of the generated theorems
- `with_transitive` also generates all theorems that can be infered from the current theorem.
For example, `DifferentiableAt` implies `ContinuousAt`. All `fun_prop` transition theorems
are tried to infer additional function properties.
- `(xy : R×R) (h : xy.2 ≠ 0)` are additional assumptions added to the theorem. These assumptions
are stated in the context of the function so for example here we can use `R` without introducing it.
- `by unfold bar; fun_prop ...` you can specify custom tactic to prove the function property.
-/
elab "def_fun_prop " f:ident "in" args:ident* ppLine
cfg:Parser.config*
bs:bracketedBinder* " : " fprop:term proofTactic:(Parser.funPropProof)? : command => do

open Lean Meta Elab Term in
def defFunProp (f : Ident) (args : TSyntaxArray `ident)
(cfg : TSyntaxArray ``Parser.config) (bs : TSyntaxArray ``Parser.Term.bracketedBinder)
(fprop : TSyntax `term) (proof : Option (TSyntax ``Parser.funPropProof)) : Command.CommandElabM Unit := do
Elab.Command.liftTermElabM <| do
-- resolve function name
let fId ← resolveUniqueNamespace f
let info ← getConstInfo fId

let cfg ← parseDefFunPropConfig cfg
let tac ← parseFunPropTactic fId proofTactic
let tac ← parseFunPropTactic fId proof

forallTelescope info.type fun xs _ => do
Elab.Term.elabBinders bs.raw fun ctx₂ => do
Expand Down Expand Up @@ -309,3 +285,47 @@ elab "def_fun_prop " f:ident "in" args:ident* ppLine
defineTransitiveFunProp proof ctx cfg.suffix

pure ()



/-- Define function property for a function in particular arguments.
Example:
```
def foo (x y z : ℝ) := x*x+y*z
def_fun_prop foo in x y z : Continuous
```
Proves continuity of `foo` in `x`, `y` and `z` as theorem `foo.arg_xyz.Continuous_rule`.
You can add additional assumptions, custom tactic to prove the property as demonstrated by the
following example:
```
def_fun_prop bar in x y
add_suffix _simple
with_transitive
(xy : R×R) (h : xy.2 ≠ 0) : (DifferentiableAt R · xy) by unfold bar; fun_prop (disch:=assumption)
```
where
- `add_suffix _simple` adds `_simple` to the end of the generated theorems
- `with_transitive` also generates all theorems that can be infered from the current theorem.
For example, `DifferentiableAt` implies `ContinuousAt`. All `fun_prop` transition theorems
are tried to infer additional function properties.
- `(xy : R×R) (h : xy.2 ≠ 0)` are additional assumptions added to the theorem. These assumptions
are stated in the context of the function so for example here we can use `R` without introducing it.
- `by unfold bar; fun_prop ...` you can specify custom tactic to prove the function property.
-/
elab "def_fun_prop " f:ident "in" args:ident* ppLine
cfg:Parser.config*
bs:bracketedBinder* " : " fprop:term proof:(Parser.funPropProof)? : command => do

let c ← Lean.Elab.Command.liftTermElabM <| parseDefFunPropConfig cfg

defFunProp f args cfg bs fprop proof

-- generate the same with all argument subsets
if c.argSubsets then
for as in args.allSubsets do
if as.size = 0 || as.size = args.size then
continue
defFunProp f as cfg bs fprop proof
17 changes: 0 additions & 17 deletions SciLean/Meta/GenerateFunTrans.lean
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ def generateFunTransDefAndTheorem (statement proof : Expr) (ctx : Array Expr)
levelParams := lvls.toList
}

IO.println (← ppExpr thmType)
IO.println (← ppExpr thmVal)

addDecl (.thmDecl thmDecl)
FunTrans.addTheorem thmName

Expand Down Expand Up @@ -278,20 +275,6 @@ def defFunTrans (f : Ident) (args : TSyntaxArray `ident)
pure ()


def _root_.Array.allSubsets {α} (a : Array α) : Array (Array α) := Id.run do
let mut as : Array (Array α) := #[]
let n := a.size
for i in [0:2^n] do
let mut ai : Array α := #[]
for h : j in [0:a.size] do
if (2^j).toUInt64 &&& i.toUInt64 ≠ 0 then
ai := ai.push (a[j])

as := as.push ai
return as



open Lean Meta Elab Term in
/-- Define function transformation for a function in particular arguments.
Expand Down
178 changes: 178 additions & 0 deletions test/def_fun_prop_trans.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import SciLean

open SciLean

variable
{K : Type} [RCLike K]
{X : Type} [NormedAddCommGroup X] [AdjointSpace K X]
{Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y]
{Z : Type} [NormedAddCommGroup Z] [AdjointSpace K Z]

set_default_scalar K

namespace DefFunPropTransTest

def mul (x y : K) : K := x * y


example : Differentiable K (fun xy : K×K => xy.1 * xy.2) := by fun_prop


def_fun_prop mul in x y
with_transitive
arg_subsets
: Differentiable K

def_fun_trans mul in x y
arg_subsets
: revFDeriv K


/--
info: DefFunPropTransTest.mul.arg_xy.Continuous_rule {K : Type} [RCLike K] : Continuous fun xy => mul xy.1 xy.2
-/
#guard_msgs in
#check mul.arg_xy.Continuous_rule

/--
info: DefFunPropTransTest.mul.arg_xy.Differentiable_rule {K : Type} [RCLike K] : Differentiable K fun xy => mul xy.1 xy.2
-/
#guard_msgs in
#check mul.arg_xy.Differentiable_rule


/--
info: DefFunPropTransTest.mul.arg_x.Differentiable_rule {K : Type} [RCLike K] (y : K) : Differentiable K fun x => mul x y
-/
#guard_msgs in
#check mul.arg_x.Differentiable_rule

/--
info: DefFunPropTransTest.mul.arg_y.Differentiable_rule {K : Type} [RCLike K] (x : K) : Differentiable K fun y => mul x y
-/
#guard_msgs in
#check mul.arg_y.Differentiable_rule


/--
info: DefFunPropTransTest.mul.arg_xy.revFDeriv_rule {K : Type} [RCLike K] : <∂ xy, mul xy.1 xy.2 = mul.arg_xy.revFDeriv
-/
#guard_msgs in
#check mul.arg_xy.revFDeriv_rule

/--
info: DefFunPropTransTest.mul.arg_xy.revFDeriv {K : Type} [RCLike K] (x : K × K) : K × (K → K × K)
-/
#guard_msgs in
#check mul.arg_xy.revFDeriv


example :
(revFDeriv K fun (x : K) => mul x x)
=
(fun x =>
let zdf := mul.arg_xy.revFDeriv (x, x);
(zdf.1, fun dz =>
let dy := zdf.2 dz;
dy.1 + dy.2)) := by autodiff; simp


def add (x y : X) : X := x + y

def_fun_prop add in x y
with_transitive
arg_subsets
{K} [RCLike K] [NormedSpace K X] : Differentiable K

def_fun_trans add in x y
arg_subsets
{K : Type} [RCLike K] [AdjointSpace K X] [CompleteSpace X] : revFDeriv K


/--
info: DefFunPropTransTest.add.arg_xy.Differentiable_rule.{u_1} {X : Type} [NormedAddCommGroup X] {K : Type u_1} [RCLike K]
[NormedSpace K X] : Differentiable K fun xy => add xy.1 xy.2
-/
#guard_msgs in
#check add.arg_xy.Differentiable_rule

/--
info: DefFunPropTransTest.add.arg_xy.revFDeriv_rule {X : Type} [NormedAddCommGroup X] {K : Type} [RCLike K] [AdjointSpace K X]
[CompleteSpace X] : <∂ xy, add xy.1 xy.2 = add.arg_xy.revFDeriv
-/
#guard_msgs in
#check add.arg_xy.revFDeriv_rule


def smul {X : Type} [SemiHilbert K X]
(x : K) (y : X) : X := x • y


example :
(revFDeriv K fun (x : K) =>
let x1 := mul x x
let x2 := mul x1 x1
let x3 := mul x2 x2
let x4 := mul x3 x3
let x5 := mul x4 x4
x5)
=
fun x =>
let zdf := mul.arg_xy.revFDeriv (x, x);
let ydg := zdf.1;
let zdf_1 := mul.arg_xy.revFDeriv (ydg, ydg);
let ydg := zdf_1.1;
let zdf_2 := mul.arg_xy.revFDeriv (ydg, ydg);
let ydg := zdf_2.1;
let zdf_3 := mul.arg_xy.revFDeriv (ydg, ydg);
let ydg := zdf_3.1;
let zdf_4 := mul.arg_xy.revFDeriv (ydg, ydg);
let ydg := zdf_4.1;
(ydg, fun dz =>
let dy := zdf_4.2 dz;
let dy := dy.1 + dy.2;
let dy := zdf_3.2 dy;
let dy := dy.1 + dy.2;
let dy := zdf_2.2 dy;
let dy := dy.1 + dy.2;
let dy := zdf_1.2 dy;
let dy := dy.1 + dy.2;
let dy := zdf.2 dy;
dy.1 + dy.2) :=
by
conv => lhs; autodiff

example :
(revFDeriv K fun (x : K) =>
let x1 := mul x x
let x2 := mul x1 (mul x x)
let x3 := mul x2 (mul x1 x)
x3)
=
fun x =>
let zdf := mul.arg_xy.revFDeriv (x, x);
let ydg := zdf.1;
let zdf_1 := mul.arg_xy.revFDeriv (x, x);
let zdf_2 := zdf_1.1;
let zdf_3 := mul.arg_xy.revFDeriv (ydg, zdf_2);
let ydg_1 := zdf_3.1;
let zdf_4 := mul.arg_xy.revFDeriv (ydg, x);
let zdf_5 := zdf_4.1;
let zdf_6 := mul.arg_xy.revFDeriv (ydg_1, zdf_5);
let ydg := zdf_6.1;
(ydg, fun dz =>
let dy := zdf_6.2 dz;
let dy_1 := zdf_4.2 dy.2;
let dxdy := dy_1.2;
let dxdy_1 := dy_1.1;
let dxdy_2 := dy.1;
let dy := zdf_3.2 dxdy_2;
let dy_2 := zdf_1.2 dy.2;
let dx := dy_2.1 + dy_2.2;
let dx_1 := dy.1;
let dxdy := dxdy + dx;
let dxdy_3 := dxdy_1 + dx_1;
let dy := zdf.2 dxdy_3;
let dx := dy.1 + dy.2;
dxdy + dx) := by
conv => lhs; autodiff
Loading

0 comments on commit 4283759

Please sign in to comment.