Skip to content

Commit

Permalink
Merge pull request #366 from isuruf/matrix-diff
Browse files Browse the repository at this point in the history
Fix Matrix.diff
  • Loading branch information
isuruf authored Sep 9, 2021
2 parents 0529d31 + 95e5321 commit 8808166
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
24 changes: 16 additions & 8 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ cdef class Basic(object):
if (len(f) != 1):
raise RuntimeError("Variable w.r.t should be given")
return self._diff(f.pop())
return diff(self, *args)
return _diff(self, *args)

def subs_dict(Basic self not None, *args):
warnings.warn("subs_dict() is deprecated. Use subs() instead", DeprecationWarning)
Expand Down Expand Up @@ -3687,7 +3687,7 @@ cdef class DenseMatrixBase(MatrixBase):
return R

def diff(self, *args):
return diff(self, *args)
return _diff(self, *args)

#TODO: implement this in C++
def subs(self, *args):
Expand Down Expand Up @@ -4063,15 +4063,23 @@ def module_cleanup():
import atexit
atexit.register(module_cleanup)


def diff(expr, *args):
cdef Basic ex = sympify(expr)
if isinstance(expr, MatrixBase):
# Don't sympify matrices so that mutable matrices
# return mutable matrices
return _diff(expr, *args)
return _diff(sympify(expr), *args)


def _diff(expr, *args):
cdef Basic prev
cdef Basic b
cdef size_t i
cdef size_t length = len(args)

if not length:
return ex
return expr

cdef size_t l = 0
cdef Basic cur_arg, next_arg
Expand All @@ -4083,20 +4091,20 @@ def diff(expr, *args):

if l + 1 == length:
# No next argument, differentiate with no integer argument
return ex._diff(cur_arg)
return expr._diff(cur_arg)

next_arg = sympify(args[l + 1])
# Check if the next arg was derivative order
if isinstance(next_arg, Integer):
i = int(next_arg)
for _ in range(i):
ex = ex._diff(cur_arg)
expr = expr._diff(cur_arg)
l += 2
if l == length:
return ex
return expr
cur_arg = sympify(args[l])
else:
ex = ex._diff(cur_arg)
expr = expr._diff(cur_arg)
l += 1
cur_arg = next_arg

Expand Down
8 changes: 8 additions & 0 deletions symengine/tests/test_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,14 @@ def test_cross():
DenseMatrix(1, 2, [1, 1]).cross(DenseMatrix(1, 2, [1, 1])))


def test_diff():
x = symbols("x")
M = DenseMatrix(1, 2, [x**2, x])
result = M.diff(x)
assert isinstance(result, DenseMatrix)
assert result == DenseMatrix(1, 2, [2*x, 1])


def test_immutablematrix():
A = ImmutableMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
assert A.shape == (3, 3)
Expand Down

0 comments on commit 8808166

Please sign in to comment.