diff --git a/CMakeLists.txt b/CMakeLists.txt index 50d3ee169dd..3c963b59e5e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,7 @@ option(BUILD_PYTHON "" ON) option(BUILD_MONOLITHIC_LIBONEFLOW "" ON) option(BUILD_RDMA "" OFF) option(BUILD_CUDA "" ON) +option(WITH_ONEDNN "" OFF) option(BUILD_TESTING "" OFF) option(WITH_XLA "Option to build with XLA" OFF) option(WITH_TENSORRT "Option to build with TensorRT" OFF) @@ -49,6 +50,7 @@ if (APPLE) set(RPC_BACKEND "LOCAL") set(BUILD_CUDA OFF) set(WITH_COCOAPI OFF) + set(WITH_ONEDNN OFF) endif() if (CMAKE_BUILD_TYPE MATCHES Debug) @@ -65,6 +67,11 @@ if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") message(FATAL_ERROR ${COMPILER_VERSION_ERROR_MSG}) endif() elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") +# Reference: +# https://releases.llvm.org/11.0.0/tools/clang/docs/OpenMPSupport.html + if("${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 11) + set(WITH_ONEDNN OFF) + endif() if("${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 5) message(FATAL_ERROR ${COMPILER_VERSION_ERROR_MSG}) endif() @@ -145,6 +152,9 @@ if (RPC_BACKEND MATCHES "GRPC") message(STATUS "RPC backend enabled: gRPC") set(SUPPORTED_RPC_BACKEND_FOUND 1) endif() +if (WITH_ONEDNN) + add_definitions(-DWITH_ONEDNN) +endif() add_definitions(-DRPC_BACKEND_LOCAL) message(STATUS "RPC backend enabled: local") enable_testing() @@ -180,6 +190,9 @@ else() if (APPLE) set(EXTRA_CXX_FLAGS "${EXTRA_CXX_FLAGS} -Wno-deprecated-declarations -Wno-mismatched-tags") endif() + if(WITH_ONEDNN) + set(EXTRA_CXX_FLAGS "${EXTRA_CXX_FLAGS} -fopenmp") + endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${EXTRA_CXX_FLAGS}") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} ${EXTRA_CXX_FLAGS}") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${EXTRA_CXX_FLAGS}") diff --git a/cmake/caches/ci/cpu.cmake b/cmake/caches/ci/cpu.cmake index f8ff7413600..cb8547b061d 100644 --- a/cmake/caches/ci/cpu.cmake +++ b/cmake/caches/ci/cpu.cmake @@ -1,6 +1,7 @@ set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_GIT_VERSION YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") +set(WITH_ONEDNN YES CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index e8f05a07b7f..e24a255a063 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -39,6 +39,10 @@ if (WITH_TENSORRT) endif() include(hwloc) +if (WITH_ONEDNN) + include(oneDNN) +endif() + option(CUDA_STATIC "" ON) @@ -143,6 +147,9 @@ set(oneflow_third_party_libs ${FLATBUFFERS_STATIC_LIBRARIES} ${LZ4_STATIC_LIBRARIES} ) +if (WITH_ONEDNN) + set(oneflow_third_party_libs ${oneflow_third_party_libs} ${ONEDNN_STATIC_LIBRARIES}) +endif() if (NOT WITH_XLA) list(APPEND oneflow_third_party_libs ${RE2_LIBRARIES}) @@ -171,6 +178,10 @@ set(oneflow_third_party_dependencies lz4_copy_libs_to_destination lz4_copy_headers_to_destination ) +if (WITH_ONEDNN) + list(APPEND oneflow_third_party_dependencies onednn) +endif() + if (WITH_COCOAPI) list(APPEND oneflow_third_party_dependencies cocoapi_copy_headers_to_destination) @@ -201,6 +212,10 @@ list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${FLATBUFFERS_INCLUDE_DIR} ${LZ4_INCLUDE_DIR} ) +if (WITH_ONEDNN) + list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${ONEDNN_INCLUDE_DIR}) +endif() + if (NOT WITH_XLA) list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${RE2_INCLUDE_DIR}) diff --git a/cmake/third_party/oneDNN.cmake b/cmake/third_party/oneDNN.cmake new file mode 100644 index 00000000000..e56ed55bbf4 --- /dev/null +++ b/cmake/third_party/oneDNN.cmake @@ -0,0 +1,62 @@ +include (ExternalProject) +include(GNUInstallDirs) + +set(ONEDNN_INSTALL_DIR ${THIRD_PARTY_DIR}/onednn) +set(ONEDNN_INCLUDE_DIR ${ONEDNN_INSTALL_DIR}/include) +set(ONEDNN_LIBRARY_DIR ${ONEDNN_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}) + +set(ONEDNN_URL https://github.com/oneapi-src/oneDNN/archive/refs/tags/v2.4.3.tar.gz) +use_mirror(VARIABLE ONEDNN_URL URL ${ONEDNN_URL}) + +if(WIN32) + message(FATAL_ERROR "Windows system does not support onednn") +else() + if(BUILD_SHARED_LIBS) + if("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".dylib") + set(ONEDNN_LIBRARY_NAMES libdnnl.dylib) + elseif("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".so") + set(ONEDNN_LIBRARY_NAMES libdnnl.so) + set(DNNL_LIBRARY_TYPE SHARED) + else() + message(FATAL_ERROR "${CMAKE_SHARED_LIBRARY_SUFFIX} not support for onednn") + endif() + else() + set(ONEDNN_LIBRARY_NAMES libdnnl.a ) + set(DNNL_LIBRARY_TYPE STATIC) + endif() +endif() + +foreach(LIBRARY_NAME ${ONEDNN_LIBRARY_NAMES}) + list(APPEND ONEDNN_STATIC_LIBRARIES ${ONEDNN_LIBRARY_DIR}/${LIBRARY_NAME}) +endforeach() + + +if(THIRD_PARTY) + +ExternalProject_Add(onednn + PREFIX onednn + URL ${ONEDNN_URL} + URL_MD5 c60ea96acbaccec053be7e3fa81c6184 + UPDATE_COMMAND "" + BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${ONEDNN_STATIC_LIBRARIES} + CMAKE_CACHE_ARGS + -DCMAKE_INSTALL_PREFIX:STRING=${ONEDNN_INSTALL_DIR} + -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} + -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} + -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW + -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} + -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} + -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} + -DCMAKE_C_FLAGS_DEBUG:STRING=${CMAKE_C_FLAGS_DEBUG} + -DCMAKE_C_FLAGS_RELEASE:STRING=${CMAKE_C_FLAGS_RELEASE} + -DDNNL_IS_MAIN_PROJECT:BOOL=OFF + -DDNNL_BUILD_EXAMPLES:BOOL=OFF + -DDNNL_BUILD_TESTS:BOOL=OFF + -DDNNL_LIBRARY_TYPE:STRING=${DNNL_LIBRARY_TYPE} + -DDNNL_CPU_RUNTIME:STRING=OMP +) + +endif(THIRD_PARTY) +add_library(onednn_imported UNKNOWN IMPORTED) +set_property(TARGET onednn_imported PROPERTY IMPORTED_LOCATION "${ONEDNN_STATIC_LIBRARIES}") diff --git a/oneflow/core/ep/cpu/cpu_stream.h b/oneflow/core/ep/cpu/cpu_stream.h index 6fa6055bff1..f4fe78788ea 100644 --- a/oneflow/core/ep/cpu/cpu_stream.h +++ b/oneflow/core/ep/cpu/cpu_stream.h @@ -17,6 +17,9 @@ limitations under the License. #define ONEFLOW_CORE_EP_CPU_CPU_STREAM_H_ #include "oneflow/core/ep/include/stream.h" +#ifdef WITH_ONEDNN +#include +#endif namespace oneflow { @@ -25,12 +28,27 @@ namespace ep { class CpuStream : public Stream { public: OF_DISALLOW_COPY_AND_MOVE(CpuStream); - CpuStream() = default; + CpuStream() { +#ifdef WITH_ONEDNN + onednn_engine_.reset(new dnnl::engine(dnnl::engine::kind::cpu, 0)); + onednn_stream_.reset(new dnnl::stream(*onednn_engine_)); +#endif + } + ~CpuStream() override = default; DeviceType device_type() const override; Maybe Sync() override; void RecordEvent(Event* event) override; + +#ifdef WITH_ONEDNN + dnnl::engine* onednn_engine() const { return onednn_engine_.get(); } + dnnl::stream* onednn_stream() const { return onednn_stream_.get(); } + + private: + std::unique_ptr onednn_engine_; + std::unique_ptr onednn_stream_; +#endif }; } // namespace ep diff --git a/oneflow/core/ep/cpu/primitive/add.cpp b/oneflow/core/ep/cpu/primitive/add.cpp index 27989abd199..6b77945e343 100644 --- a/oneflow/core/ep/cpu/primitive/add.cpp +++ b/oneflow/core/ep/cpu/primitive/add.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/ep/cpu/primitive/type_seq.h" +#include "oneflow/core/ep/cpu/cpu_stream.h" namespace oneflow { @@ -42,11 +43,11 @@ void AddCpu(const T* const* srcs, size_t arity, T* dst, size_t count) { } template -class AddImpl : public Add { +class AddDefaultImpl : public Add { public: - OF_DISALLOW_COPY_AND_MOVE(AddImpl); - AddImpl() = default; - ~AddImpl() override = default; + OF_DISALLOW_COPY_AND_MOVE(AddDefaultImpl); + AddDefaultImpl() = default; + ~AddDefaultImpl() override = default; using Add::Launch; void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst, @@ -73,11 +74,80 @@ class AddImpl : public Add { } }; +#ifdef WITH_ONEDNN + +class AddOneDnnImpl : public Add { + public: + OF_DISALLOW_COPY_AND_MOVE(AddOneDnnImpl); + AddOneDnnImpl(dnnl::memory::data_type type) : type_onednn_(type){}; + ~AddOneDnnImpl() override = default; + + using Add::Launch; + void Launch(Stream* stream, const void* const* srcs, size_t arity, void* dst, + size_t count) override { + for (int i = 1; i < arity; i++) { + if (srcs[i] == dst) { LOG(FATAL) << "Only the first parameter can be operated inplace"; } + } + dnnl::engine* onednn_engine = stream->As()->onednn_engine(); + dnnl::stream* onednn_stream = stream->As()->onednn_stream(); + + dnnl::memory::dims src_dims = {static_cast(count)}; + std::vector src_md; + std::vector src_mem; + src_md.reserve(arity); + src_mem.reserve(arity); + + for (int i = 0; i < arity; i++) { + auto md = dnnl::memory::desc(src_dims, type_onednn_, dnnl::memory::format_tag::x); + auto mem = dnnl::memory(md, *onednn_engine, (void*)(srcs)[i]); + src_md.emplace_back(md); + src_mem.emplace_back(mem); + } + + std::vector scales(arity, 1.0); + auto sum_pd = dnnl::sum::primitive_desc(scales, src_md, *onednn_engine); + auto sum_prim = dnnl::sum(sum_pd); + auto dst_mem = dnnl::memory(sum_pd.dst_desc(), *onednn_engine, dst); + std::unordered_map sum_args{{DNNL_ARG_DST, dst_mem}}; + for (int i = 0; i < arity; ++i) { sum_args.insert({DNNL_ARG_MULTIPLE_SRC + i, src_mem[i]}); } + + sum_prim.execute(*onednn_stream, sum_args); + onednn_stream->wait(); + } + + private: + dnnl::memory::data_type type_onednn_; +}; + +#endif + template std::unique_ptr NewAdd() { - return std::unique_ptr(new AddImpl()); + return std::unique_ptr(new AddDefaultImpl()); +} + +#ifdef WITH_ONEDNN + +template +std::unique_ptr NewOneDnnAdd() { + return std::unique_ptr(new AddOneDnnImpl(type_onednn)); } +#endif + +#define CPU_PRIMITIVE_ADD_ONEDNN_TYPE_SEQ \ + CPU_PRIMITIVE_ONEDNN_INT8_TYPE_SEQ \ + CPU_PRIMITIVE_ONEDNN_UINT8_TYPE_SEQ \ + CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ \ + CPU_PRIMITIVE_ONEDNN_FLOAT_TYPE_SEQ \ + CPU_PRIMITIVE_ONEDNN_FLOAT16_TYPE_SEQ \ + CPU_PRIMITIVE_ONEDNN_BFLOAT16_TYPE_SEQ + +#define CPU_PRIMITIVE_ADD_DEFAULT_TYPE_SEQ \ + CPU_PRIMITIVE_CHAR_TYPE_SEQ \ + CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \ + CPU_PRIMITIVE_INT64_TYPE_SEQ + class AddFactoryImpl : public AddFactory { public: OF_DISALLOW_COPY_AND_MOVE(AddFactoryImpl); @@ -87,8 +157,17 @@ class AddFactoryImpl : public AddFactory { std::unique_ptr New(DataType data_type) override { #define MAKE_NEW_ADD_ENTRY(type_cpp, type_proto) {type_proto, NewAdd}, +#ifdef WITH_ONEDNN +#define MAKE_NEW_ONEDNN_ADD_ENTRY(type_onednn, type_proto) {type_proto, NewOneDnnAdd}, + + static const std::map()>> new_add_handle{ + OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ONEDNN_ADD_ENTRY, CPU_PRIMITIVE_ADD_ONEDNN_TYPE_SEQ) + OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ADD_ENTRY, CPU_PRIMITIVE_ADD_DEFAULT_TYPE_SEQ)}; +#else static const std::map()>> new_add_handle{ OF_PP_FOR_EACH_TUPLE(MAKE_NEW_ADD_ENTRY, CPU_PRIMITIVE_ALL_TYPE_SEQ)}; +#endif +#undef MAKE_NEW_ONEDNN_ADD_ENTRY #undef MAKE_NEW_ADD_ENTRY const auto it = new_add_handle.find(data_type); if (it != new_add_handle.end()) { diff --git a/oneflow/core/ep/cpu/primitive/type_seq.h b/oneflow/core/ep/cpu/primitive/type_seq.h index 802ac81b6b6..d5bef212a77 100644 --- a/oneflow/core/ep/cpu/primitive/type_seq.h +++ b/oneflow/core/ep/cpu/primitive/type_seq.h @@ -20,6 +20,10 @@ limitations under the License. #include "oneflow/core/common/data_type.h" #include +#ifdef WITH_ONEDNN +#include "oneapi/dnnl/dnnl.hpp" +#endif + #define CPU_PRIMITIVE_CHAR_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(char, DataType::kChar) #define CPU_PRIMITIVE_INT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(int8_t, DataType::kInt8) #define CPU_PRIMITIVE_UINT8_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(uint8_t, DataType::kUInt8) @@ -29,6 +33,19 @@ limitations under the License. #define CPU_PRIMITIVE_DOUBLE_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(double, DataType::kDouble) #define CPU_PRIMITIVE_FLOAT16_TYPE_SEQ OF_PP_MAKE_TUPLE_SEQ(float16, DataType::kFloat16) +#define CPU_PRIMITIVE_ONEDNN_INT8_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s8, DataType::kInt8) +#define CPU_PRIMITIVE_ONEDNN_UINT8_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::u8, DataType::kUInt8) +#define CPU_PRIMITIVE_ONEDNN_INT32_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::s32, DataType::kInt32) +#define CPU_PRIMITIVE_ONEDNN_FLOAT_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f32, DataType::kFloat) +#define CPU_PRIMITIVE_ONEDNN_FLOAT16_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::f16, DataType::kFloat16) +#define CPU_PRIMITIVE_ONEDNN_BFLOAT16_TYPE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(dnnl::memory::data_type::bf16, DataType::kBFloat16) + #define CPU_PRIMITIVE_NATIVE_TYPE_SEQ \ CPU_PRIMITIVE_CHAR_TYPE_SEQ \ CPU_PRIMITIVE_INT8_TYPE_SEQ \ diff --git a/oneflow/core/kernel/slice_boxing_kernel.cpp b/oneflow/core/kernel/slice_boxing_kernel.cpp index e487245e542..a4e0b9cccb3 100644 --- a/oneflow/core/kernel/slice_boxing_kernel.cpp +++ b/oneflow/core/kernel/slice_boxing_kernel.cpp @@ -107,12 +107,12 @@ void SliceBoxingAddKernel::ForwardDataContent(KernelContext* ctx) const { } } else { if (in_i->shape() == out->shape()) { - primitive->Launch(ctx->stream(), in_i->dptr(), out->dptr(), out->mut_dptr(), + primitive->Launch(ctx->stream(), out->dptr(), in_i->dptr(), out->mut_dptr(), out->shape().elem_cnt()); } else { Blob* buf = ctx->BnInOp2Blob("buf"); this->tensor_slice_copier_vec().at(i)->Copy(ctx->stream(), buf, in_i); - primitive->Launch(ctx->stream(), buf->dptr(), out->dptr(), out->mut_dptr(), + primitive->Launch(ctx->stream(), out->dptr(), buf->dptr(), out->mut_dptr(), out->shape().elem_cnt()); } } diff --git a/oneflow/user/kernels/acc_kernel.cpp b/oneflow/user/kernels/acc_kernel.cpp index 3843ae659cd..1773bc5d1bd 100644 --- a/oneflow/user/kernels/acc_kernel.cpp +++ b/oneflow/user/kernels/acc_kernel.cpp @@ -36,7 +36,7 @@ class AccKernel final : public user_op::OpKernel { std::unique_ptr primitive = ep::primitive::NewPrimitive(ctx->device_type(), in->data_type()); CHECK(primitive); - primitive->Launch(ctx->stream(), in->dptr(), out->dptr(), out->mut_dptr(), + primitive->Launch(ctx->stream(), out->dptr(), in->dptr(), out->mut_dptr(), in->shape().elem_cnt()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } diff --git a/oneflow/user/kernels/conv_kernels.cpp b/oneflow/user/kernels/conv_kernels.cpp index 900ebda7f69..ac792deeaf4 100644 --- a/oneflow/user/kernels/conv_kernels.cpp +++ b/oneflow/user/kernels/conv_kernels.cpp @@ -532,7 +532,7 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel { ep::primitive::NewPrimitive(DeviceType::kCPU, add_to_output->data_type()); CHECK(primitive); - primitive->Launch(ctx->stream(), add_to_output->dptr(), dx->dptr(), dx->mut_dptr(), + primitive->Launch(ctx->stream(), dx->dptr(), add_to_output->dptr(), dx->mut_dptr(), add_to_output->shape().elem_cnt()); } } diff --git a/oneflow/user/kernels/dropout_kernel.cpp b/oneflow/user/kernels/dropout_kernel.cpp index 52aa69a1e70..9486fa7b822 100644 --- a/oneflow/user/kernels/dropout_kernel.cpp +++ b/oneflow/user/kernels/dropout_kernel.cpp @@ -50,7 +50,7 @@ class DropoutKernelCPU final : public user_op::OpKernel { ep::primitive::NewPrimitive(DeviceType::kCPU, add_to_output->data_type()); CHECK(primitive); - primitive->Launch(ctx->stream(), add_to_output->dptr(), out->dptr(), out->mut_dptr(), + primitive->Launch(ctx->stream(), out->dptr(), add_to_output->dptr(), out->mut_dptr(), add_to_output->shape().elem_cnt()); } } diff --git a/oneflow/user/kernels/eager_p_to_b_kernel.cpp b/oneflow/user/kernels/eager_p_to_b_kernel.cpp index f4ef7bb7efa..fdf7c0450d2 100644 --- a/oneflow/user/kernels/eager_p_to_b_kernel.cpp +++ b/oneflow/user/kernels/eager_p_to_b_kernel.cpp @@ -104,7 +104,7 @@ class EagerPToBKernel final : public user_op::OpKernel { if (GlobalProcessCtx::Rank() == dst) { CHECK_JUST(Recv(tmp_buffer_ptr, total_elem_cnt, out->data_type(), src, ctx->stream())); - add_primitive->Launch(ctx->stream(), tmp_buffer_ptr, out->dptr(), out->mut_dptr(), + add_primitive->Launch(ctx->stream(), out->dptr(), tmp_buffer_ptr, out->mut_dptr(), total_elem_cnt); } } diff --git a/oneflow/user/kernels/eager_p_to_s_kernel.cpp b/oneflow/user/kernels/eager_p_to_s_kernel.cpp index 683ae4e56e5..82d55d8cd3f 100644 --- a/oneflow/user/kernels/eager_p_to_s_kernel.cpp +++ b/oneflow/user/kernels/eager_p_to_s_kernel.cpp @@ -166,7 +166,7 @@ class EagerPToSKernel final : public user_op::OpKernel { if (GlobalProcessCtx::Rank() == dst) { CHECK_JUST(Recv(tmp_buffer_ptr, elem_cnt_per_chunk, out->data_type(), src, ctx->stream())); - add_primitive->Launch(ctx->stream(), tmp_buffer_ptr, out->dptr(), out->mut_dptr(), + add_primitive->Launch(ctx->stream(), out->dptr(), tmp_buffer_ptr, out->mut_dptr(), elem_cnt_per_chunk); } } diff --git a/oneflow/user/kernels/group_conv_kernel.cpp b/oneflow/user/kernels/group_conv_kernel.cpp index 742cf1f7f50..ab60efe4708 100644 --- a/oneflow/user/kernels/group_conv_kernel.cpp +++ b/oneflow/user/kernels/group_conv_kernel.cpp @@ -578,7 +578,7 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel { ep::primitive::NewPrimitive(DeviceType::kCPU, add_to_output->data_type()); CHECK(primitive); - primitive->Launch(ctx->stream(), add_to_output->dptr(), dx->dptr(), dx->mut_dptr(), + primitive->Launch(ctx->stream(), dx->dptr(), add_to_output->dptr(), dx->mut_dptr(), add_to_output->shape().elem_cnt()); } }