Skip to content

Commit

Permalink
Added Tests and Wrapped solve()
Browse files Browse the repository at this point in the history
  • Loading branch information
ShikharJ committed Aug 9, 2017
1 parent d243265 commit ada29df
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 37 deletions.
6 changes: 5 additions & 1 deletion symengine/lib/symengine.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
ctypedef map[RCP[Integer], unsigned] map_integer_uint "SymEngine::map_integer_uint"
cdef struct RCPIntegerKeyLess
cdef struct RCPBasicKeyLess
ctypedef set[RCP[Basic], RCPBasicKeyLess] set_basic "SymEngine::set_basic"
ctypedef set[RCP[const_Basic], RCPBasicKeyLess] set_basic "SymEngine::set_basic"
ctypedef multiset[RCP[const_Basic], RCPBasicKeyLess] multiset_basic "SymEngine::multiset_basic"
cdef cppclass Basic:
string __str__() nogil except +
Expand Down Expand Up @@ -1022,3 +1022,7 @@ cdef extern from "<symengine/sets.h>" namespace "SymEngine":
cdef RCP[const Set] set_complement(RCP[const Set] &universe, RCP[const Set] &container) nogil except +
cdef RCP[const Set] conditionset(RCP[const Basic] &sym, RCP[const Boolean] &condition) nogil except +
cdef RCP[const Set] imageset(RCP[const Basic] &sym, RCP[const Basic] &expr, RCP[const Set] &base) nogil except +

cdef extern from "<symengine/solve.h>" namespace "SymEngine":
cdef RCP[const Set] solve(RCP[const Basic] &f, RCP[const Symbol] &sym) nogil except +
cdef RCP[const Set] solve(RCP[const Basic] &f, RCP[const Symbol] &sym, RCP[const Set] &domain) nogil except +
83 changes: 50 additions & 33 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -183,25 +183,25 @@ cdef c2py(RCP[const symengine.Basic] o):
elif (symengine.is_a_PyNumber(deref(o))):
r = PyNumber.__new__(PyNumber)
elif (symengine.is_a_Piecewise(deref(o))):
r = Piecewise.__new__(Piecewise)
r = Basic.__new__(Piecewise)
elif (symengine.is_a_Contains(deref(o))):
r = Contains.__new__(Contains)
r = Boolean.__new__(Contains)
elif (symengine.is_a_Interval(deref(o))):
r = Interval.__new__(Interval)
r = Set.__new__(Interval)
elif (symengine.is_a_EmptySet(deref(o))):
r = EmptySet.__new__(EmptySet)
r = Set.__new__(EmptySet)
elif (symengine.is_a_UniversalSet(deref(o))):
r = UniversalSet.__new__(UniversalSet)
r = Set.__new__(UniversalSet)
elif (symengine.is_a_FiniteSet(deref(o))):
r = FiniteSet.__new__(FiniteSet)
r = Set.__new__(FiniteSet)
elif (symengine.is_a_Union(deref(o))):
r = Union.__new__(Union)
r = Set.__new__(Union)
elif (symengine.is_a_Complement(deref(o))):
r = Complement.__new__(Complement)
r = Set.__new__(Complement)
elif (symengine.is_a_ConditionSet(deref(o))):
r = ConditionSet.__new__(ConditionSet)
r = Set.__new__(ConditionSet)
elif (symengine.is_a_ImageSet(deref(o))):
r = ImageSet.__new__(ImageSet)
r = Set.__new__(ImageSet)
elif (symengine.is_a_And(deref(o))):
r = Boolean.__new__(And)
elif (symengine.is_a_Not(deref(o))):
Expand Down Expand Up @@ -380,13 +380,11 @@ def sympy2symengine(a, raise_error=False):
elif isinstance(a, sympy.Not):
return logical_not(a.args[0])
elif isinstance(a, sympy.Nor):
return logical_nor(*a.args)
return Nor(*a.args)
elif isinstance(a, sympy.Nand):
return logical_nand(*a.args)
return Nand(*a.args)
elif isinstance(a, sympy.Xor):
return logical_xor(*a.args)
elif isinstance(a, sympy.Xnor):
return logical_xnor(*a.args)
elif isinstance(a, sympy.gamma):
return gamma(a.args[0])
elif isinstance(a, sympy.Derivative):
Expand All @@ -400,10 +398,8 @@ def sympy2symengine(a, raise_error=False):
return piecewise(*(a.args))
elif isinstance(a, sympy.Interval):
return interval(*(a.args))
elif isinstance(a, sympy.S.EmptySet):
elif isinstance(a, sympy.EmptySet):
return emptyset()
elif isinstance(a, sympy.S.UniversalSet):
return universalset()
elif isinstance(a, sympy.FiniteSet):
return finiteset(*(a.args))
elif isinstance(a, sympy.Contains):
Expand All @@ -414,8 +410,6 @@ def sympy2symengine(a, raise_error=False):
return set_intersection(*(a.args))
elif isinstance(a, sympy.Complement):
return set_complement(*(a.args))
elif isinstance(a, sympy.ConditionSet):
return conditionset(*(a.args))
elif isinstance(a, sympy.ImageSet):
return imageset(*(a.args))
elif isinstance(a, sympy.Function):
Expand All @@ -439,6 +433,8 @@ def sympy2symengine(a, raise_error=False):
return acsch(a.args[0])
elif isinstance(a, sympy.asech):
return asech(a.args[0])
elif isinstance(a, sympy.ConditionSet):
return conditionset(*(a.args))

