-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #125 from JakeGinesin/main
Coq and Lean
- Loading branch information
Showing
5 changed files
with
327 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import re | ||
import ast | ||
from typing import List | ||
|
||
docstring_linestart_re = re.compile("""\n(\s+)*""") | ||
|
||
class Translator: | ||
'''Translate Python to Coq | ||
''' | ||
|
||
stop = ["\nFixpoint", "\nDefinition", "\nExample", "\nProof", "\nAxiom", "\n(*"] | ||
|
||
def translate_prompt(self, name: str, args: List[ast.arg], returns: ast.expr, description: str) -> str: | ||
coq_description = "(*\n" + description + "\n*)\n" | ||
self.fn_name = name | ||
self.type = [[arg.annotation for arg in args], returns] | ||
|
||
prefix = "Require Import ZArith.\nRequire Import Reals.\nRequire Import Coq.Strings.String.\nOpen Scope string_scope.\n\n" | ||
|
||
def translate_arg(arg): | ||
ty = " : " + self.pytype_to_coqtype(arg.annotation) | ||
return "(" + arg.arg + ty + ")" | ||
|
||
coq_args = " ".join(map(str,[translate_arg(arg) for arg in args])) | ||
coq_ret = self.pytype_to_coqtype(returns) | ||
|
||
return f"{prefix}{coq_description}Definition {name} {coq_args} : {coq_ret} :=\n" | ||
|
||
def pytype_to_coqtype(self, ann: ast.expr | None) -> str: | ||
""" | ||
Traverses AST and translates Python type annotation to Lean type annotation | ||
""" | ||
|
||
if ann == None : raise Exception(f"No annotation") | ||
|
||
match ann: | ||
# Subscripts | ||
case ast.Subscript(ast.Name(id), slice, ctx): | ||
match id: | ||
case "List": | ||
return "list " + self.pytype_to_coqtype(slice) | ||
case "Union": | ||
raise Exception("Coq has no support for untagged unions.") | ||
case "Tuple": | ||
match slice: | ||
case ast.Tuple(elts, _ctx): | ||
tys = [self.pytype_to_coqtype(elem) for elem in elts] | ||
return f"({' × '.join(tys)})" | ||
case other: | ||
raise Exception(f"Bad tuple: {slice}") | ||
case "Dict": | ||
raise Exception("Coq has no support for dictionaries. Yikes.") | ||
case "Optional": | ||
return "option " + self.pytype_to_coqtype(slice) | ||
case other: | ||
raise Exception(f"Bad generic {other}") | ||
|
||
# Literals | ||
case ast.Name(id="str") | "str": | ||
return "string" | ||
case ast.Name(id="int") | "int": | ||
return "Z" | ||
case ast.Name(id="float") | "float": | ||
return "R" | ||
case ast.Name(id="bool") | "bool": | ||
return "bool" | ||
case ast.Name(id="None") | "None": | ||
raise Exception("Coq does not have type None") | ||
|
||
# Misc | ||
case None: | ||
raise Exception("Implicitly untyped argument None") | ||
case ast.Name("Any"): | ||
raise Exception("Coq does not have Any") | ||
case ast.Name(x): | ||
raise Exception(f"Unknown name {x}") | ||
case ast.Constant(Ellipsis): | ||
raise Exception("No ellipsis!") | ||
case _other: | ||
raise Exception(f"Unknown annotation: {ann}") | ||
|
||
def test_suite_prefix_lines(self, _): | ||
return [] | ||
|
||
def test_suite_suffix_lines(self): | ||
return [] | ||
|
||
def file_ext(self): | ||
return "v" | ||
|
||
def __init__(self): | ||
self.type = None | ||
self.testnum = 0 | ||
|
||
def deep_equality(self, left, right): | ||
""" | ||
we use the reflexivity tactic to ensure type equivalence. | ||
Similar to Lean. See humaneval_to_lean for more. | ||
""" | ||
self.testnum+=1 | ||
return f"Example t{self.testnum} : {left} = {right}. Proof. reflexivity. Qed." | ||
|
||
def gen_var(self, name: str): | ||
return name | ||
|
||
# TODO: find some better alternative to define arbitrarily typed list in single line | ||
def gen_list(self, elements: List[str]): | ||
return "(cons " + " (cons ".join(map(str,elements)) + " nil" + ")"*(len(elements) + 1) | ||
|
||
def gen_tuple(self, elements: List[str]): | ||
return f"({' × '.join(elements)})" | ||
|
||
def gen_literal(self, c: bool | str | int | float | None): | ||
# switch doesn't work here for some reason xd idk python well enough | ||
if type(c) == bool: | ||
return str(c) | ||
if type(c) == str: | ||
return f'"{c}"' | ||
if type(c) == int: | ||
return f"{c}%Z" | ||
if type(c) == float: | ||
return f"{c}%R" | ||
|
||
return "Nothing" | ||
|
||
def gen_call(self, func: str, args: List[str]): | ||
#args = [coerce(arg, self.type[0][i]) for i, arg in enumerate(args)] | ||
return self.fn_name + " " + " ".join(map(str,args)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import re | ||
import ast | ||
from typing import List | ||
from types import NoneType | ||
|
||
docstring_linestart_re = re.compile("""\n(\s+)*""") | ||
|
||
class Translator: | ||
'''Translate Python to Lean. | ||
''' | ||
|
||
stop = ["\n//", "\n/*", "\nfunction", "\nmethod", "\ntheorem", "\nlemma", "\nprotected def", "\ndef"] | ||
|
||
def translate_prompt(self, name: str, args: List[ast.arg], returns: ast.expr, description: str) -> str: | ||
lean_description = "/-\n" + description + "\n-/\n" | ||
self.fn_name = name | ||
self.type = [[arg.annotation for arg in args], returns] | ||
|
||
def translate_arg(arg): | ||
ty = " : " + self.pytype_to_leantype(arg.annotation) | ||
return "(" + arg.arg + ty + ")" | ||
|
||
lean_args = " ".join(map(str,[translate_arg(arg) for arg in args])) | ||
lean_ret = self.pytype_to_leantype(returns) | ||
|
||
return f"{lean_description}def {name} {lean_args} : {lean_ret} :=\n" | ||
|
||
def pytype_to_leantype(self, ann: ast.expr | None) -> str: | ||
""" | ||
Traverses AST and translates Python type annotation to Lean type annotation | ||
""" | ||
|
||
if ann == None : raise Exception(f"No annotation") | ||
|
||
match ann: | ||
# Subscripts | ||
case ast.Subscript(ast.Name(id), slice, ctx): | ||
match id: | ||
case "List": | ||
return "list " + self.pytype_to_leantype(slice) | ||
case "Union": | ||
raise Exception("Lean has no support for untagged unions.") | ||
case "Tuple": | ||
match slice: | ||
case ast.Tuple(elts, _ctx): | ||
tys = [self.pytype_to_leantype(elem) for elem in elts] | ||
return f"({' × '.join(tys)})" | ||
case other: | ||
raise Exception(f"Bad tuple: {slice}") | ||
case "Dict": | ||
raise Exception("Lean has no support for dictionaries. Yikes.") | ||
case "Optional": | ||
return "Option " + self.pytype_to_leantype(slice) | ||
case other: | ||
raise Exception(f"Bad generic {other}") | ||
|
||
# Literals | ||
case ast.Name(id="str") | "str": | ||
return "String" | ||
case ast.Name(id="int") | "int": | ||
return "Int" | ||
case ast.Name(id="float") | "float": | ||
raise Exception("Lean does not have inherent support for Reals (you need mathlib).") | ||
case ast.Name(id="bool") | "bool": | ||
return "Bool" | ||
case ast.Name(id="None") | "None": | ||
raise Exception("Lean does not have type None") | ||
|
||
# Misc | ||
case None: | ||
raise Exception("Implicitly untyped argument None") | ||
case ast.Name("Any"): | ||
raise Exception("Lean does not have Any") | ||
case ast.Name(x): | ||
raise Exception(f"Unknown name {x}") | ||
case ast.Constant(Ellipsis): | ||
raise Exception("No ellipsis!") | ||
case _other: | ||
raise Exception(f"Unknown annotation: {ann}") | ||
|
||
def __init__(self): | ||
self.type = None | ||
|
||
def file_ext(self): | ||
return "lean" | ||
|
||
def deep_equality(self, left, right): | ||
""" | ||
In Lean we acheive the notion of "deeper equality" via the reflexivity tactic | ||
and the type equivalence-checking "=" rather than programmically via "==" | ||
if we *were* to do it programmically, you'd have something like: | ||
def main : IO Unit := | ||
if [assertion here] then pure () else throw (IO.userError "assertion error") | ||
and, we'd run | ||
$ lean --run file.lean | ||
(also note, the "pure ()" business is via the IO monad in Lean) | ||
""" | ||
return f"example: {left} = {right} := by rfl" | ||
|
||
def test_suite_prefix_lines(self, _): | ||
return [] | ||
|
||
def test_suite_suffix_lines(self): | ||
return [] | ||
|
||
def gen_list(self, elements: List[str]): | ||
return f"[{', '.join(elements)}]" | ||
|
||
def gen_tuple(self, elements: List[str]): | ||
return f"({' × '.join(elements)})" | ||
|
||
# TODO: I think this is fine for the most part... need to verify | ||
def gen_literal(self, c: bool | str | int | float | None): | ||
return repr(c) | ||
|
||
def gen_var(self, name: str): | ||
return name | ||
|
||
def gen_call(self, func: str, args: List[str]): | ||
# args = [coerce(arg, self.type[0][i]) for i, arg in enumerate(args)] | ||
return self.fn_name + " " + " ".join(map(str,args)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from pathlib import Path | ||
from safe_subprocess import run | ||
import subprocess | ||
|
||
# return codes for coqc: | ||
# 0: compilation goes through | ||
# 1: some sort of error (nondescript) | ||
|
||
def eval_script(path: Path): | ||
cleanup_extensions = ['.vo', '.vok', '.vos'] | ||
|
||
try: | ||
# sadly there seems to be no way to verify proofs in a coq file without compiling | ||
output = subprocess.run(["coqc -noglob", str(path)], capture_output=True, timeout=5) | ||
outmessage = str(output) | ||
|
||
if output.returncode == 0: | ||
status = "OK" | ||
# cleanup: remove files generated by coqc | ||
for ext in cleanup_extensions: | ||
file_to_remove = path.with_suffix(ext) | ||
if file_to_remove.exists(): | ||
file_to_remove.unlink() | ||
|
||
elif "Unable to unify" in outmessage: | ||
status = "AssertionError" | ||
else: | ||
status = "SyntaxError" | ||
returncode = output.returncode | ||
|
||
except subprocess.TimeoutExpired as exc: | ||
status = "Timeout" | ||
output = exc | ||
returncode = -1 | ||
return { | ||
"status": status, | ||
"exit_code": returncode, | ||
"stdout": "" if output.stdout is None else output.stdout.decode("utf-8"), | ||
"stderr": "" if output.stderr is None else output.stderr.decode("utf-8"), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from pathlib import Path | ||
from safe_subprocess import run | ||
import subprocess | ||
|
||
def eval_script(path: Path): | ||
# since lean is a theorem prover first and not a programming environment, | ||
# the return code is always 1. idk. | ||
try: | ||
output = subprocess.run(["lean", str(path)], capture_output=True, timeout=5) | ||
outmessage = str(output) | ||
|
||
if "error: tactic 'rfl' failed" in outmessage: # :skull: | ||
status = "AssertionError" | ||
elif outmessage == "": | ||
status = "OK" | ||
else: | ||
status = "SyntaxError" | ||
returncode = output.returncode | ||
|
||
except subprocess.TimeoutExpired as exc: | ||
status = "Timeout" | ||
output = exc | ||
returncode = -1 | ||
return { | ||
"status": status, | ||
"exit_code": returncode, | ||
"stdout": "" if output.stdout is None else output.stdout.decode("utf-8"), | ||
"stderr": "" if output.stderr is None else output.stderr.decode("utf-8"), | ||
} |