Skip to content

Commit

Permalink
Add support for comma separated strings
Browse files Browse the repository at this point in the history
Signed-off-by: Vibhu Jawa <vibhujawa@gmail.com>
  • Loading branch information
VibhuJawa committed Oct 11, 2024
1 parent 33871ac commit 2517874
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,11 @@ class LocalCUDACluster(LocalCluster):
The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also
incompatible with RMM pools and managed memory. Trying to enable both will
result in an exception.
rmm_allocator_external_lib_list: list or None, default None
rmm_allocator_external_lib_list: str, list or None, default None
List of external libraries for which to set RMM as the allocator.
Supported options are: ``["torch", "cupy"]``. If ``None``, no external
libraries will use RMM as their allocator.
Supported options are: ``["torch", "cupy"]``. Can be a comma-separated string
(like ``"torch,cupy"``) or a list of strings (like ``["torch", "cupy"]``).
If ``None``, no external libraries will use RMM as their allocator.
rmm_release_threshold: int, str or None, default None
When ``rmm.async is True`` and the pool size grows beyond this value, unused
memory held by the pool will be released at the next synchronization point.
Expand Down Expand Up @@ -271,15 +272,17 @@ def __init__(
if n_workers < 1:
raise ValueError("Number of workers cannot be less than 1.")

if rmm_allocator_external_lib_list is not None and not isinstance(
rmm_allocator_external_lib_list, list
):
raise ValueError(
"rmm_allocator_external_lib_list must be a list of strings. "
"Valid examples: ['torch'], ['cupy'], or ['torch', 'cupy']. "
f"Received: {type(rmm_allocator_external_lib_list)} "
f"with value: {rmm_allocator_external_lib_list}"
)
if rmm_allocator_external_lib_list is not None:
if isinstance(rmm_allocator_external_lib_list, str):
rmm_allocator_external_lib_list = [
v.strip() for v in rmm_allocator_external_lib_list.split(",")
]
elif not isinstance(rmm_allocator_external_lib_list, list):
raise ValueError(
"rmm_allocator_external_lib_list must be either a comma-separated "
"string or a list of strings. Examples: 'torch,cupy' "
"or ['torch', 'cupy']"
)

# Set nthreads=1 when parsing mem_limit since it only depends on n_workers
logger = logging.getLogger(__name__)
Expand Down

0 comments on commit 2517874

Please sign in to comment.