Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workaround for symbol class leak #403

Merged
merged 1 commit into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions symengine/lib/pywrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,47 +281,61 @@ inline PyObject* get_pickle_module() {
return module;
}

PyObject* pickle_loads(const std::string &pickle_str) {
PyObject *module = get_pickle_module();
PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size());
PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes);
Py_XDECREF(pickle_bytes);
if (obj == NULL) {
throw SerializationError("error when loading pickled symbol subclass object");
}
return obj;
}

RCP<const Basic> load_basic(cereal::PortableBinaryInputArchive &ar, RCP<const Symbol> &)
{
bool is_pysymbol;
bool store_pickle;
std::string name;
ar(is_pysymbol);
ar(name);
if (is_pysymbol) {
std::string pickle_str;
ar(pickle_str);
PyObject *module = get_pickle_module();
PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size());
PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes);
if (obj == NULL) {
throw SerializationError("error when loading pickled symbol subclass object");
}
RCP<const Basic> result = make_rcp<PySymbol>(name, obj);
Py_XDECREF(pickle_bytes);
ar(store_pickle);
PyObject *obj = pickle_loads(pickle_str);
RCP<const Basic> result = make_rcp<PySymbol>(name, obj, store_pickle);
Py_XDECREF(obj);
return result;
} else {
return symbol(name);
}
}

std::string pickle_dumps(const PyObject * obj) {
PyObject *module = get_pickle_module();
PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", obj);
if (pickle_bytes == NULL) {
throw SerializationError("error when pickling symbol subclass object");
}
Py_ssize_t size;
char* buffer;
PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size);
return std::string(buffer, size);
}

void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b)
{
bool is_pysymbol = is_a_sub<PySymbol>(b);
ar(is_pysymbol);
ar(b.__str__());
if (is_pysymbol) {
RCP<const PySymbol> p = rcp_static_cast<const PySymbol>(b.rcp_from_this());
PyObject *module = get_pickle_module();
PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object());
if (pickle_bytes == NULL) {
throw SerializationError("error when pickling symbol subclass object");
}
Py_ssize_t size;
char* buffer;
PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size);
std::string pickle_str(buffer, size);
PyObject *obj = p->get_py_object();
std::string pickle_str = pickle_dumps(obj);
ar(pickle_str);
Py_XDECREF(pickle_bytes);
ar(p->store_pickle);
Py_XDECREF(obj);
}
}

Expand Down
27 changes: 22 additions & 5 deletions symengine/lib/pywrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

namespace SymEngine {

std::string pickle_dumps(const PyObject *);
PyObject* pickle_loads(const std::string &);

/*
* PySymbol is a subclass of Symbol that keeps a reference to a Python object.
* When subclassing a Symbol from Python, the information stored in subclassed
Expand All @@ -27,16 +30,30 @@ namespace SymEngine {
class PySymbol : public Symbol {
private:
PyObject* obj;
std::string bytes;
public:
PySymbol(const std::string& name, PyObject* obj) : Symbol(name), obj(obj) {
Py_INCREF(obj);
const bool store_pickle;
PySymbol(const std::string& name, PyObject* obj, bool store_pickle) :
Symbol(name), obj(obj), store_pickle(store_pickle) {
if (store_pickle) {
bytes = pickle_dumps(obj);
} else {
Py_INCREF(obj);
}
}
PyObject* get_py_object() const {
return obj;
if (store_pickle) {
return pickle_loads(bytes);
} else {
Py_INCREF(obj);
return obj;
}
}
virtual ~PySymbol() {
// TODO: This is never called because of the cyclic reference.
Py_DECREF(obj);
if (not store_pickle) {
// TODO: This is never called because of the cyclic reference.
Py_DECREF(obj);
}
}
};

Expand Down
8 changes: 4 additions & 4 deletions symengine/lib/symengine.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
bool neq(const Basic &a, const Basic &b) nogil except +

RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast<const SymEngine::Symbol>"(rcp_const_basic &b) nogil
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(rcp_const_basic &b) nogil
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(rcp_const_basic &b) nogil except +
RCP[const Integer] rcp_static_cast_Integer "SymEngine::rcp_static_cast<const SymEngine::Integer>"(rcp_const_basic &b) nogil
RCP[const Rational] rcp_static_cast_Rational "SymEngine::rcp_static_cast<const SymEngine::Rational>"(rcp_const_basic &b) nogil
RCP[const Complex] rcp_static_cast_Complex "SymEngine::rcp_static_cast<const SymEngine::Complex>"(rcp_const_basic &b) nogil
Expand Down Expand Up @@ -367,8 +367,8 @@ cdef extern from "pywrapper.h" namespace "SymEngine":

cdef extern from "pywrapper.h" namespace "SymEngine":
cdef cppclass PySymbol(Symbol):
PySymbol(string name, PyObject* pyobj)
PyObject* get_py_object()
PySymbol(string name, PyObject* pyobj, bool use_pickle) except +
PyObject* get_py_object() except +

string wrapper_dumps(const Basic &x) nogil except +
rcp_const_basic wrapper_loads(const string &s) nogil except +
Expand Down Expand Up @@ -477,7 +477,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
rcp_const_basic make_rcp_Symbol "SymEngine::make_rcp<const SymEngine::Symbol>"(string name) nogil
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"() nogil
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string name) nogil
rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj) nogil
rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj, bool use_pickle) except +
rcp_const_basic make_rcp_Constant "SymEngine::make_rcp<const SymEngine::Constant>"(string name) nogil
rcp_const_basic make_rcp_Infty "SymEngine::make_rcp<const SymEngine::Infty>"(RCP[const Number] i) nogil
rcp_const_basic make_rcp_NaN "SymEngine::make_rcp<const SymEngine::NaN>"() nogil
Expand Down
20 changes: 17 additions & 3 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ cpdef void assign_to_capsule(object capsule, object value):

cdef object c2py(rcp_const_basic o):
cdef Basic r
cdef PyObject *obj
if (symengine.is_a_Add(deref(o))):
r = Expr.__new__(Add)
elif (symengine.is_a_Mul(deref(o))):
Expand Down Expand Up @@ -74,7 +75,10 @@ cdef object c2py(rcp_const_basic o):
r = Dummy.__new__(Dummy)
elif (symengine.is_a_Symbol(deref(o))):
if (symengine.is_a_PySymbol(deref(o))):
return <object>(deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object())
obj = deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object()
result = <object>(obj)
Py_XDECREF(obj);
return result
r = Symbol.__new__(Symbol)
elif (symengine.is_a_Constant(deref(o))):
r = S.Pi
Expand Down Expand Up @@ -1216,16 +1220,26 @@ cdef class Expr(Basic):


cdef class Symbol(Expr):

"""
Symbol is a class to store a symbolic variable with a given name.
Subclassing Symbol leads to a memory leak due to a cycle in reference counting.
To avoid this with a performance penalty, set the kwarg store_pickle=True
in the constructor and support the pickle protocol in the subclass by
implmenting __reduce__.
"""

def __init__(Basic self, name, *args, **kwargs):
cdef cppbool store_pickle;
if type(self) == Symbol:
self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8"))
else:
self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), <PyObject*>self)
store_pickle = kwargs.pop("store_pickle", False)
if store_pickle:
# First set the pointer to a regular symbol so that when pickle.dumps
# is called when the PySymbol is created, methods like name works.
self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8"))
self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), <PyObject*>self,
store_pickle)

def _sympy_(self):
import sympy
Expand Down