Skip to content

Commit

Permalink
[SYCL] Refactor address space casts functionality (#15543)
Browse files Browse the repository at this point in the history
This follows the interfaces designed in
https://github.com/intel/llvm/blob/3a1c3cb53566f904a73361d5c57b939d981564b5/sycl/doc/extensions/proposed/sycl_ext_oneapi_address_cast.asciidoc,
but instead of operating on `multi_ptr`, these work on decorated C++
pointers (as that's what we need throughout our implementation,
including `multi_ptr` implementation itself).

Basically, I've moved the implementation of the extension to the new
`detail::static|dynamic_address_cast` functions and replaced all uses of
the old `detail::cast_AS` (that had inconsistent static vs dynamic
behavior depending on address spaces/backends) and also uses of direct
SPIRV builtin/wrappers invocations.

This isn't NFC, because by doing that I've changed "dynamic" behavior to
"static" whenever the spec allows that (e.g. if it's UB if runtime
pointers doesn't point to a proper allocation).
  • Loading branch information
aelovikov-intel authored Oct 14, 2024
1 parent ba99338 commit 9ea0f20
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 436 deletions.
184 changes: 0 additions & 184 deletions sycl/include/sycl/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,190 +540,6 @@ __SPIRV_ATOMICS(__SPIRV_ATOMIC_MINMAX, Max)
#undef __SPIRV_ATOMIC_UNSIGNED
#undef __SPIRV_ATOMIC_XOR

template <typename dataT>
extern __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtrExplicit_ToGlobal(void *Ptr) noexcept {
return (__attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern const __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtrExplicit_ToGlobal(const void *Ptr) noexcept {
return (const __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern volatile __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtrExplicit_ToGlobal(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern const volatile __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtrExplicit_ToGlobal(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtrExplicit_ToLocal(void *Ptr) noexcept {
return (__attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
__spv::StorageClass::Workgroup);
}

template <typename dataT>
extern const __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtrExplicit_ToLocal(const void *Ptr) noexcept {
return (const __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
__spv::StorageClass::Workgroup);
}

template <typename dataT>
extern volatile __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtrExplicit_ToLocal(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
__spv::StorageClass::Workgroup);
}

template <typename dataT>
extern const volatile __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtrExplicit_ToLocal(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
__spv::StorageClass::Workgroup);
}

template <typename dataT>
extern __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtrExplicit_ToPrivate(void *Ptr) noexcept {
return (__attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
__spv::StorageClass::Function);
}

template <typename dataT>
extern const __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtrExplicit_ToPrivate(const void *Ptr) noexcept {
return (const __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
__spv::StorageClass::Function);
}

template <typename dataT>
extern volatile __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtrExplicit_ToPrivate(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
__spv::StorageClass::Function);
}

template <typename dataT>
extern const volatile __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtrExplicit_ToPrivate(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
__spv::StorageClass::Function);
}

template <typename dataT>
extern __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtr_ToGlobal(void *Ptr) noexcept {
return (__attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtr_ToGlobal(Ptr,
__spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern const __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtr_ToGlobal(const void *Ptr) noexcept {
return (const __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtr_ToGlobal(Ptr,
__spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern volatile __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtr_ToGlobal(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtr_ToGlobal(Ptr,
__spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern const volatile __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtr_ToGlobal(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtr_ToGlobal(Ptr,
__spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtr_ToLocal(void *Ptr) noexcept {
return (__attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
}

template <typename dataT>
extern const __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtr_ToLocal(const void *Ptr) noexcept {
return (const __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
}

template <typename dataT>
extern volatile __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtr_ToLocal(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
}

template <typename dataT>
extern const volatile __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtr_ToLocal(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
}

template <typename dataT>
extern __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtr_ToPrivate(void *Ptr) noexcept {
return (__attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
}

template <typename dataT>
extern const __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtr_ToPrivate(const void *Ptr) noexcept {
return (const __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
}

template <typename dataT>
extern volatile __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtr_ToPrivate(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
}

template <typename dataT>
extern const volatile __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtr_ToPrivate(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
}

template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL dataT
__spirv_SubgroupShuffleINTEL(dataT Data, uint32_t InvocationId) noexcept;
Expand Down
190 changes: 143 additions & 47 deletions sycl/include/sycl/access/access.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,58 +325,154 @@ template <typename T>
using remove_decoration_t = typename remove_decoration<T>::type;

namespace detail {

// Helper function for selecting appropriate casts between address spaces.
template <typename ToT, typename FromT> inline ToT cast_AS(FromT from) {
#ifdef __SYCL_DEVICE_ONLY__
constexpr access::address_space ToAS = deduce_AS<ToT>::value;
constexpr access::address_space FromAS = deduce_AS<FromT>::value;
if constexpr (FromAS == access::address_space::generic_space) {
#if defined(__NVPTX__) || defined(__AMDGCN__) || defined(__SYCL_NATIVE_CPU__)
// TODO: NVPTX and AMDGCN backends do not currently support the
// __spirv_GenericCastToPtrExplicit_* builtins, so to work around this
// we do C-style casting. This may produce warnings when targetting
// these backends.
return (ToT)from;
inline constexpr bool
address_space_cast_is_possible(access::address_space Src,
access::address_space Dst) {
// constant_space is unique and is not interchangeable with any other.
auto constant_space = access::address_space::constant_space;
if (Src == constant_space || Dst == constant_space)
return Src == Dst;

auto generic_space = access::address_space::generic_space;
if (Src == Dst || Src == generic_space || Dst == generic_space)
return true;

// global_host/global_device could be casted to/from global
auto global_space = access::address_space::global_space;
auto global_device = access::address_space::ext_intel_global_device_space;
auto global_host = access::address_space::ext_intel_global_host_space;

if (Src == global_space || Dst == global_space) {
auto Other = Src == global_space ? Dst : Src;
if (Other == global_device || Other == global_host)
return true;
}

// No more compatible combinations.
return false;
}

template <access::address_space Space, typename ElementType>
auto static_address_cast(ElementType *Ptr) {
constexpr auto generic_space = access::address_space::generic_space;
constexpr auto global_space = access::address_space::global_space;
constexpr auto local_space = access::address_space::local_space;
constexpr auto private_space = access::address_space::private_space;
constexpr auto global_device =
access::address_space::ext_intel_global_device_space;
constexpr auto global_host =
access::address_space::ext_intel_global_host_space;

constexpr auto SrcAS = deduce_AS<ElementType *>::value;
static_assert(address_space_cast_is_possible(SrcAS, Space));

using dst_type = typename DecoratedType<
std::remove_pointer_t<remove_decoration_t<ElementType *>>, Space>::type *;

// Note: reinterpret_cast isn't enough for some of the casts between different
// address spaces, use C-style cast instead.
#if !defined(__SPIR__)
return (dst_type)Ptr;
#else
using ToElemT = std::remove_pointer_t<remove_decoration_t<ToT>>;
if constexpr (ToAS == access::address_space::global_space)
return __SYCL_GenericCastToPtrExplicit_ToGlobal<ToElemT>(from);
else if constexpr (ToAS == access::address_space::local_space)
return __SYCL_GenericCastToPtrExplicit_ToLocal<ToElemT>(from);
else if constexpr (ToAS == access::address_space::private_space)
return __SYCL_GenericCastToPtrExplicit_ToPrivate<ToElemT>(from);
#ifdef __ENABLE_USM_ADDR_SPACE__
else if constexpr (ToAS == access::address_space::
ext_intel_global_device_space ||
ToAS ==
access::address_space::ext_intel_global_host_space)
// For extended address spaces we do not currently have a SPIR-V
// conversion function, so we do a C-style cast. This may produce
// warnings.
return (ToT)from;
#endif // __ENABLE_USM_ADDR_SPACE__
else
return reinterpret_cast<ToT>(from);
#endif // defined(__NVPTX__) || defined(__AMDGCN__)
} else
#ifdef __ENABLE_USM_ADDR_SPACE__
if constexpr (FromAS == access::address_space::global_space &&
(ToAS ==
access::address_space::ext_intel_global_device_space ||
ToAS ==
access::address_space::ext_intel_global_host_space)) {
// Casting from global address space to the global device and host address
// spaces is allowed.
return (ToT)from;
} else
#endif // __ENABLE_USM_ADDR_SPACE__
#endif // __SYCL_DEVICE_ONLY__
{
return reinterpret_cast<ToT>(from);
if constexpr (SrcAS != generic_space) {
return (dst_type)Ptr;
} else if constexpr (Space == global_space) {
return (dst_type)__spirv_GenericCastToPtr_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
} else if constexpr (Space == local_space) {
return (dst_type)__spirv_GenericCastToPtr_ToLocal(
Ptr, __spv::StorageClass::Workgroup);
} else if constexpr (Space == private_space) {
return (dst_type)__spirv_GenericCastToPtr_ToPrivate(
Ptr, __spv::StorageClass::Function);
#if !defined(__ENABLE_USM_ADDR_SPACE__)
} else if constexpr (Space == global_device || Space == global_host) {
// If __ENABLE_USM_ADDR_SPACE__ isn't defined then both
// global_device/global_host are just aliases for global_space.
return (dst_type)__spirv_GenericCastToPtr_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
#endif
} else {
return (dst_type)Ptr;
}
#endif
}

// Previous implementation (`castAS`, used in `multi_ptr` ctors among other
// places), used C-style cast instead of a proper dynamic check for some
// backends/spaces. `SupressNotImplementedAssert = true` parameter is emulating
// that previous behavior until the proper support is added for compatibility
// reasons.
template <access::address_space Space, bool SupressNotImplementedAssert = false,
typename ElementType>
auto dynamic_address_cast(ElementType *Ptr) {
constexpr auto generic_space = access::address_space::generic_space;
constexpr auto global_space = access::address_space::global_space;
constexpr auto local_space = access::address_space::local_space;
constexpr auto private_space = access::address_space::private_space;
constexpr auto global_device =
access::address_space::ext_intel_global_device_space;
constexpr auto global_host =
access::address_space::ext_intel_global_host_space;

constexpr auto SrcAS = deduce_AS<ElementType *>::value;
using dst_type = typename DecoratedType<
std::remove_pointer_t<remove_decoration_t<ElementType *>>, Space>::type *;

if constexpr (!address_space_cast_is_possible(SrcAS, Space)) {
return (dst_type) nullptr;
} else if constexpr (Space == generic_space) {
return (dst_type)Ptr;
} else if constexpr (Space == global_space &&
(SrcAS == global_device || SrcAS == global_host)) {
return (dst_type)Ptr;
} else if constexpr (SrcAS == global_space &&
(Space == global_device || Space == global_host)) {
#if defined(__ENABLE_USM_ADDR_SPACE__)
static_assert(SupressNotImplementedAssert || Space != Space,
"Not supported yet!");
return static_address_cast<Space>(Ptr);
#else
// If __ENABLE_USM_ADDR_SPACE__ isn't defined then both
// global_device/global_host are just aliases for global_space.
static_assert(std::is_same_v<dst_type, ElementType *>);
return (dst_type)Ptr;
#endif
#if defined(__SPIR__)
} else if constexpr (Space == global_space) {
return (dst_type)__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
} else if constexpr (Space == local_space) {
return (dst_type)__spirv_GenericCastToPtrExplicit_ToLocal(
Ptr, __spv::StorageClass::Workgroup);
} else if constexpr (Space == private_space) {
return (dst_type)__spirv_GenericCastToPtrExplicit_ToPrivate(
Ptr, __spv::StorageClass::Function);
#if !defined(__ENABLE_USM_ADDR_SPACE__)
} else if constexpr (SrcAS == generic_space &&
(Space == global_device || Space == global_host)) {
return (dst_type)__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
#endif
#endif
} else {
static_assert(SupressNotImplementedAssert || Space != Space,
"Not supported yet!");
return static_address_cast<Space>(Ptr);
}
}
#else // __SYCL_DEVICE_ONLY__
template <access::address_space Space, typename ElementType>
auto static_address_cast(ElementType *Ptr) {
return Ptr;
}
template <access::address_space Space, bool SupressNotImplementedAssert = false,
typename ElementType>
auto dynamic_address_cast(ElementType *Ptr) {
return Ptr;
}
#endif // __SYCL_DEVICE_ONLY__
} // namespace detail

#undef __OPENCL_GLOBAL_AS__
Expand Down
Loading

0 comments on commit 9ea0f20

Please sign in to comment.