diff --git a/symengine/lib/symengine_wrapper.in.pyx b/symengine/lib/symengine_wrapper.in.pyx index 8974b7d9..06b85266 100644 --- a/symengine/lib/symengine_wrapper.in.pyx +++ b/symengine/lib/symengine_wrapper.in.pyx @@ -2412,8 +2412,15 @@ class Pow(Expr): class Function(Expr): def __new__(cls, *args, **kwargs): - if cls == Function and len(args) == 1: - return UndefFunction(args[0]) + if cls == Function: + nargs = len(args) + if nargs == 0: + raise TypeError("Required at least one argument to Function") + elif nargs == 1: + return UndefFunction(args[0]) + elif nargs > 1: + raise TypeError(f"Unexpected extra arguments {args[1:]}.") + return super(Function, cls).__new__(cls) @property @@ -2834,6 +2841,10 @@ class FunctionSymbol(Function): name = deref(X).get_name().decode("utf-8") return str(name) + @property + def name(Basic self): + return self.get_name() + def _sympy_(self): import sympy name = self.get_name() diff --git a/symengine/tests/test_functions.py b/symengine/tests/test_functions.py index 3a19b122..207add98 100644 --- a/symengine/tests/test_functions.py +++ b/symengine/tests/test_functions.py @@ -103,6 +103,18 @@ def test_derivative(): assert i == fxy.diff(y, 1, x) +def test_function(): + x = Symbol("x") + fx = Function("f")(x) + assert fx == function_symbol("f", x) + + raises(TypeError, lambda: Function("f", "x")) + raises(TypeError, lambda: Function("f", x)) + raises(TypeError, lambda: Function()) + + assert fx.name == "f" + + def test_abs(): x = Symbol("x") e = abs(x)