Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fractions functionality issue #395 #403

Merged
merged 22 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def visit_Name(self, node: TypedName) -> plt.AST:
if isinstance(node.typ, ClassType):
# if this is not an instance but a class, call the constructor
return node.typ.constr()
if hasattr(node, "is_wrapped") and node.is_wrapped:
return transform_ext_params_map(node.typ)(plt.Force(plt.Var(node.id)))
return plt.Force(plt.Var(node.id))

def visit_Expr(self, node: TypedExpr) -> CallAST:
Expand Down Expand Up @@ -433,7 +435,7 @@ def visit_Call(self, node: TypedCall) -> plt.AST:
assert isinstance(t, InstanceType)
# pass in all arguments evaluated with the statemonad
a_int = self.visit(a)
if isinstance(t.typ, AnyType):
if isinstance(t.typ, AnyType) or isinstance(t.typ, UnionType):
# if the function expects input of generic type data, wrap data before passing it inside
a_int = transform_output_map(a.typ)(a_int)
args.append(a_int)
Expand Down Expand Up @@ -914,6 +916,14 @@ def visit_Dict(self, node: TypedDict) -> plt.AST:
return l

def visit_IfExp(self, node: TypedIfExp) -> plt.AST:
if isinstance(node.typ.typ, UnionType):
body = self.visit(node.body)
orelse = self.visit(node.orelse)
if not isinstance(node.body.typ, UnionType):
body = transform_output_map(node.body.typ)(body)
if not isinstance(node.orelse.typ, UnionType):
orelse = transform_output_map(node.orelse.typ)(orelse)
return plt.Ite(self.visit(node.test), body, orelse)
return plt.Ite(
self.visit(node.test),
self.visit(node.body),
Expand Down
6 changes: 6 additions & 0 deletions opshin/fun_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def type_from_args(self, args: typing.List[Type]) -> FunctionType:
return FunctionType(args, BoolInstanceType)

def impl_from_args(self, args: typing.List[Type]) -> plt.AST:
if not (isinstance(args[0], UnionType) or isinstance(args[0].typ, UnionType)):
if args[0].typ == args[1]:
return OLambda(["x"], plt.Bool(True))
else:
return OLambda(["x"], plt.Bool(False))

if isinstance(args[1], IntegerType):
return OLambda(
["x"],
Expand Down
1 change: 1 addition & 0 deletions opshin/rewrite/rewrite_forbidden_overwrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"List",
"Dict",
"Union",
"Self",
# decorator and class name
"dataclass",
"PlutusData",
Expand Down
10 changes: 10 additions & 0 deletions opshin/rewrite/rewrite_import_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def visit_ClassDef(self, node: ClassDef) -> ClassDef:
and arg.annotation.id == "Self"
):
node.body[i].args.args[j].annotation.idSelf = node.name
if (
isinstance(arg.annotation, Subscript)
and arg.annotation.value.id == "Union"
):
for k, s in enumerate(arg.annotation.slice.elts):
if isinstance(s, Name) and s.id == "Self":
node.body[i].args.args[j].annotation.slice.elts[
k
].idSelf = node.name

if (
isinstance(attribute.returns, Name)
and attribute.returns.id == "Self"
Expand Down
17 changes: 7 additions & 10 deletions opshin/rewrite/rewrite_scoping.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class RewriteScoping(CompilingNodeTransformer):
step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope"
latest_scope_id: int
scopes: typing.List[typing.Tuple[OrderedSet, int]]
current_Self: typing.Tuple[str, str]

def variable_scope_id(self, name: str) -> int:
"""find the id of the scope in which this variable is defined (closest to its usage)"""
Expand Down Expand Up @@ -86,13 +87,17 @@ def visit_Module(self, node: Module) -> Module:
def visit_Name(self, node: Name) -> Name:
nc = copy(node)
# setting is handled in either enclosing module or function
if node.id == "Self":
assert node.idSelf == self.current_Self[1]
nc.idSelf_new = self.current_Self[0]
nc.id = self.map_name(node.id)
return nc

def visit_ClassDef(self, node: ClassDef) -> ClassDef:
cp_node = RecordScoper.scope(node, self)
for i, attribute in enumerate(cp_node.body):
if isinstance(attribute, FunctionDef):
self.current_Self = (cp_node.name, cp_node.orig_name)
cp_node.body[i] = self.visit_FunctionDef(attribute, method=True)
return cp_node

Expand All @@ -108,17 +113,9 @@ def visit_FunctionDef(self, node: FunctionDef, method: bool = False) -> Function
a_cp = copy(a)
self.set_variable_scope(a.arg)
a_cp.arg = self.map_name(a.arg)
a_cp.annotation = (
self.visit(a.annotation)
if not hasattr(a.annotation, "idSelf")
else a.annotation
)
a_cp.annotation = self.visit(a.annotation)
node_cp.args.args.append(a_cp)
node_cp.returns = (
self.visit(node.returns)
if not hasattr(node.returns, "idSelf")
else node.returns
)
node_cp.returns = self.visit(node.returns)
# vars defined in this scope
shallow_node_def_collector = ShallowNameDefCollector()
for s in node.body:
Expand Down
231 changes: 159 additions & 72 deletions opshin/std/fractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,165 @@ class Fraction(PlutusData):
numerator: int
denominator: int


def add_fraction(a: Fraction, b: Fraction) -> Fraction:
"""returns a + b"""
return Fraction(
(a.numerator * b.denominator) + (b.numerator * a.denominator),
a.denominator * b.denominator,
)


def neg_fraction(a: Fraction) -> Fraction:
"""returns -a"""
return Fraction(-a.numerator, a.denominator)


def sub_fraction(a: Fraction, b: Fraction) -> Fraction:
"""returns a - b"""
return add_fraction(a, neg_fraction(b))


def mul_fraction(a: Fraction, b: Fraction) -> Fraction:
"""returns a * b"""
return Fraction(a.numerator * b.numerator, a.denominator * b.denominator)


def div_fraction(a: Fraction, b: Fraction) -> Fraction:
"""returns a / b"""
return Fraction(a.numerator * b.denominator, a.denominator * b.numerator)
def norm(self) -> "Fraction":
"""Restores the invariant that num/denom are in the smallest possible denomination and denominator > 0"""
return _norm_gcd_fraction(_norm_signs_fraction(self))

def ceil(self) -> int:
return (
self.numerator + self.denominator - sign(self.denominator)
) // self.denominator

def __add__(self, other: Union["Fraction", int]) -> "Fraction":
"""returns self + other"""
if isinstance(other, Fraction):
return Fraction(
(self.numerator * other.denominator)
+ (other.numerator * self.denominator),
self.denominator * other.denominator,
)
else:
return Fraction(
(self.numerator) + (other * self.denominator),
self.denominator,
)

def __neg__(
self,
) -> "Fraction":
"""returns -self"""
return Fraction(-self.numerator, self.denominator)

def __sub__(self, other: Union["Fraction", int]) -> "Fraction":
"""returns self - other"""
if isinstance(other, Fraction):
return Fraction(
(self.numerator * other.denominator)
- (other.numerator * self.denominator),
self.denominator * other.denominator,
)
else:
return Fraction(
self.numerator - (other * self.denominator), self.denominator
)

def __mul__(self, other: Union["Fraction", int]) -> "Fraction":
"""returns self * other"""
if isinstance(other, Fraction):
return Fraction(
self.numerator * other.numerator, self.denominator * other.denominator
)
else:
return Fraction(self.numerator * other, self.denominator)

def __truediv__(self, other: Union["Fraction", int]) -> "Fraction":
"""returns self / other"""
if isinstance(other, Fraction):
return Fraction(
self.numerator * other.denominator, self.denominator * other.numerator
)
else:
return Fraction(self.numerator, self.denominator * other)

def __ge__(self, other: Union["Fraction", int]) -> bool:
"""returns self >= other"""
if isinstance(other, Fraction):
if self.denominator * other.denominator >= 0:
res = (
self.numerator * other.denominator
>= self.denominator * other.numerator
)
else:
res = (
self.numerator * other.denominator
<= self.denominator * other.numerator
)
return res
else:
if self.denominator >= 0:
res = self.numerator >= self.denominator * other
else:
res = self.numerator <= self.denominator * other
return res

def __le__(self, other: Union["Fraction", int]) -> bool:
"""returns self <= other"""
if isinstance(other, Fraction):
if self.denominator * other.denominator >= 0:
res = (
self.numerator * other.denominator
<= self.denominator * other.numerator
)
else:
res = (
self.numerator * other.denominator
>= self.denominator * other.numerator
)
return res
else:
if self.denominator >= 0:
res = self.numerator <= self.denominator * other
else:
res = self.numerator >= self.denominator * other
return res

def __eq__(self, other: Union["Fraction", int]) -> bool:
"""returns self == other"""
if isinstance(other, Fraction):
return (
self.numerator * other.denominator == self.denominator * other.numerator
)
else:
return self.numerator == self.denominator * other

def __lt__(self, other: Union["Fraction", int]) -> bool:
"""returns self < other"""
if isinstance(other, Fraction):
if self.denominator * other.denominator >= 0:
res = (
self.numerator * other.denominator
< self.denominator * other.numerator
)
else:
res = (
self.numerator * other.denominator
> self.denominator * other.numerator
)
return res
else:
if self.denominator >= 0:
res = self.numerator < self.denominator * other
else:
res = self.numerator > self.denominator * other
return res

def __gt__(self, other: Union["Fraction", int]) -> bool:
"""returns self > other"""
if isinstance(other, Fraction):
if self.denominator * other.denominator >= 0:
res = (
self.numerator * other.denominator
> self.denominator * other.numerator
)
else:
res = (
self.numerator * other.denominator
< self.denominator * other.numerator
)
return res
else:
if self.denominator >= 0:
res = self.numerator > self.denominator * other
else:
res = self.numerator < self.denominator * other
return res

def __floordiv__(self, other: Union["Fraction", int]) -> int:
if isinstance(other, Fraction):
x = self / other
return x.numerator // x.denominator
else:
return self.numerator // (other * self.denominator)


def _norm_signs_fraction(a: Fraction) -> Fraction:
Expand All @@ -62,50 +194,5 @@ def norm_fraction(a: Fraction) -> Fraction:
return _norm_gcd_fraction(_norm_signs_fraction(a))


def ge_fraction(a: Fraction, b: Fraction) -> bool:
"""returns a >= b"""
if a.denominator * b.denominator >= 0:
res = a.numerator * b.denominator >= a.denominator * b.numerator
else:
res = a.numerator * b.denominator <= a.denominator * b.numerator
return res


def le_fraction(a: Fraction, b: Fraction) -> bool:
"""returns a <= b"""
if a.denominator * b.denominator >= 0:
res = a.numerator * b.denominator <= a.denominator * b.numerator
else:
res = a.numerator * b.denominator >= a.denominator * b.numerator
return res


def eq_fraction(a: Fraction, b: Fraction) -> bool:
"""returns a == b"""
return a.numerator * b.denominator == a.denominator * b.numerator


def lt_fraction(a: Fraction, b: Fraction) -> bool:
"""returns a < b"""
if a.denominator * b.denominator >= 0:
res = a.numerator * b.denominator < a.denominator * b.numerator
else:
res = a.numerator * b.denominator > a.denominator * b.numerator
return res


def gt_fraction(a: Fraction, b: Fraction) -> bool:
"""returns a > b"""
if a.denominator * b.denominator >= 0:
res = a.numerator * b.denominator > a.denominator * b.numerator
else:
res = a.numerator * b.denominator < a.denominator * b.numerator
return res


def floor_fraction(a: Fraction) -> int:
return a.numerator // a.denominator


def ceil_fraction(a: Fraction) -> int:
return (a.numerator + a.denominator - sign(a.denominator)) // a.denominator
Loading
Loading