-
Notifications
You must be signed in to change notification settings - Fork 195
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Closes #1144 This PR adds an RMM-based allocator for PyTorch, `rmm.rmm_torch_allocator`. This enables, e.g., using the same memory pool in code that uses both RAPIDS and PyTorch. It also enables PyTorch to use all of the different memory resources provided by RMM. For example: ```python import rmm import torch torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator) base_mr = rmm.mr.CudaMemoryResource() def allocate_func(size): print(f"Allocating {size} bytes") return base_mr.allocate(size) def deallocate_func(ptr, size): print(f"Deallocating {size} bytes") return base_mr.deallocate(ptr, size) rmm.mr.set_current_device_resource( rmm.mr.CallbackMemoryResource(allocate_func, deallocate_func) ) x = torch.tensor([1, 2]).cuda() del x y = torch.tensor([1, 2, 3]).cuda() del y ``` Output: ``` Allocating 16 bytes Deallocating 16 bytes Allocating 24 bytes Deallocating 24 bytes ``` Authors: - Ashwin Srinath (https://github.com/shwina) Approvers: - Mark Harris (https://github.com/harrism) - Bradley Dice (https://github.com/bdice) - Vyas Ramasubramani (https://github.com/vyasr) URL: #1168
- Loading branch information
Showing
11 changed files
with
189 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from rmm._lib.memory_resource cimport device_memory_resource | ||
|
||
|
||
cdef extern from "rmm/mr/device/per_device_resource.hpp" namespace "rmm" nogil: | ||
cdef cppclass cuda_device_id: | ||
ctypedef int value_type | ||
|
||
cuda_device_id(value_type id) | ||
|
||
value_type value() | ||
|
||
cdef extern from "rmm/mr/device/per_device_resource.hpp" \ | ||
namespace "rmm::mr" nogil: | ||
cdef device_memory_resource* set_current_device_resource( | ||
device_memory_resource* new_mr | ||
) | ||
cdef device_memory_resource* get_current_device_resource() | ||
cdef device_memory_resource* set_per_device_resource( | ||
cuda_device_id id, device_memory_resource* new_mr | ||
) | ||
cdef device_memory_resource* get_per_device_resource ( | ||
cuda_device_id id | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from cuda.ccudart cimport cudaStream_t | ||
|
||
from rmm._lib.cuda_stream_view cimport cuda_stream_view | ||
from rmm._lib.memory_resource cimport device_memory_resource | ||
from rmm._lib.per_device_resource cimport get_current_device_resource | ||
|
||
|
||
cdef public void* allocate( | ||
ssize_t size, int device, void* stream | ||
) except * with gil: | ||
cdef device_memory_resource* mr = get_current_device_resource() | ||
cdef cuda_stream_view stream_view = cuda_stream_view( | ||
<cudaStream_t>(stream) | ||
) | ||
return mr[0].allocate(size, stream_view) | ||
|
||
cdef public void deallocate( | ||
void* ptr, ssize_t size, void* stream | ||
) except * with gil: | ||
cdef device_memory_resource* mr = get_current_device_resource() | ||
cdef cuda_stream_view stream_view = cuda_stream_view( | ||
<cudaStream_t>(stream) | ||
) | ||
mr[0].deallocate(ptr, size, stream_view) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
|
||
import rmm | ||
|
||
|
||
@pytest.fixture(scope="function", autouse=True) | ||
def rmm_auto_reinitialize(): | ||
# Run the test | ||
yield | ||
|
||
# Automatically reinitialize the current memory resource after running each | ||
# test | ||
|
||
rmm.reinitialize() | ||
|
||
|
||
@pytest.fixture | ||
def stats_mr(): | ||
mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.CudaMemoryResource()) | ||
rmm.mr.set_current_device_resource(mr) | ||
return mr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import gc | ||
|
||
import pytest | ||
|
||
import rmm | ||
|
||
torch = pytest.importorskip("torch") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def torch_allocator(): | ||
try: | ||
from torch.cuda.memory import change_current_allocator | ||
except ImportError: | ||
pytest.skip("pytorch pluggable allocator not available") | ||
change_current_allocator(rmm.rmm_torch_allocator) | ||
|
||
|
||
def test_rmm_torch_allocator(torch_allocator, stats_mr): | ||
assert stats_mr.allocation_counts["current_bytes"] == 0 | ||
x = torch.tensor([1, 2]).cuda() | ||
assert stats_mr.allocation_counts["current_bytes"] > 0 | ||
del x | ||
gc.collect() | ||
assert stats_mr.allocation_counts["current_bytes"] == 0 | ||
|
||
|
||
def test_rmm_torch_allocator_using_stream(torch_allocator, stats_mr): | ||
assert stats_mr.allocation_counts["current_bytes"] == 0 | ||
s = torch.cuda.Stream() | ||
with torch.cuda.stream(s): | ||
x = torch.tensor([1, 2]).cuda() | ||
torch.cuda.current_stream().wait_stream(s) | ||
assert stats_mr.allocation_counts["current_bytes"] > 0 | ||
del x | ||
gc.collect() | ||
assert stats_mr.allocation_counts["current_bytes"] == 0 |