if raise_error:
raise SympifyError("sympy2symengine: Cannot convert '%r' to a symengine type." % a)
Expand Down Expand Up @@ -1171,7 +1167,8 @@ class Not(Boolean):

def _sympy_(self):
import sympy
return sympy.Not(c2py(<RCP[const symengine.Basic]>(self.args[0])._sympy_()))
s = self.args_as_sympy()[0]
return sympy.Not(s)


class Xor(Boolean):
Expand Down Expand Up @@ -2577,6 +2574,10 @@ class EmptySet(Set):
import sympy
return sympy.EmptySet()

@property
def func(self):
return self.__class__


class UniversalSet(Set):

Expand All @@ -2587,6 +2588,10 @@ class UniversalSet(Set):
import sympy
return sympy.UniversalSet()

@property
def func(self):
return self.__class__


class FiniteSet(Set):

Expand Down Expand Up @@ -3556,29 +3561,27 @@ def logical_or(*args):
s.insert(symengine.rcp_static_cast_Boolean(e_.thisptr))
return c2py(<RCP[const symengine.Basic]>(symengine.logical_or(s)))

def logical_nor(*args):
def Nor(*args):
cdef symengine.set_boolean s
cdef Boolean e_
for e in args:
e_ = sympify(e)
s.insert(symengine.rcp_static_cast_Boolean(e_.thisptr))
return c2py(<RCP[const symengine.Basic]>(symengine.logical_nor(s)))

Nor = logical_nor

def logical_nand(*args):
def Nand(*args):
cdef symengine.set_boolean s
cdef Boolean e_
for e in args:
e_ = sympify(e)
s.insert(symengine.rcp_static_cast_Boolean(e_.thisptr))
return c2py(<RCP[const symengine.Basic]>(symengine.logical_nand(s)))

Nand = logical_nand

def logical_not(x):
cdef Boolean X = sympify(x)
return c2py(<RCP[const symengine.Basic]>(symengine.logical_not(symengine.rcp_static_cast_Boolean(X.thisptr))))
cdef Basic x_ = sympify(x)
require(x_, Boolean)
cdef RCP[const symengine.Boolean] _x = symengine.rcp_static_cast_Boolean(x_.thisptr)
return c2py(<RCP[const symengine.Basic]>(symengine.logical_not(_x)))

def logical_xor(*args):
cdef symengine.vec_boolean v
Expand All @@ -3588,16 +3591,14 @@ def logical_xor(*args):
v.push_back(symengine.rcp_static_cast_Boolean(e_.thisptr))
return c2py(<RCP[const symengine.Basic]>(symengine.logical_xor(v)))

def logical_xnor(*args):
def Xnor(*args):
cdef symengine.vec_boolean v
cdef Boolean e_
for e in args:
e_ = sympify(e)
v.push_back(symengine.rcp_static_cast_Boolean(e_.thisptr))
return c2py(<RCP[const symengine.Basic]>(symengine.logical_xnor(v)))

Xnor = logical_xnor

def eval_double(x):
cdef Basic X = sympify(x)
return c2py(<RCP[const symengine.Basic]>(symengine.real_double(symengine.eval_double(deref(X.thisptr)))))
Expand Down Expand Up @@ -3998,7 +3999,6 @@ def powermod(a, b, m):
cdef RCP[const symengine.Integer] m1 = symengine.rcp_static_cast_Integer(_m.thisptr)
cdef RCP[const symengine.Number] b1 = symengine.rcp_static_cast_Number(_b.thisptr)
cdef RCP[const symengine.Integer] root

cdef cppbool ret_val = symengine.powermod(symengine.outArg_Integer(root), a1, b1, m1)
if ret_val == 0:
return None
Expand Down Expand Up @@ -4405,6 +4405,10 @@ def piecewise(*v):


def interval(start, end, left_open=False, right_open=False):
if isinstance(start, NegativeInfinity):
left_open = True
if isinstance(end, Infinity):
right_open = True
cdef Number start_ = sympify(start)
cdef Number end_ = sympify(end)
cdef cppbool left_open_ = left_open
Expand All @@ -4427,7 +4431,7 @@ def finiteset(*args):
cdef Basic e_
for e in args:
e_ = sympify(e)
s.insert(e_.thisptr)
s.insert(<RCP[symengine.const_Basic]>(e_.thisptr))
return c2py(<RCP[const symengine.Basic]>(symengine.finiteset(s)))


