Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RMM PyTorch allocator #1168

Merged
merged 18 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -732,3 +732,50 @@ This can be done in two ways:
**Note:** This only configures Numba to use the current RMM resource for allocations.
It does not initialize nor change the current resource, e.g., enabling a memory pool.
See [here](#memoryresource-objects) for more information on changing the current memory resource.

### Using RMM with PyTorch

[PyTorch](https://pytorch.org/docs/stable/notes/cuda.html) can use RMM
for memory allocation. For example, to configure PyTorch to use an
RMM-managed pool:

```python
import rmm
import torch

rmm.reinitialize(pool_allocator=True)
torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator)
```

PyTorch and RMM will now share the same memory pool.
shwina marked this conversation as resolved.
Show resolved Hide resolved

You can, of course, use a custom memory resource with PyTorch as well:

```python
import rmm
import torch

# note that you can configure PyTorch to use RMM either before or
# after changing RMM's memory resource. PyTorch will use whatever
# memory resource is configured to be the "current" memory resource at
# the time of allocation.
torch.cuda.change_current_allocator(rmm.rmm_torch_allocator)

# configure RMM to use a managed memory resource, wrapped with a
# statistics resource adaptor that can report information about the
# amount of memory allocated:
mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.ManagedMemoryResource())
rmm.mr.set_current_device_resource(mr)

x = torch.tensor([1, 2]).cuda()

# the memory resource reports information about PyTorch allocations:
mr.allocation_counts
Out[6]:
{'current_bytes': 16,
'current_count': 1,
'peak_bytes': 16,
'peak_count': 1,
'total_bytes': 16,
'total_count': 1}
```
1 change: 1 addition & 0 deletions python/rmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
register_reinitialize_hook,
reinitialize,
rmm_cupy_allocator,
rmm_torch_allocator,
unregister_reinitialize_hook,
)

Expand Down
3 changes: 2 additions & 1 deletion python/rmm/_lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# the License.
# =============================================================================

set(cython_sources device_buffer.pyx lib.pyx memory_resource.pyx cuda_stream.pyx)
set(cython_sources device_buffer.pyx lib.pyx memory_resource.pyx cuda_stream.pyx
torch_allocator.pyx)
set(linked_libraries rmm::rmm)

# Build all of the Cython targets
Expand Down
8 changes: 8 additions & 0 deletions python/rmm/_lib/memory_resource.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@ from libcpp.memory cimport shared_ptr
from libcpp.string cimport string
from libcpp.vector cimport vector

from rmm._lib.cuda_stream_view cimport cuda_stream_view


cdef extern from "rmm/mr/device/device_memory_resource.hpp" \
namespace "rmm::mr" nogil:
cdef cppclass device_memory_resource:
void* allocate(size_t bytes) except +
void* allocate(size_t bytes, cuda_stream_view stream) except +
void deallocate(void* ptr, size_t bytes) except +
void deallocate(
void* ptr,
size_t bytes,
cuda_stream_view stream
) except +

cdef class DeviceMemoryResource:
cdef shared_ptr[device_memory_resource] c_obj
Expand Down
29 changes: 5 additions & 24 deletions python/rmm/_lib/memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ from cuda.cudart import cudaError_t
from rmm._cuda.gpu import CUDARuntimeError, getDevice, setDevice

from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm._lib.per_device_resource cimport (
cuda_device_id,
set_per_device_resource as cpp_set_per_device_resource,
)

# Transparent handle of a C++ exception
ctypedef pair[int, string] CppExcept
Expand Down Expand Up @@ -206,29 +210,6 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \
) except +


cdef extern from "rmm/mr/device/per_device_resource.hpp" namespace "rmm" nogil:

cdef cppclass cuda_device_id:
ctypedef int value_type

cuda_device_id(value_type id)

value_type value()

cdef device_memory_resource* _set_current_device_resource \
"rmm::mr::set_current_device_resource" (device_memory_resource* new_mr)
cdef device_memory_resource* _get_current_device_resource \
"rmm::mr::get_current_device_resource" ()

cdef device_memory_resource* _set_per_device_resource \
"rmm::mr::set_per_device_resource" (
cuda_device_id id,
device_memory_resource* new_mr
)
cdef device_memory_resource* _get_per_device_resource \
"rmm::mr::get_per_device_resource"(cuda_device_id id)


cdef class DeviceMemoryResource:

cdef device_memory_resource* get_mr(self):
Expand Down Expand Up @@ -967,7 +948,7 @@ cpdef set_per_device_resource(int device, DeviceMemoryResource mr):
cdef unique_ptr[cuda_device_id] device_id = \
make_unique[cuda_device_id](device)

_set_per_device_resource(deref(device_id), mr.get_mr())
cpp_set_per_device_resource(deref(device_id), mr.get_mr())


