From a7136d7890d8fbd439719bccfbc6c7b4ede70e9d Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 22 Jan 2024 21:38:36 -0600 Subject: [PATCH 1/2] Raise error on wrong number of arguments to Function --- symengine/lib/symengine_wrapper.in.pyx | 11 +++++++++-- symengine/tests/test_functions.py | 9 +++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/symengine/lib/symengine_wrapper.in.pyx b/symengine/lib/symengine_wrapper.in.pyx index 8974b7d9..4fdc0f9c 100644 --- a/symengine/lib/symengine_wrapper.in.pyx +++ b/symengine/lib/symengine_wrapper.in.pyx @@ -2412,8 +2412,15 @@ class Pow(Expr): class Function(Expr): def __new__(cls, *args, **kwargs): - if cls == Function and len(args) == 1: - return UndefFunction(args[0]) + if cls == Function: + nargs = len(args) + if nargs == 0: + raise TypeError("Required at least one argument to Function") + elif nargs == 1: + return UndefFunction(args[0]) + elif nargs > 1: + raise TypeError(f"Unexpected extra arguments {args[1:]}.") + return super(Function, cls).__new__(cls) @property diff --git a/symengine/tests/test_functions.py b/symengine/tests/test_functions.py index 3a19b122..c5241ad4 100644 --- a/symengine/tests/test_functions.py +++ b/symengine/tests/test_functions.py @@ -103,6 +103,15 @@ def test_derivative(): assert i == fxy.diff(y, 1, x) +def test_function(): + x = Symbol("x") + assert Function("f")(x) == function_symbol("f", x) + + raises(TypeError, lambda: Function("f", "x")) + raises(TypeError, lambda: Function("f", x)) + raises(TypeError, lambda: Function()) + + def test_abs(): x = Symbol("x") e = abs(x) From a165060bd2fb2d3e64770fa68ad6968689f06f38 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 22 Jan 2024 21:43:00 -0600 Subject: [PATCH 2/2] add name property to FunctionSymbol --- symengine/lib/symengine_wrapper.in.pyx | 4 ++++ symengine/tests/test_functions.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/symengine/lib/symengine_wrapper.in.pyx b/symengine/lib/symengine_wrapper.in.pyx index 4fdc0f9c..06b85266 100644 --- a/symengine/lib/symengine_wrapper.in.pyx +++ b/symengine/lib/symengine_wrapper.in.pyx @@ -2841,6 +2841,10 @@ class FunctionSymbol(Function): name = deref(X).get_name().decode("utf-8") return str(name) + @property + def name(Basic self): + return self.get_name() + def _sympy_(self): import sympy name = self.get_name() diff --git a/symengine/tests/test_functions.py b/symengine/tests/test_functions.py index c5241ad4..207add98 100644 --- a/symengine/tests/test_functions.py +++ b/symengine/tests/test_functions.py @@ -105,12 +105,15 @@ def test_derivative(): def test_function(): x = Symbol("x") - assert Function("f")(x) == function_symbol("f", x) + fx = Function("f")(x) + assert fx == function_symbol("f", x) raises(TypeError, lambda: Function("f", "x")) raises(TypeError, lambda: Function("f", x)) raises(TypeError, lambda: Function()) + assert fx.name == "f" + def test_abs(): x = Symbol("x")