Skip to content

Commit

Permalink
Add RMM PyTorch allocator (#1168)
Browse files Browse the repository at this point in the history
Closes #1144 

This PR adds an RMM-based allocator for PyTorch, `rmm.rmm_torch_allocator`.

This enables, e.g., using the same memory pool in code that uses both RAPIDS and PyTorch. It also enables PyTorch to use all of the different memory resources provided by RMM. For example:

```python
import rmm                                                                                                                                                                                                 
import torch

torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator)
                                                                                                                                                                                                           
base_mr = rmm.mr.CudaMemoryResource()                                                                                                                                                                      
                                                                                                                                                                                                           
def allocate_func(size):                                                                                                                                                                                   
    print(f"Allocating {size} bytes")                                                                                                                                                                      
    return base_mr.allocate(size)                                                                                                                                                                          
                                                                                                                                                                                                           
def deallocate_func(ptr, size):                                                                                                                                                                            
    print(f"Deallocating {size} bytes")                                                                                                                                                                    
    return base_mr.deallocate(ptr, size)                                                                                                                                                                   
                                                                                                                                                                                                           
rmm.mr.set_current_device_resource(                                                                                                                                                                        
    rmm.mr.CallbackMemoryResource(allocate_func, deallocate_func)                                                                                                                                          
)                                                                                                                                                                                                          
                                                                                                                                                                                                           
x = torch.tensor([1, 2]).cuda()                                                                                                                                                                            
del x                                                                                                                                                                                                      
y = torch.tensor([1, 2, 3]).cuda()                                                                                                                                                                         
del y                                                                                                                                                                                                      
```

Output:

```
Allocating 16 bytes
Deallocating 16 bytes
Allocating 24 bytes
Deallocating 24 bytes
```

Authors:
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - Mark Harris (https://github.com/harrism)
  - Bradley Dice (https://github.com/bdice)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #1168
  • Loading branch information
shwina authored Jan 5, 2023
1 parent 7fb45d4 commit e324ace
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 48 deletions.
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -732,3 +732,50 @@ This can be done in two ways:
**Note:** This only configures Numba to use the current RMM resource for allocations.
It does not initialize nor change the current resource, e.g., enabling a memory pool.
See [here](#memoryresource-objects) for more information on changing the current memory resource.

### Using RMM with PyTorch

[PyTorch](https://pytorch.org/docs/stable/notes/cuda.html) can use RMM
for memory allocation. For example, to configure PyTorch to use an
RMM-managed pool:

```python
import rmm
import torch

rmm.reinitialize(pool_allocator=True)
torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator)
```

PyTorch and RMM will now share the same memory pool.

You can, of course, use a custom memory resource with PyTorch as well:

```python
import rmm
import torch

# note that you can configure PyTorch to use RMM either before or
# after changing RMM's memory resource. PyTorch will use whatever
# memory resource is configured to be the "current" memory resource at
# the time of allocation.
torch.cuda.change_current_allocator(rmm.rmm_torch_allocator)

# configure RMM to use a managed memory resource, wrapped with a
# statistics resource adaptor that can report information about the
# amount of memory allocated:
mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.ManagedMemoryResource())
rmm.mr.set_current_device_resource(mr)

x = torch.tensor([1, 2]).cuda()

# the memory resource reports information about PyTorch allocations:
mr.allocation_counts
Out[6]:
{'current_bytes': 16,
'current_count': 1,
'peak_bytes': 16,
'peak_count': 1,
'total_bytes': 16,
'total_count': 1}
```
1 change: 1 addition & 0 deletions python/rmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
register_reinitialize_hook,
reinitialize,
rmm_cupy_allocator,
rmm_torch_allocator,
unregister_reinitialize_hook,
)

Expand Down
3 changes: 2 additions & 1 deletion python/rmm/_lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# the License.
# =============================================================================

set(cython_sources device_buffer.pyx lib.pyx memory_resource.pyx cuda_stream.pyx)
set(cython_sources device_buffer.pyx lib.pyx memory_resource.pyx cuda_stream.pyx
torch_allocator.pyx)
set(linked_libraries rmm::rmm)

# Build all of the Cython targets
Expand Down
8 changes: 8 additions & 0 deletions python/rmm/_lib/memory_resource.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@ from libcpp.memory cimport shared_ptr
from libcpp.string cimport string
from libcpp.vector cimport vector

from rmm._lib.cuda_stream_view cimport cuda_stream_view


cdef extern from "rmm/mr/device/device_memory_resource.hpp" \
namespace "rmm::mr" nogil:
cdef cppclass device_memory_resource:
void* allocate(size_t bytes) except +
void* allocate(size_t bytes, cuda_stream_view stream) except +
void deallocate(void* ptr, size_t bytes) except +
void deallocate(
void* ptr,
size_t bytes,
cuda_stream_view stream
) except +

cdef class DeviceMemoryResource:
cdef shared_ptr[device_memory_resource] c_obj
Expand Down
29 changes: 5 additions & 24 deletions python/rmm/_lib/memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ from cuda.cudart import cudaError_t
from rmm._cuda.gpu import CUDARuntimeError, getDevice, setDevice

from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm._lib.per_device_resource cimport (
cuda_device_id,
set_per_device_resource as cpp_set_per_device_resource,
)

# Transparent handle of a C++ exception
ctypedef pair[int, string] CppExcept
Expand Down Expand Up @@ -206,29 +210,6 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \
) except +


cdef extern from "rmm/mr/device/per_device_resource.hpp" namespace "rmm" nogil:

cdef cppclass cuda_device_id:
ctypedef int value_type

cuda_device_id(value_type id)

value_type value()

cdef device_memory_resource* _set_current_device_resource \
"rmm::mr::set_current_device_resource" (device_memory_resource* new_mr)
cdef device_memory_resource* _get_current_device_resource \
"rmm::mr::get_current_device_resource" ()

cdef device_memory_resource* _set_per_device_resource \
"rmm::mr::set_per_device_resource" (
cuda_device_id id,
device_memory_resource* new_mr
)
cdef device_memory_resource* _get_per_device_resource \
"rmm::mr::get_per_device_resource"(cuda_device_id id)


cdef class DeviceMemoryResource:

cdef device_memory_resource* get_mr(self):
Expand Down Expand Up @@ -967,7 +948,7 @@ cpdef set_per_device_resource(int device, DeviceMemoryResource mr):
cdef unique_ptr[cuda_device_id] device_id = \
make_unique[cuda_device_id](device)

_set_per_device_resource(deref(device_id), mr.get_mr())
cpp_set_per_device_resource(deref(device_id), mr.get_mr())


cpdef set_current_device_resource(DeviceMemoryResource mr):
Expand Down
23 changes: 23 additions & 0 deletions python/rmm/_lib/per_device_resource.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from rmm._lib.memory_resource cimport device_memory_resource


cdef extern from "rmm/mr/device/per_device_resource.hpp" namespace "rmm" nogil:
cdef cppclass cuda_device_id:
ctypedef int value_type

cuda_device_id(value_type id)

value_type value()

cdef extern from "rmm/mr/device/per_device_resource.hpp" \
namespace "rmm::mr" nogil:
cdef device_memory_resource* set_current_device_resource(
device_memory_resource* new_mr
)
cdef device_memory_resource* get_current_device_resource()
cdef device_memory_resource* set_per_device_resource(
cuda_device_id id, device_memory_resource* new_mr
)
cdef device_memory_resource* get_per_device_resource (
cuda_device_id id
)
24 changes: 24 additions & 0 deletions python/rmm/_lib/torch_allocator.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from cuda.ccudart cimport cudaStream_t

from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm._lib.memory_resource cimport device_memory_resource
from rmm._lib.per_device_resource cimport get_current_device_resource


cdef public void* allocate(
ssize_t size, int device, void* stream
) except * with gil:
cdef device_memory_resource* mr = get_current_device_resource()
cdef cuda_stream_view stream_view = cuda_stream_view(
<cudaStream_t>(stream)
)
return mr[0].allocate(size, stream_view)

cdef public void deallocate(
void* ptr, ssize_t size, void* stream
) except * with gil:
cdef device_memory_resource* mr = get_current_device_resource()
cdef cuda_stream_view stream_view = cuda_stream_view(
<cudaStream_t>(stream)
)
mr[0].deallocate(ptr, size, stream_view)
15 changes: 15 additions & 0 deletions python/rmm/rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,21 @@ def rmm_cupy_allocator(nbytes):
return ptr


try:
from torch.cuda.memory import CUDAPluggableAllocator
except ImportError:
rmm_torch_allocator = None
else:
import rmm._lib.torch_allocator

_alloc_free_lib_path = rmm._lib.torch_allocator.__file__
rmm_torch_allocator = CUDAPluggableAllocator(
_alloc_free_lib_path,
alloc_fn_name="allocate",
free_fn_name="deallocate",
)


def register_reinitialize_hook(func, *args, **kwargs):
"""
Add a function to the list of functions ("hooks") that will be
Expand Down
21 changes: 21 additions & 0 deletions python/rmm/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

import rmm


@pytest.fixture(scope="function", autouse=True)
def rmm_auto_reinitialize():
# Run the test
yield

# Automatically reinitialize the current memory resource after running each
# test

rmm.reinitialize()


@pytest.fixture
def stats_mr():
mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.CudaMemoryResource())
rmm.mr.set_current_device_resource(mr)
return mr
29 changes: 6 additions & 23 deletions python/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,6 @@
)


@pytest.fixture(scope="function", autouse=True)
def rmm_auto_reinitialize():

# Run the test
yield

# Automatically reinitialize the current memory resource after running each
# test
rmm.reinitialize()


def array_tester(dtype, nelem, alloc):
# data
h_in = np.full(nelem, 3.2, dtype)
Expand Down Expand Up @@ -604,20 +593,14 @@ def test_cuda_async_memory_resource_threshold(nelem, alloc):
array_tester("u1", 2 * nelem, alloc) # should trigger release


def test_statistics_resource_adaptor():

cuda_mr = rmm.mr.CudaMemoryResource()

mr = rmm.mr.StatisticsResourceAdaptor(cuda_mr)

rmm.mr.set_current_device_resource(mr)
def test_statistics_resource_adaptor(stats_mr):

buffers = [rmm.DeviceBuffer(size=1000) for _ in range(10)]

for i in range(9, 0, -2):
del buffers[i]

assert mr.allocation_counts == {
assert stats_mr.allocation_counts == {
"current_bytes": 5000,
"current_count": 5,
"peak_bytes": 10000,
Expand All @@ -627,7 +610,7 @@ def test_statistics_resource_adaptor():
}

# Push a new Tracking adaptor
mr2 = rmm.mr.StatisticsResourceAdaptor(mr)
mr2 = rmm.mr.StatisticsResourceAdaptor(stats_mr)
rmm.mr.set_current_device_resource(mr2)

for _ in range(2):
Expand All @@ -641,7 +624,7 @@ def test_statistics_resource_adaptor():
"total_bytes": 2000,
"total_count": 2,
}
assert mr.allocation_counts == {
assert stats_mr.allocation_counts == {
"current_bytes": 7000,
"current_count": 7,
"peak_bytes": 10000,
Expand All @@ -661,18 +644,18 @@ def test_statistics_resource_adaptor():
"total_bytes": 2000,
"total_count": 2,
}
assert mr.allocation_counts == {
assert stats_mr.allocation_counts == {
"current_bytes": 0,
"current_count": 0,
"peak_bytes": 10000,
"peak_count": 10,
"total_bytes": 12000,
"total_count": 12,
}
gc.collect()


def test_tracking_resource_adaptor():

cuda_mr = rmm.mr.CudaMemoryResource()

mr = rmm.mr.TrackingResourceAdaptor(cuda_mr, capture_stacks=True)
Expand Down
37 changes: 37 additions & 0 deletions python/rmm/tests/test_rmm_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import gc

import pytest

import rmm

torch = pytest.importorskip("torch")


@pytest.fixture(scope="session")
def torch_allocator():
try:
from torch.cuda.memory import change_current_allocator
except ImportError:
pytest.skip("pytorch pluggable allocator not available")
change_current_allocator(rmm.rmm_torch_allocator)


def test_rmm_torch_allocator(torch_allocator, stats_mr):
assert stats_mr.allocation_counts["current_bytes"] == 0
x = torch.tensor([1, 2]).cuda()
assert stats_mr.allocation_counts["current_bytes"] > 0
del x
gc.collect()
assert stats_mr.allocation_counts["current_bytes"] == 0


def test_rmm_torch_allocator_using_stream(torch_allocator, stats_mr):
assert stats_mr.allocation_counts["current_bytes"] == 0
s = torch.cuda.Stream()
with torch.cuda.stream(s):
x = torch.tensor([1, 2]).cuda()
torch.cuda.current_stream().wait_stream(s)
assert stats_mr.allocation_counts["current_bytes"] > 0
del x
gc.collect()
assert stats_mr.allocation_counts["current_bytes"] == 0

0 comments on commit e324ace

Please sign in to comment.