Skip to content

Commit

Permalink
Merge pull request #125 from JakeGinesin/main
Browse files Browse the repository at this point in the history
Coq and Lean
  • Loading branch information
cassanof authored Feb 18, 2024
2 parents ab441c1 + f4ef7e4 commit 773d609
Show file tree
Hide file tree
Showing 5 changed files with 327 additions and 0 deletions.
128 changes: 128 additions & 0 deletions dataset_builder/humaneval_to_coq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import re
import ast
from typing import List

docstring_linestart_re = re.compile("""\n(\s+)*""")

class Translator:
'''Translate Python to Coq
'''

stop = ["\nFixpoint", "\nDefinition", "\nExample", "\nProof", "\nAxiom", "\n(*"]

def translate_prompt(self, name: str, args: List[ast.arg], returns: ast.expr, description: str) -> str:
coq_description = "(*\n" + description + "\n*)\n"
self.fn_name = name
self.type = [[arg.annotation for arg in args], returns]

prefix = "Require Import ZArith.\nRequire Import Reals.\nRequire Import Coq.Strings.String.\nOpen Scope string_scope.\n\n"

def translate_arg(arg):
ty = " : " + self.pytype_to_coqtype(arg.annotation)
return "(" + arg.arg + ty + ")"

coq_args = " ".join(map(str,[translate_arg(arg) for arg in args]))
coq_ret = self.pytype_to_coqtype(returns)

return f"{prefix}{coq_description}Definition {name} {coq_args} : {coq_ret} :=\n"

def pytype_to_coqtype(self, ann: ast.expr | None) -> str:
"""
Traverses AST and translates Python type annotation to Lean type annotation
"""

if ann == None : raise Exception(f"No annotation")

match ann:
# Subscripts
case ast.Subscript(ast.Name(id), slice, ctx):
match id:
case "List":
return "list " + self.pytype_to_coqtype(slice)
case "Union":
raise Exception("Coq has no support for untagged unions.")
case "Tuple":
match slice:
case ast.Tuple(elts, _ctx):
tys = [self.pytype_to_coqtype(elem) for elem in elts]
return f"({' × '.join(tys)})"
case other:
raise Exception(f"Bad tuple: {slice}")
case "Dict":
raise Exception("Coq has no support for dictionaries. Yikes.")
case "Optional":
return "option " + self.pytype_to_coqtype(slice)
case other:
raise Exception(f"Bad generic {other}")

# Literals
case ast.Name(id="str") | "str":
return "string"
case ast.Name(id="int") | "int":
return "Z"
case ast.Name(id="float") | "float":
return "R"
case ast.Name(id="bool") | "bool":
return "bool"
case ast.Name(id="None") | "None":
raise Exception("Coq does not have type None")

# Misc
case None:
raise Exception("Implicitly untyped argument None")
case ast.Name("Any"):
raise Exception("Coq does not have Any")
case ast.Name(x):
raise Exception(f"Unknown name {x}")
case ast.Constant(Ellipsis):
raise Exception("No ellipsis!")
case _other:
raise Exception(f"Unknown annotation: {ann}")

def test_suite_prefix_lines(self, _):
return []

def test_suite_suffix_lines(self):
return []

def file_ext(self):
return "v"

def __init__(self):
self.type = None
self.testnum = 0

def deep_equality(self, left, right):
"""
we use the reflexivity tactic to ensure type equivalence.
Similar to Lean. See humaneval_to_lean for more.
"""
self.testnum+=1
return f"Example t{self.testnum} : {left} = {right}. Proof. reflexivity. Qed."

def gen_var(self, name: str):
return name

# TODO: find some better alternative to define arbitrarily typed list in single line
def gen_list(self, elements: List[str]):
return "(cons " + " (cons ".join(map(str,elements)) + " nil" + ")"*(len(elements) + 1)

def gen_tuple(self, elements: List[str]):
return f"({' × '.join(elements)})"

def gen_literal(self, c: bool | str | int | float | None):
# switch doesn't work here for some reason xd idk python well enough
if type(c) == bool:
return str(c)
if type(c) == str:
return f'"{c}"'
if type(c) == int:
return f"{c}%Z"
if type(c) == float:
return f"{c}%R"

return "Nothing"

def gen_call(self, func: str, args: List[str]):
#args = [coerce(arg, self.type[0][i]) for i, arg in enumerate(args)]
return self.fn_name + " " + " ".join(map(str,args))
126 changes: 126 additions & 0 deletions dataset_builder/humaneval_to_lean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import re
import ast
from typing import List
from types import NoneType

docstring_linestart_re = re.compile("""\n(\s+)*""")

class Translator:
'''Translate Python to Lean.
'''

stop = ["\n//", "\n/*", "\nfunction", "\nmethod", "\ntheorem", "\nlemma", "\nprotected def", "\ndef"]

def translate_prompt(self, name: str, args: List[ast.arg], returns: ast.expr, description: str) -> str:
lean_description = "/-\n" + description + "\n-/\n"
self.fn_name = name
self.type = [[arg.annotation for arg in args], returns]

def translate_arg(arg):
ty = " : " + self.pytype_to_leantype(arg.annotation)
return "(" + arg.arg + ty + ")"

lean_args = " ".join(map(str,[translate_arg(arg) for arg in args]))
lean_ret = self.pytype_to_leantype(returns)

return f"{lean_description}def {name} {lean_args} : {lean_ret} :=\n"