Expand Down Expand Up @@ -4486,5 +4490,18 @@ def imageset(sym, expr, base):
cdef RCP[const symengine.Set] b = symengine.rcp_static_cast_Set(base_.thisptr)
return c2py(<RCP[const symengine.Basic]>(symengine.imageset(sym_.thisptr, expr_.thisptr, b)))


def solve(f, sym, domain=None):
cdef Basic f_ = sympify(f)
cdef Basic sym_ = sympify(sym)
require(sym_, Symbol)
cdef RCP[const symengine.Symbol] x = symengine.rcp_static_cast_Symbol(sym_.thisptr)
if domain is None:
return c2py(<RCP[const symengine.Basic]>(symengine.solve(f_.thisptr, x)))
cdef Set domain_ = sympify(domain)
cdef RCP[const symengine.Set] d = symengine.rcp_static_cast_Set(domain_.thisptr)
return c2py(<RCP[const symengine.Basic]>(symengine.solve(f_.thisptr, x, d)))


# Turn on nice stacktraces:
symengine.print_stack_on_segfault()
2 changes: 2 additions & 0 deletions symengine/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ install(FILES __init__.py
test_printing.py
test_sage.py
test_series_expansion.py
test_sets.py
test_solve.py
test_subs.py
test_symbol.py
test_sympify.py
Expand Down
90 changes: 88 additions & 2 deletions symengine/tests/test_logic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from symengine.utilities import raises
from symengine.lib.symengine_wrapper import (true, false, Eq, Ne,
Ge, Gt, Le, Lt, Symbol, I)
from symengine.lib.symengine_wrapper import (true, false, Eq, Ne, Ge, Gt, Le, Lt, Symbol,
I, And, Or, Not, Nand, Nor, Xor, Xnor, Piecewise,
Contains, Interval, FiniteSet, oo, log)

x = Symbol("x")
y = Symbol("y")
z = Symbol("z")

def test_relationals():
assert Eq(0) == true
Expand All @@ -26,8 +28,92 @@ def test_relationals():
assert Eq(I, 2) == false
assert Ne(I, 2) == true


def test_rich_cmp():
assert (x < y) == Lt(x, y)
assert (x <= y) == Le(x, y)
assert (x > y) == Gt(x, y)
assert (x >= y) == Ge(x, y)


def test_And():
assert And() == true
assert And(True) == true
assert And(False) == false
assert And(True, True ) == true
assert And(True, False) == false
assert And(False, False) == false
assert And(True, True, True) == true


def test_Or():
assert Or() == false
assert Or(True) == true
assert Or(False) == false
assert Or(True, True ) == true
assert Or(True, False) == true
assert Or(False, False) == false
assert Or(True, False, False) == true


def test_Nor():
assert Nor() == true
assert Nor(True) == false
assert Nor(False) == true
assert Nor(True, True ) == false
assert Nor(True, False) == false
assert Nor(False, False) == true
assert Nor(True, True, True) == false


def test_Nand():
assert Nand() == false
assert Nand(True) == false
assert Nand(False) == true
assert Nand(True, True) == false
assert Nand(True, False) == true
assert Nand(False, False) == true
assert Nand(True, True, True) == false


def test_Not():
assert Not(True) == false
assert Not(False) == true


def test_Xor():
assert Xor() == false
assert Xor(True) == true
assert Xor(False) == false
assert Xor(True, True ) == false
assert Xor(True, False) == true
assert Xor(False, False) == false
assert Xor(True, False, False) == true


def test_Xnor():
assert Xnor() == true
assert Xnor(True) == false
assert Xnor(False) == true
assert Xnor(True, True ) == true
assert Xnor(True, False) == false
assert Xnor(False, False) == true
assert Xnor(True, False, False) == false


def test_Piecewise():
assert Piecewise((x, x < 1), (0, True)) == Piecewise((x, x < 1), (0, True))
int1 = Interval(1, 2, True, False)
int2 = Interval(2, 5, True, False)
int3 = Interval(5, 10, True, False)
p = Piecewise((x, Contains(x, int1)), (y, Contains(x, int2)), (x + y, Contains(x, int3)))
q = Piecewise((1, Contains(x, int1)), (0, Contains(x, int2)), (1, Contains(x, int3)))
assert p.diff(x) == q


def test_Contains():
assert Contains(x, FiniteSet(0)) != false
assert Contains(x, Interval(1, 1)) != false
assert Contains(oo, Interval(-oo, oo)) == false
assert Contains(-oo, Interval(-oo, oo)) == false

Loading

0 comments on commit ada29df

Please sign in to comment.