Skip to content

Commit

Permalink
Bug fix, Recursion Error #369 (#404)
Browse files Browse the repository at this point in the history
* Bug fix, add int and bool, bool gt/lt

* Unit tests

* Also run sub

---------

Co-authored-by: Niels <niels.muendler@inf.ethz.ch>
Co-authored-by: Niels Mündler <n.muendler@posteo.de>
  • Loading branch information
3 people authored Sep 4, 2024
1 parent fd297fb commit 6f99124
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
18 changes: 18 additions & 0 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3024,3 +3024,21 @@ def validator(_: None) -> None:
"""
res = eval_uplc(source_code, Unit())
self.assertEqual(res, uplc.PlutusConstr(0, []), "Invalid return")

@given(a=st.booleans(), b=st.booleans())
def test_cast_bool_to_int_lt(self, a: bool, b: bool):
source_code = """
def validator(a: bool, b:bool)-> int:
return 5+(a<b)
"""
res = eval_uplc_value(source_code, a, b)
self.assertEquals(res, 5 + (a < b))

@given(a=st.booleans(), b=st.booleans())
def test_cast_bool_to_int_gt(self, a: bool, b: bool):
source_code = """
def validator(a: bool, b:bool)-> int:
return 5-(a>b)
"""
res = eval_uplc_value(source_code, a, b)
self.assertEquals(res, 5 - (a > b))
24 changes: 24 additions & 0 deletions opshin/type_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,9 @@ def _binop_return_type(self, binop: operator, other: "Type") -> "Type":
):
if other == IntegerInstanceType:
return IntegerType()
elif other == BoolInstanceType:
# cast to integer
return IntegerType()
if isinstance(binop, Mult):
if other == IntegerInstanceType:
return IntegerType()
Expand Down Expand Up @@ -1673,6 +1676,23 @@ def _binop_bin_fun(self, binop: operator, other: AST):
PowImpl(x, OVar("y")),
),
)
if other.typ == BoolInstanceType:
if isinstance(binop, Add):
return lambda x, y: OLet(
[("x", x), ("y", y)],
plt.Ite(
OVar("y"), plt.AddInteger(OVar("x"), plt.Integer(1)), OVar("x")
),
)
elif isinstance(binop, Sub):
return lambda x, y: OLet(
[("x", x), ("y", y)],
plt.Ite(
OVar("y"),
plt.SubtractInteger(OVar("x"), plt.Integer(1)),
OVar("x"),
),
)

if isinstance(binop, Mult):
if other.typ == IntegerInstanceType:
Expand Down Expand Up @@ -2308,6 +2328,10 @@ def cmp(self, op: cmpop, o: "Type") -> plt.AST:
return OLambda(["x", "y"], plt.Iff(OVar("x"), OVar("y")))
if isinstance(op, NotEq):
return OLambda(["x", "y"], plt.Not(plt.Iff(OVar("x"), OVar("y"))))
if isinstance(op, Lt):
return OLambda(["x", "y"], plt.And(plt.Not(OVar("x")), OVar("y")))
if isinstance(op, Gt):
return OLambda(["x", "y"], plt.And(OVar("x"), plt.Not(OVar("y"))))
return super().cmp(op, o)

def stringify(self, recursive: bool = False) -> plt.AST:
Expand Down

0 comments on commit 6f99124

Please sign in to comment.