cpdef set_current_device_resource(DeviceMemoryResource mr):
Expand Down
23 changes: 23 additions & 0 deletions python/rmm/_lib/per_device_resource.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from rmm._lib.memory_resource cimport device_memory_resource


cdef extern from "rmm/mr/device/per_device_resource.hpp" namespace "rmm" nogil:
cdef cppclass cuda_device_id:
ctypedef int value_type

cuda_device_id(value_type id)

value_type value()

cdef extern from "rmm/mr/device/per_device_resource.hpp" \
namespace "rmm::mr" nogil:
cdef device_memory_resource* set_current_device_resource(
device_memory_resource* new_mr
)
cdef device_memory_resource* get_current_device_resource()
cdef device_memory_resource* set_per_device_resource(
cuda_device_id id, device_memory_resource* new_mr
)
cdef device_memory_resource* get_per_device_resource (
cuda_device_id id
)
24 changes: 24 additions & 0 deletions python/rmm/_lib/torch_allocator.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from cuda.ccudart cimport cudaStream_t

from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm._lib.memory_resource cimport device_memory_resource
from rmm._lib.per_device_resource cimport get_current_device_resource


cdef public void* allocate(
ssize_t size, int device, void* stream
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: device is ignored by design? (I was reviewing cupy/cupy#7210 and noticed this.)

Copy link
Contributor Author

@shwina shwina Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, great catch! This brings out a subtle problem:

In RMM, each device has its own memory resource. Thus, to do the allocation on a specified device with RMM, I would write the torch allocate function like this:

cdef public void* allocate(ssize_t size, int device, void* stream) except * with gil:                                                                                                                                                                                       
    cdef device_memory_resource* mr = get_per_device_resource(device)                                                                                                                            
    return mr[0].allocate(size, <cudaStream_t> stream)

Unforunately, the deallocation function does not accept a device argument, so we cannot retrieve the memory resource that was used for allocation:

void deallocate(void* ptr, ssize_t size, void* stream)

I don't really see a way around this other than for the deallocate signature to include the device argument. cc: @emcastillo

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love to know too. TBH I am puzzled by PyTorch's (long-time) behavior of asking for device. It should just honor the current device...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll submit a follow-up PR adding support for device, once pytorch/pytorch#91398 is merged.

) except * with gil:
cdef device_memory_resource* mr = get_current_device_resource()
cdef cuda_stream_view stream_view = cuda_stream_view(
<cudaStream_t>(stream)
)
return mr[0].allocate(size, stream_view)

cdef public void deallocate(
void* ptr, ssize_t size, void* stream
) except * with gil:
cdef device_memory_resource* mr = get_current_device_resource()
cdef cuda_stream_view stream_view = cuda_stream_view(
<cudaStream_t>(stream)
)
mr[0].deallocate(ptr, size, stream_view)
15 changes: 15 additions & 0 deletions python/rmm/rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,21 @@ def rmm_cupy_allocator(nbytes):
return ptr


try:
from torch.cuda.memory import CUDAPluggableAllocator
except ImportError:
rmm_torch_allocator = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this bite a user that accesses this attribute and passes it around thinking that it's fine when it's really just None? It might be safer to override __getattr__ for the module and have it raise an error to prevent the user from accessing this attribute when CUDAPluggableAllocator failed to import.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we alternatively pass on ImportError to achieve the same effect as defining that module __getattr__?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'd get close to the same effect, just a slightly less user-friendly version. With a __getattr__ override you could provide a more friendly error message indicating that this happened because the torch allocator failed to import, whereas if you just avoid defining it the user will see an AttributeError without any additional diagnostics and may think it's a bug in rmm.

It's a very minor point though, I'm fine leaving this as is for now and only revisiting in the future if we get a lot of user questions about why the allocator is None.

else:
import rmm._lib.torch_allocator

_alloc_free_lib_path = rmm._lib.torch_allocator.__file__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat!

rmm_torch_allocator = CUDAPluggableAllocator(
_alloc_free_lib_path,
alloc_fn_name="allocate",
free_fn_name="deallocate",
)
Comment on lines +248 to +252
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Would this honor rmm.reinitialize() if a user changes the MR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rmm.reinitialize() resets the default memory resource used by RMM. Each call to allocate() and deallocate() queries the default memory resource via a call to get_current_device_resource(), so -- yes.



def register_reinitialize_hook(func, *args, **kwargs):
"""
Add a function to the list of functions ("hooks") that will be
Expand Down
21 changes: 21 additions & 0 deletions python/rmm/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

import rmm


@pytest.fixture(scope="function", autouse=True)
def rmm_auto_reinitialize():
# Run the test
yield

# Automatically reinitialize the current memory resource after running each
# test

rmm.reinitialize()


@pytest.fixture
def stats_mr():
mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.CudaMemoryResource())
rmm.mr.set_current_device_resource(mr)
return mr
29 changes: 6 additions & 23 deletions python/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,6 @@
)


@pytest.fixture(scope="function", autouse=True)
def rmm_auto_reinitialize():

# Run the test
yield

# Automatically reinitialize the current memory resource after running each
# test
rmm.reinitialize()


def array_tester(dtype, nelem, alloc):
# data
h_in = np.full(nelem, 3.2, dtype)
Expand Down Expand Up @@ -604,20 +593,14 @@ def test_cuda_async_memory_resource_threshold(nelem, alloc):
array_tester("u1", 2 * nelem, alloc) # should trigger release


def test_statistics_resource_adaptor():

cuda_mr = rmm.mr.CudaMemoryResource()

mr = rmm.mr.StatisticsResourceAdaptor(cuda_mr)

rmm.mr.set_current_device_resource(mr)
def test_statistics_resource_adaptor(stats_mr):

buffers = [rmm.DeviceBuffer(size=1000) for _ in range(10)]

for i in range(9, 0, -2):
del buffers[i]

assert mr.allocation_counts == {
assert stats_mr.allocation_counts == {
"current_bytes": 5000,
"current_count": 5,
"peak_bytes": 10000,
Expand All @@ -627,7 +610,7 @@ def test_statistics_resource_adaptor():
}

# Push a new Tracking adaptor
mr2 = rmm.mr.StatisticsResourceAdaptor(mr)
mr2 = rmm.mr.StatisticsResourceAdaptor(stats_mr)
rmm.mr.set_current_device_resource(mr2)

for _ in range(2):
Expand All @@ -641,7 +624,7 @@ def test_statistics_resource_adaptor():
"total_bytes": 2000,
"total_count": 2,
}
assert mr.allocation_counts == {
assert stats_mr.allocation_counts == {
"current_bytes": 7000,
"current_count": 7,
"peak_bytes": 10000,
Expand All @@ -661,18 +644,18 @@ def test_statistics_resource_adaptor():
"total_bytes": 2000,
"total_count": 2,
}
assert mr.allocation_counts == {
assert stats_mr.allocation_counts == {
"current_bytes": 0,
"current_count": 0,
"peak_bytes": 10000,
"peak_count": 10,
"total_bytes": 12000,
"total_count": 12,
}
gc.collect()


def test_tracking_resource_adaptor():

cuda_mr = rmm.mr.CudaMemoryResource()

mr = rmm.mr.TrackingResourceAdaptor(cuda_mr, capture_stacks=True)
Expand Down
37 changes: 37 additions & 0 deletions python/rmm/tests/test_rmm_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import gc

import pytest

import rmm

torch = pytest.importorskip("torch")


@pytest.fixture(scope="session")
def torch_allocator():
try:
from torch.cuda.memory import change_current_allocator
except ImportError:
pytest.skip("pytorch pluggable allocator not available")
Comment on lines +12 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a pytest utility for this if you want to use it. pytest.importorskip

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that handle importing specific functions?

This is using importorskip for torch generally above. The torch.cuda.memory module has been around for a while. Though the functionality we need from it is pretty new.

Maybe in the future this could require a specific PyTorch version. There doesn't seem to be one yet that has what we need though.

Copy link
Contributor

@bdice bdice Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, pytest.importorskip only handles modules and you have to use attribute accessors to get functions. It's kind of messy. The current solution is probably easier to read, let's keep it as-is.

change_current_allocator(rmm.rmm_torch_allocator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this function behave if passed None (the case where the torch allocator hasn't been defined)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - I wouldn't expect it to be None if we were able to import change_current_allocator, since the existence of change_current_allocator implies that rmm_torch_allocator was defined (although somewhat implicitly: change_current_allocator and CudaPluggableAllocator in PyTorch were introduced together).

Should we also skip this test if rmm.rmm_torch_allocator is None for some reason?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK I wasn't sure about that, I thought they weren't introduced entirely concurrently. Up to you on the extra skip, it sounds like it would be pedantically correct but not practically necessary.



def test_rmm_torch_allocator(torch_allocator, stats_mr):
assert stats_mr.allocation_counts["current_bytes"] == 0
x = torch.tensor([1, 2]).cuda()
assert stats_mr.allocation_counts["current_bytes"] > 0
del x
shwina marked this conversation as resolved.
Show resolved Hide resolved
gc.collect()
assert stats_mr.allocation_counts["current_bytes"] == 0


def test_rmm_torch_allocator_using_stream(torch_allocator, stats_mr):
assert stats_mr.allocation_counts["current_bytes"] == 0
s = torch.cuda.Stream()
with torch.cuda.stream(s):
x = torch.tensor([1, 2]).cuda()
torch.cuda.current_stream().wait_stream(s)
assert stats_mr.allocation_counts["current_bytes"] > 0
del x
shwina marked this conversation as resolved.
Show resolved Hide resolved
gc.collect()
assert stats_mr.allocation_counts["current_bytes"] == 0