def pytype_to_leantype(self, ann: ast.expr | None) -> str:
"""
Traverses AST and translates Python type annotation to Lean type annotation
"""

if ann == None : raise Exception(f"No annotation")

match ann:
# Subscripts
case ast.Subscript(ast.Name(id), slice, ctx):
match id:
case "List":
return "list " + self.pytype_to_leantype(slice)
case "Union":
raise Exception("Lean has no support for untagged unions.")
case "Tuple":
match slice:
case ast.Tuple(elts, _ctx):
tys = [self.pytype_to_leantype(elem) for elem in elts]
return f"({' × '.join(tys)})"
case other:
raise Exception(f"Bad tuple: {slice}")
case "Dict":
raise Exception("Lean has no support for dictionaries. Yikes.")
case "Optional":
return "Option " + self.pytype_to_leantype(slice)
case other:
raise Exception(f"Bad generic {other}")

# Literals
case ast.Name(id="str") | "str":
return "String"
case ast.Name(id="int") | "int":
return "Int"
case ast.Name(id="float") | "float":
raise Exception("Lean does not have inherent support for Reals (you need mathlib).")
case ast.Name(id="bool") | "bool":
return "Bool"
case ast.Name(id="None") | "None":
raise Exception("Lean does not have type None")

# Misc
case None:
raise Exception("Implicitly untyped argument None")
case ast.Name("Any"):
raise Exception("Lean does not have Any")
case ast.Name(x):
raise Exception(f"Unknown name {x}")
case ast.Constant(Ellipsis):
raise Exception("No ellipsis!")
case _other:
raise Exception(f"Unknown annotation: {ann}")

def __init__(self):
self.type = None

def file_ext(self):
return "lean"

def deep_equality(self, left, right):
"""
In Lean we acheive the notion of "deeper equality" via the reflexivity tactic
and the type equivalence-checking "=" rather than programmically via "=="
if we *were* to do it programmically, you'd have something like:
def main : IO Unit :=
if [assertion here] then pure () else throw (IO.userError "assertion error")
and, we'd run
$ lean --run file.lean
(also note, the "pure ()" business is via the IO monad in Lean)
"""
return f"example: {left} = {right} := by rfl"

def test_suite_prefix_lines(self, _):
return []

def test_suite_suffix_lines(self):
return []

def gen_list(self, elements: List[str]):
return f"[{', '.join(elements)}]"

def gen_tuple(self, elements: List[str]):
return f"({' × '.join(elements)})"

# TODO: I think this is fine for the most part... need to verify
def gen_literal(self, c: bool | str | int | float | None):
return repr(c)

def gen_var(self, name: str):
return name

def gen_call(self, func: str, args: List[str]):
# args = [coerce(arg, self.type[0][i]) for i, arg in enumerate(args)]
return self.fn_name + " " + " ".join(map(str,args))

4 changes: 4 additions & 0 deletions evaluation/src/containerized_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import eval_matlab
import eval_hs
import eval_elixir
import eval_coq
import eval_lean
import tempfile


Expand Down Expand Up @@ -55,6 +57,8 @@
"m": (eval_matlab.eval_script, ".m"),
"hs": (eval_hs.eval_script, ".hs"),
"elixir": (eval_elixir.eval_script, ".exs"),
"coq": (eval_coq.eval_script, ".v"),
"lean": (eval_lean.eval_script, ".lean"),
}

def eval_string_script(language, program):
Expand Down
40 changes: 40 additions & 0 deletions evaluation/src/eval_coq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path
from safe_subprocess import run
import subprocess

# return codes for coqc:
# 0: compilation goes through
# 1: some sort of error (nondescript)

def eval_script(path: Path):
cleanup_extensions = ['.vo', '.vok', '.vos']

try:
# sadly there seems to be no way to verify proofs in a coq file without compiling
output = subprocess.run(["coqc -noglob", str(path)], capture_output=True, timeout=5)
outmessage = str(output)

if output.returncode == 0:
status = "OK"
# cleanup: remove files generated by coqc
for ext in cleanup_extensions:
file_to_remove = path.with_suffix(ext)
if file_to_remove.exists():
file_to_remove.unlink()

elif "Unable to unify" in outmessage:
status = "AssertionError"
else:
status = "SyntaxError"
returncode = output.returncode

except subprocess.TimeoutExpired as exc:
status = "Timeout"
output = exc
returncode = -1
return {
"status": status,
"exit_code": returncode,
"stdout": "" if output.stdout is None else output.stdout.decode("utf-8"),
"stderr": "" if output.stderr is None else output.stderr.decode("utf-8"),
}
29 changes: 29 additions & 0 deletions evaluation/src/eval_lean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from pathlib import Path
from safe_subprocess import run
import subprocess

def eval_script(path: Path):
# since lean is a theorem prover first and not a programming environment,
# the return code is always 1. idk.
try:
output = subprocess.run(["lean", str(path)], capture_output=True, timeout=5)
outmessage = str(output)

if "error: tactic 'rfl' failed" in outmessage: # :skull:
status = "AssertionError"
elif outmessage == "":
status = "OK"
else:
status = "SyntaxError"
returncode = output.returncode

except subprocess.TimeoutExpired as exc:
status = "Timeout"
output = exc
returncode = -1
return {
"status": status,
"exit_code": returncode,
"stdout": "" if output.stdout is None else output.stdout.decode("utf-8"),
"stderr": "" if output.stderr is None else output.stderr.decode("utf-8"),
}

0 comments on commit 773d609

Please sign in to comment.