-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Python bindings for cuda_async_memory_resource
#718
Changes from all commits
4ce3db0
5f9e13d
3fe8da4
710de0c
d7abce8
93ee155
9870383
6321a7d
c31688e
09f1798
9e8d5be
74dd4fe
216b6ba
1d9032b
11bec76
e967899
b7d607b
ae1765d
1f8cdac
57f5d1b
f59cb92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) 2019-2020, NVIDIA CORPORATION. | ||
import filecmp | ||
import glob | ||
import os | ||
import re | ||
|
@@ -63,22 +64,40 @@ def get_cuda_version_from_header(cuda_include_dir): | |
# valid symbols for specific version of CUDA. | ||
|
||
cwd = os.getcwd() | ||
preprocess_files = ["gpu.pxd"] | ||
supported_cuda_versions = {"10.1", "10.2", "11.0"} | ||
|
||
for file_p in preprocess_files: | ||
pxi_file = ".".join(file_p.split(".")[:-1]) | ||
pxi_file = pxi_file + ".pxi" | ||
|
||
if CUDA_VERSION in supported_cuda_versions: | ||
shutil.copyfile( | ||
os.path.join(cwd, "rmm/_cuda", CUDA_VERSION, pxi_file), | ||
os.path.join(cwd, "rmm/_cuda", file_p), | ||
files_to_preprocess = ["gpu.pxd"] | ||
|
||
# The .pxi file is unchanged between some CUDA versions | ||
# (e.g., 11.0 & 11.1), so we keep only a single copy | ||
# of it | ||
cuda_version_to_pxi_dir = { | ||
"10.1": "10.1", | ||
"10.2": "10.2", | ||
"11.0": "11.x", | ||
"11.1": "11.x", | ||
"11.2": "11.x", | ||
} | ||
|
||
for pxd_basename in files_to_preprocess: | ||
pxi_basename = os.path.splitext(pxd_basename)[0] + ".pxi" | ||
if CUDA_VERSION in cuda_version_to_pxi_dir: | ||
pxi_pathname = os.path.join( | ||
cwd, | ||
"rmm/_cuda", | ||
cuda_version_to_pxi_dir[CUDA_VERSION], | ||
pxi_basename, | ||
) | ||
pxd_pathname = os.path.join(cwd, "rmm/_cuda", pxd_basename) | ||
try: | ||
if filecmp.cmp(pxi_pathname, pxd_pathname): | ||
# files are the same, no need to copy | ||
continue | ||
except FileNotFoundError: | ||
# pxd_pathname doesn't exist yet | ||
pass | ||
shutil.copyfile(pxi_pathname, pxd_pathname) | ||
kkraus14 marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+80
to
+97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move the cuda version check outside of the loop and invert it to reduce nesting? if CUDA_VERSION not in cuda_version_to_pxi_dir:
raise TypeError(f"{CUDA_VERSION} is not supported.") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would mean we always check, regardless of how many files we have to preprocess, so that might need to be accounted for. example: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed that this is low hanging fruit to fix and we may as well tackle it now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Woops, this merged before fixing this. Will raise an issue to tackle it in a follow up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sorry. I'll put in one tomorrow. |
||
else: | ||
raise TypeError(f"{CUDA_VERSION} is not supported.") | ||
|
||
|
||
try: | ||
nthreads = int(os.environ.get("PARALLEL_LEVEL", "0") or "0") | ||
except Exception: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What should failure look like here?
try..except
and re-raise with more information?driverGetVersion()
and duplicate the check for 11.2 in C++ and Python?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does the C++ error look like if someone tries to create this on CUDA 11.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/rapidsai/rmm/blob/branch-0.19/include/rmm/mr/device/cuda_async_memory_resource.hpp#L44-L54
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to improve the message here: https://github.com/rapidsai/rmm/blob/branch-0.19/include/rmm/mr/device/cuda_async_memory_resource.hpp#L53 to say that it was compiled without support instead of just the generic error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this part of the macro deals specifically with CUDA version < 11.2 -- @harrism any thoughts here on a possibly more informative error message? This will directly be propagated up to Python users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"cudaMallocAsync not supported by the version of the CUDA Toolkit used for compilation"? I don't want to say "... used to compile RMM" since RMM is header-only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improved the error message based on your suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You changed the wrong error message. :)