From 6f99124f990f47557a65d95c372788e5733e664e Mon Sep 17 00:00:00 2001 From: SCMusson Date: Wed, 4 Sep 2024 14:45:45 +0100 Subject: [PATCH] Bug fix, Recursion Error #369 (#404) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bug fix, add int and bool, bool gt/lt * Unit tests * Also run sub --------- Co-authored-by: Niels Co-authored-by: Niels Mündler --- opshin/tests/test_misc.py | 18 ++++++++++++++++++ opshin/type_impls.py | 24 ++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/opshin/tests/test_misc.py b/opshin/tests/test_misc.py index b8e43471..be934852 100644 --- a/opshin/tests/test_misc.py +++ b/opshin/tests/test_misc.py @@ -3024,3 +3024,21 @@ def validator(_: None) -> None: """ res = eval_uplc(source_code, Unit()) self.assertEqual(res, uplc.PlutusConstr(0, []), "Invalid return") + + @given(a=st.booleans(), b=st.booleans()) + def test_cast_bool_to_int_lt(self, a: bool, b: bool): + source_code = """ +def validator(a: bool, b:bool)-> int: + return 5+(a int: + return 5-(a>b) +""" + res = eval_uplc_value(source_code, a, b) + self.assertEquals(res, 5 - (a > b)) diff --git a/opshin/type_impls.py b/opshin/type_impls.py index 9bf39bba..7aef7016 100644 --- a/opshin/type_impls.py +++ b/opshin/type_impls.py @@ -1645,6 +1645,9 @@ def _binop_return_type(self, binop: operator, other: "Type") -> "Type": ): if other == IntegerInstanceType: return IntegerType() + elif other == BoolInstanceType: + # cast to integer + return IntegerType() if isinstance(binop, Mult): if other == IntegerInstanceType: return IntegerType() @@ -1673,6 +1676,23 @@ def _binop_bin_fun(self, binop: operator, other: AST): PowImpl(x, OVar("y")), ), ) + if other.typ == BoolInstanceType: + if isinstance(binop, Add): + return lambda x, y: OLet( + [("x", x), ("y", y)], + plt.Ite( + OVar("y"), plt.AddInteger(OVar("x"), plt.Integer(1)), OVar("x") + ), + ) + elif isinstance(binop, Sub): + return lambda x, y: OLet( + [("x", x), ("y", y)], + plt.Ite( + OVar("y"), + plt.SubtractInteger(OVar("x"), plt.Integer(1)), + OVar("x"), + ), + ) if isinstance(binop, Mult): if other.typ == IntegerInstanceType: @@ -2308,6 +2328,10 @@ def cmp(self, op: cmpop, o: "Type") -> plt.AST: return OLambda(["x", "y"], plt.Iff(OVar("x"), OVar("y"))) if isinstance(op, NotEq): return OLambda(["x", "y"], plt.Not(plt.Iff(OVar("x"), OVar("y")))) + if isinstance(op, Lt): + return OLambda(["x", "y"], plt.And(plt.Not(OVar("x")), OVar("y"))) + if isinstance(op, Gt): + return OLambda(["x", "y"], plt.And(OVar("x"), plt.Not(OVar("y")))) return super().cmp(op, o) def stringify(self, recursive: bool = False) -> plt.AST: