Skip to content

Commit

Permalink
Merge pull request #307 from isuruf/linsolve
Browse files Browse the repository at this point in the history
Linsolve
  • Loading branch information
isuruf authored Jan 20, 2020
2 parents 740b26f + c2597bd commit a25ccb4
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 38 deletions.
1 change: 1 addition & 0 deletions symengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
LessThan, StrictGreaterThan, StrictLessThan, Eq, Ne, Ge, Le,
Gt, Lt, And, Or, Not, Nand, Nor, Xor, Xnor, perfect_power, integer_nthroot,
isprime, sqrt_mod, Expr, cse, count_ops, ccode, Piecewise, Contains, Interval, FiniteSet,
EmptySet, linsolve,
FunctionSymbol as AppliedUndef,
golden_ratio as GoldenRatio,
catalan as Catalan,
Expand Down
2 changes: 2 additions & 0 deletions symengine/lib/symengine.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
void insert(iterator, iterator) except +

ctypedef vector[rcp_const_basic] vec_basic "SymEngine::vec_basic"
ctypedef vector[RCP[Symbol]] vec_sym "SymEngine::vec_sym"
ctypedef vector[RCP[Integer]] vec_integer "SymEngine::vec_integer"
ctypedef map[RCP[Integer], unsigned] map_integer_uint "SymEngine::map_integer_uint"
cdef struct RCPIntegerKeyLess
Expand Down Expand Up @@ -1047,6 +1048,7 @@ cdef extern from "<symengine/sets.h>" namespace "SymEngine":
cdef extern from "<symengine/solve.h>" namespace "SymEngine":
cdef RCP[const Set] solve(rcp_const_basic &f, RCP[const Symbol] &sym) nogil except +
cdef RCP[const Set] solve(rcp_const_basic &f, RCP[const Symbol] &sym, RCP[const Set] &domain) nogil except +
cdef vec_basic linsolve(const vec_basic &eqs, const vec_sym &syms) nogil except +

cdef extern from "<symengine/printers.h>" namespace "SymEngine":
string ccode(const Basic &x) nogil except +
Expand Down
63 changes: 39 additions & 24 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2068,11 +2068,7 @@ class Add(AssocOp):
identity = 0

def __new__(cls, *args, **kwargs):
cdef symengine.vec_basic v_
cdef Basic e
for e_ in args:
e = _sympify(e_)
v_.push_back(e.thisptr)
cdef symengine.vec_basic v_ = iter_to_vec_basic(args)
return c2py(symengine.add(v_))

@classmethod
Expand Down Expand Up @@ -2123,11 +2119,7 @@ class Mul(AssocOp):
identity = 1

def __new__(cls, *args, **kwargs):
cdef symengine.vec_basic v_
cdef Basic e
for e_ in args:
e = _sympify(e_)
v_.push_back(e.thisptr)
cdef symengine.vec_basic v_ = iter_to_vec_basic(args)
return c2py(symengine.mul(v_))

@classmethod
Expand Down Expand Up @@ -2296,11 +2288,7 @@ class KroneckerDelta(Function):

class LeviCivita(Function):
def __new__(cls, *args):
cdef symengine.vec_basic v
cdef Basic e_
for e in args:
e_ = sympify(e)
v.push_back(e_.thisptr)
cdef symengine.vec_basic v = iter_to_vec_basic(args)
return c2py(symengine.levi_civita(v))

def _sympy_(self):
Expand Down Expand Up @@ -2710,11 +2698,7 @@ class PyFunction(FunctionSymbol):
def __init__(Basic self, pyfunction = None, args = None, pyfunction_class=None, module=None):
if pyfunction is None:
return
cdef symengine.vec_basic v
cdef Basic arg_
for arg in args:
arg_ = sympify(arg)
v.push_back(arg_.thisptr)
cdef symengine.vec_basic v = iter_to_vec_basic(args)
cdef PyFunctionClass _pyfunction_class = get_function_class(pyfunction_class, module)
cdef PyObject* _pyfunction = <PyObject*>pyfunction
Py_XINCREF(_pyfunction)
Expand Down Expand Up @@ -3785,42 +3769,53 @@ cdef class ImmutableDenseMatrix(DenseMatrixBase):

ImmutableMatrix = ImmutableDenseMatrix


cdef matrix_to_vec(DenseMatrixBase d, symengine.vec_basic& v):
cdef Basic e_
for i in range(d.nrows()):
for j in range(d.ncols()):
e_ = d._get(i, j)
v.push_back(e_.thisptr)


def eye(n):
cdef DenseMatrixBase d = DenseMatrix(n, n)
symengine.eye(deref(symengine.static_cast_DenseMatrix(d.thisptr)), 0)
return d

def diag(*values):
cdef DenseMatrixBase d = DenseMatrix(len(values), len(values))
cdef symengine.vec_basic V

cdef symengine.vec_basic iter_to_vec_basic(iter):
cdef Basic B
for b in values:
cdef symengine.vec_basic V
for b in iter:
B = sympify(b)
V.push_back(B.thisptr)
return V


def diag(*values):
cdef DenseMatrixBase d = DenseMatrix(len(values), len(values))
cdef symengine.vec_basic V = iter_to_vec_basic(values)
symengine.diag(deref(symengine.static_cast_DenseMatrix(d.thisptr)), V, 0)
return d


def ones(r, c = None):
if c is None:
c = r
cdef DenseMatrixBase d = DenseMatrix(r, c)
symengine.ones(deref(symengine.static_cast_DenseMatrix(d.thisptr)))
return d


def zeros(r, c = None):
if c is None:
c = r
cdef DenseMatrixBase d = DenseMatrix(r, c)
symengine.zeros(deref(symengine.static_cast_DenseMatrix(d.thisptr)))
return d


cdef class Sieve:
@staticmethod
def generate_primes(n):
Expand All @@ -3831,6 +3826,7 @@ cdef class Sieve:
s.append(primes[i])
return s


cdef class Sieve_iterator:
cdef symengine.sieve_iterator *thisptr
cdef unsigned limit
Expand Down Expand Up @@ -5000,6 +4996,25 @@ def solve(f, sym, domain=None):
return c2py(<rcp_const_basic>(symengine.solve(f_.thisptr, x, d)))


def linsolve(eqs, syms):
"""
Solve a set of linear equations given as an iterable `eqs`
which are linear w.r.t the symbols given as an iterable `syms`
"""
cdef symengine.vec_basic eqs_ = iter_to_vec_basic(eqs)
cdef symengine.vec_sym syms_
cdef RCP[const symengine.Symbol] sym_
cdef Symbol B
for sym in syms:
B = sympify(sym)
sym_ = symengine.rcp_static_cast_Symbol(B.thisptr)
syms_.push_back(sym_)
if syms_.size() != eqs_.size():
raise RuntimeError("Number of equations and symbols do not match")
cdef symengine.vec_basic ret = symengine.linsolve(eqs_, syms_)
return vec_basic_to_tuple(ret)


def cse(exprs):
cdef symengine.vec_basic vec
cdef symengine.vec_pair replacements
Expand Down
34 changes: 21 additions & 13 deletions symengine/tests/test_solve.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from symengine.utilities import raises
from symengine.lib.symengine_wrapper import (Interval, EmptySet, FiniteSet,
I, oo, solve, Eq, Symbol)
from symengine import (Interval, EmptySet, FiniteSet, I, oo, Eq, Symbol,
linsolve)
from symengine.lib.symengine_wrapper import solve

def test_solve():
x = Symbol("x")
reals = Interval(-oo, oo)
x = Symbol("x")
reals = Interval(-oo, oo)

assert solve(1, x, reals) == EmptySet()
assert solve(0, x, reals) == reals
assert solve(x + 3, x, reals) == FiniteSet(-3)
assert solve(x + 3, x, Interval(0, oo)) == EmptySet()
assert solve(x, x, reals) == FiniteSet(0)
assert solve(x**2 + 1, x) == FiniteSet(-I, I)
assert solve(x**2 - 2*x + 1, x) == FiniteSet(1)
assert solve(Eq(x**3 + 3*x**2 + 3*x, -1), x, reals) == FiniteSet(-1)
assert solve(x**3 - x, x) == FiniteSet(0, 1, -1)
assert solve(1, x, reals) == EmptySet()
assert solve(0, x, reals) == reals
assert solve(x + 3, x, reals) == FiniteSet(-3)
assert solve(x + 3, x, Interval(0, oo)) == EmptySet()
assert solve(x, x, reals) == FiniteSet(0)
assert solve(x**2 + 1, x) == FiniteSet(-I, I)
assert solve(x**2 - 2*x + 1, x) == FiniteSet(1)
assert solve(Eq(x**3 + 3*x**2 + 3*x, -1), x, reals) == FiniteSet(-1)
assert solve(x**3 - x, x) == FiniteSet(0, 1, -1)

def test_linsolve():
x = Symbol("x")
y = Symbol("y")
assert linsolve([x - 2], [x]) == (2,)
assert linsolve([x - 2, y - 3], [x, y]) == (2, 3)
assert linsolve([x + y - 3, x + 2*y - 4], [x, y]) == (2, 1)
2 changes: 1 addition & 1 deletion symengine_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v0.5.0
fc05f8d55915c2de956e7797d764eb1116b61711

0 comments on commit a25ccb4

Please sign in to comment.