diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 8f19a90eb6..b02267ff39 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -46,6 +46,7 @@ add_executable(MIOpenDriver dm_groupnorm.cpp dm_layernorm.cpp dm_lrn.cpp + dm_outer.cpp dm_pool.cpp dm_reduce.cpp dm_reduceextreme.cpp diff --git a/driver/dm_outer.cpp b/driver/dm_outer.cpp new file mode 100644 index 0000000000..c4b5791ab9 --- /dev/null +++ b/driver/dm_outer.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "registry_driver_maker.hpp" +#include "outer_driver.hpp" + +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "outer") + return new OuterDriver(); + if(base_arg == "outerfp16") + return new OuterDriver(); + if(base_arg == "outerbfp16") + return new OuterDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/driver.hpp b/driver/driver.hpp index 5bb0a29042..11201f3771 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -174,7 +174,8 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], " "tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16], " "groupnorm[bfp16|fp16], cat[bfp16|fp16], addlayernorm[bfp16|fp16], " - "t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16]\n"); + "t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16] " + "outer[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -202,7 +203,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "addlayernorm" && arg != "addlayernormfp16" && arg != "addlayernormbfp16" && arg != "t5layernorm" && arg != "t5layernormfp16" && arg != "t5layernormbfp16" && arg != "adam" && arg != "adamfp16" && arg != "ampadam" && arg != "reduceextreme" && - arg != "reduceextremefp16" && arg != "reduceextremebfp16" && arg != "--version") + arg != "reduceextremefp16" && arg != "reduceextremebfp16" && arg != "outer" && + arg != "outerfp16" && arg != "outerbfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); Usage(); diff --git a/driver/outer_driver.hpp b/driver/outer_driver.hpp new file mode 100644 index 0000000000..e960c46747 --- /dev/null +++ b/driver/outer_driver.hpp @@ -0,0 +1,515 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_MIOPEN_OUTER_DRIVER_HPP +#define GUARD_MIOPEN_OUTER_DRIVER_HPP + +#include "InputFlags.hpp" +#include "driver.hpp" +#include "tensor_driver.hpp" +#include "timer.hpp" +#include "random.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include <../test/tensor_holder.hpp> +#include <../test/verify.hpp> + +template +int32_t mloOuterForwardRunHost(miopenTensorDescriptor_t input1Desc, + miopenTensorDescriptor_t input2Desc, + miopenTensorDescriptor_t outputDesc, + Tgpu* input1, + Tgpu* input2, + Tcheck* outputhost) +{ + auto input1_dims = miopen::deref(input1Desc).GetLengths(); + auto input2_dims = miopen::deref(input2Desc).GetLengths(); + auto output_dims = miopen::deref(outputDesc).GetLengths(); + + size_t in_n = input1_dims[0]; + size_t in_m = input2_dims[0]; + + int32_t ret = 0; + + size_t cnt = 0; + for(size_t i = 0; i < in_n; i++) + { + for(size_t j = 0; j < in_m; j++) + { + outputhost[cnt] = 0; + outputhost[cnt++] = input1[i] * input2[j]; + } + } + return ret; +} + +template +int32_t mloOuterBackwardRunHost(miopenTensorDescriptor_t input1Desc, + miopenTensorDescriptor_t input2Desc, + miopenTensorDescriptor_t input1GradDesc, + miopenTensorDescriptor_t input2GradDesc, + miopenTensorDescriptor_t outputGradDesc, + Tgpu* input1, + Tgpu* input2, + Tgpu* outGrad, + Tcheck* in1Gradhost, + Tcheck* in2Gradhost) +{ + auto input1_dims = miopen::deref(input1Desc).GetLengths(); + auto input2_dims = miopen::deref(input2Desc).GetLengths(); + auto output_dims = miopen::deref(outputGradDesc).GetLengths(); + + size_t in_n = input1_dims[0]; + size_t in_m = input2_dims[0]; + + int32_t ret = 0; + + for(size_t i = 0; i < in_n; i++) + { + Tcheck sum = static_cast(0.0f); + for(size_t j = 0; j < in_m; j++) + { + sum += static_cast(input2[j]) * static_cast(outGrad[i * in_m + j]); + } + in1Gradhost[i] = sum; + } + + for(size_t j = 0; j < in_m; j++) + { + Tcheck sum = static_cast(0.0f); + for(size_t i = 0; i < in_n; i++) + { + sum += static_cast(input1[i]) * static_cast(outGrad[i * in_m + j]); + } + in2Gradhost[j] = sum; + } + return ret; +} + +template +class OuterDriver : public Driver +{ +public: + OuterDriver() : Driver() + { + miopenCreateTensorDescriptor(&input1Desc); + miopenCreateTensorDescriptor(&input2Desc); + miopenCreateTensorDescriptor(&outputDesc); + + miopenCreateTensorDescriptor(&input1GradDesc); + miopenCreateTensorDescriptor(&input2GradDesc); + miopenCreateTensorDescriptor(&outputGradDesc); + + data_type = miopen_type{}; + } + + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + std::vector GetInputTensorLengthsFromCmdLine(); + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + int RunBackwardCPU(); + + Tref GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~OuterDriver() override + { + miopenDestroyTensorDescriptor(input1Desc); + miopenDestroyTensorDescriptor(input1Desc); + miopenDestroyTensorDescriptor(outputDesc); + + miopenDestroyTensorDescriptor(input1GradDesc); + miopenDestroyTensorDescriptor(input2GradDesc); + miopenDestroyTensorDescriptor(outputGradDesc); + } + +private: + InputFlags inflags; + + int forw; + + miopenTensorDescriptor_t input1Desc; + miopenTensorDescriptor_t input2Desc; + miopenTensorDescriptor_t input1GradDesc; + miopenTensorDescriptor_t input2GradDesc; + miopenTensorDescriptor_t outputDesc; + miopenTensorDescriptor_t outputGradDesc; + + std::unique_ptr in1_dev; + std::unique_ptr in2_dev; + std::unique_ptr in1Grad_dev; + std::unique_ptr in2Grad_dev; + std::unique_ptr out_dev; + std::unique_ptr outGrad_dev; + + std::vector in1; + std::vector in2; + std::vector in1Grad; + std::vector in2Grad; + std::vector out; + std::vector outGrad; + + std::vector in1Gradhost; + std::vector in2Gradhost; + std::vector outhost; +}; + +template +int OuterDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int OuterDriver::GetandSetData() +{ + std::vector in_lens = GetInputTensorLengthsFromCmdLine(); + + int in_n = inflags.GetValueInt("in_n"); + int in_m = inflags.GetValueInt("in_m"); + + auto lens1 = std::vector({in_lens[0]}); + auto lens2 = std::vector({in_lens[1]}); + + SetTensorNd(input1Desc, lens1, data_type); + SetTensorNd(input2Desc, lens2, data_type); + + SetTensorNd(input1GradDesc, lens1, data_type); + SetTensorNd(input2GradDesc, lens2, data_type); + + std::vector out_len({in_n, in_m}); + + SetTensorNd(outputDesc, out_len, data_type); + SetTensorNd(outputGradDesc, out_len, data_type); + + return 0; +} + +template +int OuterDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward Outer (Default=1)", "int"); + inflags.AddInputFlag("in_n", 'N', "32", "n size (Default=32)", "int"); + inflags.AddInputFlag("in_m", 'M', "32", "m size (Default=32)", "int"); + + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +std::vector OuterDriver::GetInputTensorLengthsFromCmdLine() +{ + int in_n = inflags.GetValueInt("in_n"); + int in_m = inflags.GetValueInt("in_m"); + + if((in_n != 0) && (in_m != 0)) + { + return std::vector({in_n, in_m}); + } + else + { + std::cerr << "Error Input Tensor Lengths\n" << std::endl; + return std::vector({0}); + } +} + +template +int OuterDriver::AllocateBuffersAndCopy() +{ + size_t in1_sz = GetTensorSize(input1Desc); + size_t in2_sz = GetTensorSize(input2Desc); + size_t out_sz = GetTensorSize(outputDesc); + + uint32_t ctx = 0; + + in1_dev = std::unique_ptr(new GPUMem(ctx, in1_sz, sizeof(Tgpu))); + in2_dev = std::unique_ptr(new GPUMem(ctx, in2_sz, sizeof(Tgpu))); + in1Grad_dev = std::unique_ptr(new GPUMem(ctx, in1_sz, sizeof(Tgpu))); + in2Grad_dev = std::unique_ptr(new GPUMem(ctx, in2_sz, sizeof(Tgpu))); + out_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); + outGrad_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); + + in1 = std::vector(in1_sz, static_cast(0)); + in2 = std::vector(in2_sz, static_cast(0)); + in1Grad = std::vector(in1_sz, static_cast(0)); + in2Grad = std::vector(in2_sz, static_cast(0)); + out = std::vector(out_sz, static_cast(0)); + outGrad = std::vector(out_sz, static_cast(0)); + + in1Gradhost = std::vector(in1_sz, static_cast(0)); + in2Gradhost = std::vector(in2_sz, static_cast(0)); + outhost = std::vector(out_sz, static_cast(0)); + + for(int i = 0; i < in1_sz; i++) + { + in1[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + } + + for(int i = 0; i < in2_sz; i++) + { + in2[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + } + + for(int i = 0; i < out_sz; i++) + { + outGrad[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + } + + if(in1_dev->ToGPU(GetStream(), in1.data()) != 0) + std::cerr << "Error copying (in1) to GPU, size: " << in1_dev->GetSize() << std::endl; + + if(in2_dev->ToGPU(GetStream(), in2.data()) != 0) + std::cerr << "Error copying (in1) to GPU, size: " << in2_dev->GetSize() << std::endl; + + if(outGrad_dev->ToGPU(GetStream(), outGrad.data()) != 0) + std::cerr << "Error copying (out) to GPU, size: " << out_dev->GetSize() << std::endl; + + return miopenStatusSuccess; +} + +template +int OuterDriver::RunForwardGPU() +{ + float kernel_total_time = 0.0; + float kernel_first_time = 0.0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenOuterForward(GetHandle(), + input1Desc, + in1_dev->GetMem(), + input2Desc, + in2_dev->GetMem(), + outputDesc, + out_dev->GetMem()); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Forward Outer Elapsed: " << t.gettime_ms() / iter + << " ms\n"; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Forward Outer Elapsed: " << kernel_average_time << " ms\n"; + } + + if(out_dev->FromGPU(GetStream(), out.data()) != 0) + std::cerr << "Error copying (out_dev) from GPU, size: " << out_dev->GetSize() << std::endl; + + return miopenStatusSuccess; +} + +template +int OuterDriver::RunForwardCPU() +{ + mloOuterForwardRunHost( + input1Desc, input2Desc, outputDesc, in1.data(), in2.data(), outhost.data()); + + return miopenStatusSuccess; +} + +template +int OuterDriver::RunBackwardGPU() +{ + float kernel_total_time = 0.0; + float kernel_first_time = 0.0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + float time_sum = 0.0f; + float time_tmp = 0.0f; + + miopenOuterBackwardGrad1(GetHandle(), + input2Desc, + in2_dev->GetMem(), + input1GradDesc, + in1Grad_dev->GetMem(), + outputGradDesc, + outGrad_dev->GetMem()); + + miopenGetKernelTime(GetHandle(), &time_tmp); + time_sum += time_tmp; + + miopenOuterBackwardGrad2(GetHandle(), + input1Desc, + in1_dev->GetMem(), + input2GradDesc, + in2Grad_dev->GetMem(), + outputGradDesc, + outGrad_dev->GetMem()); + + miopenGetKernelTime(GetHandle(), &time_tmp); + time_sum += time_tmp; + + kernel_total_time += time_sum; + if(i == 0) + kernel_first_time = time_sum; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Forward Outer Elapsed: " << t.gettime_ms() / iter + << " ms\n"; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Backward Outer Elapsed: " << kernel_average_time << " ms\n"; + } + + if(in1Grad_dev->FromGPU(GetStream(), in1Grad.data()) != 0) + std::cerr << "Error copying (in1Grad_dev) from GPU, size: " << in1Grad_dev->GetSize() + << std::endl; + + if(in2Grad_dev->FromGPU(GetStream(), in2Grad.data()) != 0) + std::cerr << "Error copying (in2Grad_dev) from GPU, size: " << in2Grad_dev->GetSize() + << std::endl; + + return miopenStatusSuccess; +} + +template +int OuterDriver::RunBackwardCPU() +{ + mloOuterBackwardRunHost(input1Desc, + input2Desc, + input1GradDesc, + input2GradDesc, + outputGradDesc, + in1.data(), + in2.data(), + outGrad.data(), + in1Gradhost.data(), + in2Gradhost.data()); + + return miopenStatusSuccess; +} + +template +Tref OuterDriver::GetTolerance() +{ + auto tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; + + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + return tolerance; +} + +template +int OuterDriver::VerifyForward() +{ + RunForwardCPU(); + const Tref tolerance = GetTolerance(); + auto error = miopen::rms_range(outhost, out); + + if(!std::isfinite(error) || error > tolerance) + { + std::cout << "Forward Outer FAILED: " << error << " > " << tolerance << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Forward Outer Verifies OK on CPU reference (" << error << " < " << tolerance + << ')' << std::endl; + } + + return miopenStatusSuccess; +} + +template +int OuterDriver::VerifyBackward() +{ + RunBackwardCPU(); + const Tref tolerance = GetTolerance(); + auto error1 = miopen::rms_range(in1Gradhost, in1Grad); + auto error2 = miopen::rms_range(in2Gradhost, in2Grad); + + if(!std::isfinite(error1) || error1 > tolerance) + { + std::cout << "Backward Outer FAILED with in1: " << error1 << " > " << tolerance + << std::endl; + return EC_VerifyFwd; + } + else if(!std::isfinite(error2) || error2 > tolerance) + { + std::cout << "Backward Outer FAILED with in2: " << error2 << " > " << tolerance + << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Backward Outer Verifies OK on CPU reference (" << error1 << " < " << tolerance + << ')' << " and " + << "(" << error2 << " < " << tolerance << ')' << std::endl; + } + return miopenStatusSuccess; +} + +#endif // GUARD_MIOPEN_OUTER_DRIVER_HPP diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 9821b94912..58e9556a5a 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -6155,6 +6155,71 @@ MIOPEN_EXPORT miopenStatus_t miopenT5LayerNormBackward(miopenHandle_t handle, // CLOSEOUT LAYERNORM DOXYGEN GROUP #endif +#ifdef MIOPEN_BETA_API +// Outer APIs +/** @addtogroup outer + * + * @{ + */ +/*! @brief Execute a outer forward layer + * + * @param handle MIOpen handle (input) + * @param x1Desc Tensor descriptor of input tensor x1 (input) + * @param x1 Source data tensor x1 (input) + * @param x2Desc Tensor descriptor of input tensor x2 (input) + * @param x2 Source data tensor x2 (input) + * @param yDesc Tensor descriptor of output tensor y (output) + * @param y Data tensor y (output) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenOuterForward(miopenHandle_t handle, + const miopenTensorDescriptor_t x1Desc, + const void* x1, + const miopenTensorDescriptor_t x2Desc, + const void* x2, + const miopenTensorDescriptor_t yDesc, + void* y); +/*! @brief Execute a outer backwardGrad1 layer + * + * @param handle MIOpen handle (input) + * @param x2Desc Tensor descriptor of input tensor x2 (input) + * @param x2 Source data tensor x2 (input) + * @param x1GradDesc Tensor descriptor of input tensor x1Grad (output) + * @param x1Grad Source data tensor x1Grad (output) + * @param yGradDesc Tensor descriptor of output tensor yGrad (input) + * @param yGrad Data tensor yGrad (input) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenOuterBackwardGrad1(miopenHandle_t handle, + const miopenTensorDescriptor_t x2Desc, + const void* x2, + const miopenTensorDescriptor_t x1GradDesc, + void* x1Grad, + const miopenTensorDescriptor_t yGradDesc, + const void* yGrad); +/*! @brief Execute a outer backwardGrad2 layer + * + * @param handle MIOpen handle (input) + * @param x1Desc Tensor descriptor of input tensor x2 (input) + * @param x1 Source data tensor x2 (input) + * @param x2GradDesc Tensor descriptor of input tensor x1Grad (output) + * @param x2Grad Source data tensor x1Grad (output) + * @param yGradDesc Tensor descriptor of output tensor yGrad (input) + * @param yGrad Data tensor yGrad (input) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenOuterBackwardGrad2(miopenHandle_t handle, + const miopenTensorDescriptor_t x1Desc, + const void* x1, + const miopenTensorDescriptor_t x2GradDesc, + void* x2Grad, + const miopenTensorDescriptor_t yGradDesc, + const void* yGrad); + +/** @} */ +// CLOSEOUT OUTER DOXYGEN GROUP +#endif + #ifdef MIOPEN_BETA_API // Graph API /** @addtogroup GraphAPI diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1cb8a1fb0c..b8b102eac5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -156,6 +156,8 @@ set( MIOpen_Source mha/problem_description.cpp op_args.cpp operator.cpp + outer_api.cpp + outer/problem_description.cpp performance_config.cpp pooling/problem_description.cpp pooling_api.cpp @@ -284,6 +286,9 @@ set( MIOpen_Source solver/layernorm/forward_t5layernorm.cpp solver/mha/mha_solver_backward.cpp solver/mha/mha_solver_forward.cpp + solver/outer/backwardgrad1_outer.cpp + solver/outer/backwardgrad2_outer.cpp + solver/outer/forward_outer.cpp solver/pooling/forward2d.cpp solver/pooling/forwardNaive.cpp solver/pooling/forwardNd.cpp @@ -488,6 +493,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenLRNBwd.cl kernels/MIOpenLRNFwd.cl kernels/MIOpenNeuron.cl + kernels/MIOpenOuter.cpp kernels/MIOpenPooling.cl kernels/MIOpenPoolingBwd.cl kernels/MIOpenPoolingBwdND.cl @@ -629,6 +635,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN ocl/dropoutocl.cpp ocl/gcn_asm_utils.cpp ocl/rnn_util_ocl.cpp + outer.cpp hip/hip_build_utils.cpp hip/batched_transpose_sol.cpp hip/general_tensor_reorder_sol.cpp diff --git a/src/include/miopen/outer.hpp b/src/include/miopen/outer.hpp new file mode 100644 index 0000000000..fe67bd22b8 --- /dev/null +++ b/src/include/miopen/outer.hpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_OUTER_HPP_ +#define MIOPEN_OUTER_HPP_ + +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +MIOPEN_INTERNALS_EXPORT miopenStatus_t OuterForward(Handle& handle, + const TensorDescriptor& x1Desc, + ConstData_t x1, + const TensorDescriptor& x2Desc, + ConstData_t x2, + const TensorDescriptor& yDesc, + Data_t y); + +MIOPEN_INTERNALS_EXPORT miopenStatus_t OuterBackwardGrad1(Handle& handle, + const TensorDescriptor& x2Desc, + ConstData_t x2, + const TensorDescriptor& x1GradDesc, + ConstData_t x1Grad, + const TensorDescriptor& yGradDesc, + ConstData_t yGrad); + +MIOPEN_INTERNALS_EXPORT miopenStatus_t OuterBackwardGrad2(Handle& handle, + const TensorDescriptor& x1Desc, + ConstData_t x1, + const TensorDescriptor& x2GradDesc, + ConstData_t x2Grad, + const TensorDescriptor& yGradDesc, + ConstData_t yGrad); + +} // namespace miopen +#endif // _MIOPEN_OUTER_HPP_ diff --git a/src/include/miopen/outer/invoke_params.hpp b/src/include/miopen/outer/invoke_params.hpp new file mode 100644 index 0000000000..dea70169d5 --- /dev/null +++ b/src/include/miopen/outer/invoke_params.hpp @@ -0,0 +1,116 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +namespace miopen { + +namespace outer { + +struct InvokeParamsForward : public miopen::InvokeParams +{ + InvokeParamsForward(const TensorDescriptor& x1Desc_, + ConstData_t x1_, + const TensorDescriptor& x2Desc_, + ConstData_t x2_, + const TensorDescriptor& yDesc_, + Data_t y_) + : x1Desc(x1Desc_), x1(x1_), x2Desc(x2Desc_), x2(x2_), yDesc(yDesc_), y(y_) + { + } + + TensorDescriptor x1Desc{}; + ConstData_t x1 = nullptr; + TensorDescriptor x2Desc{}; + ConstData_t x2 = nullptr; + TensorDescriptor yDesc{}; + Data_t y = nullptr; + + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } +}; + +struct InvokeParamsBackwardGrad1 : public miopen::InvokeParams +{ + InvokeParamsBackwardGrad1(const TensorDescriptor& x2Desc_, + ConstData_t x2_, + const TensorDescriptor& x1GradDesc_, + ConstData_t x1Grad_, + const TensorDescriptor& yGradDesc_, + ConstData_t yGrad_) + : x2Desc(x2Desc_), + x2(x2_), + x1GradDesc(x1GradDesc_), + x1Grad(x1Grad_), + yGradDesc(yGradDesc_), + yGrad(yGrad_) + { + } + + TensorDescriptor x2Desc{}; + ConstData_t x2 = nullptr; + TensorDescriptor x1GradDesc{}; + ConstData_t x1Grad = nullptr; + TensorDescriptor yGradDesc{}; + ConstData_t yGrad = nullptr; + + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } +}; + +struct InvokeParamsBackwardGrad2 : public miopen::InvokeParams +{ + InvokeParamsBackwardGrad2(const TensorDescriptor& x1Desc_, + ConstData_t x1_, + const TensorDescriptor& x2GradDesc_, + ConstData_t x2Grad_, + const TensorDescriptor& yGradDesc_, + ConstData_t yGrad_) + : x1Desc(x1Desc_), + x1(x1_), + x2GradDesc(x2GradDesc_), + x2Grad(x2Grad_), + yGradDesc(yGradDesc_), + yGrad(yGrad_) + { + } + + TensorDescriptor x1Desc{}; + ConstData_t x1 = nullptr; + TensorDescriptor x2GradDesc{}; + ConstData_t x2Grad = nullptr; + TensorDescriptor yGradDesc{}; + ConstData_t yGrad = nullptr; + + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } +}; + +} // namespace outer +} // namespace miopen diff --git a/src/include/miopen/outer/problem_description.hpp b/src/include/miopen/outer/problem_description.hpp new file mode 100644 index 0000000000..3d37a47d01 --- /dev/null +++ b/src/include/miopen/outer/problem_description.hpp @@ -0,0 +1,106 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace outer { + +enum gradType +{ + NONE, + ONE, + TWO +}; + +struct ProblemDescription : ProblemDescriptionBase +{ + ProblemDescription(const bool is_fwd_, + gradType grad_, + const TensorDescriptor& x1Desc_, + const TensorDescriptor& x2Desc_, + const TensorDescriptor& yDesc_) + : is_fwd(is_fwd_), grad(grad_), x1Desc(x1Desc_), x2Desc(x2Desc_), yDesc(yDesc_) + { + const auto dtype = yDesc.GetType(); + if(x1Desc.GetType() != dtype) + { + MIOPEN_THROW(miopenStatusBadParm, "Outer: Tensor types do not match."); + } + if(x2Desc.GetType() != dtype) + { + MIOPEN_THROW(miopenStatusBadParm, "Outer: Tensor types do not match."); + } + if(is_fwd == true && (grad_ == ONE || grad_ == TWO)) + { + MIOPEN_THROW(miopenStatusBadParm, + "Outer: the direciton and the gradient type do not match"); + } + if(is_fwd == false && grad_ == NONE) + { + MIOPEN_THROW(miopenStatusBadParm, + "Outer: the direciton and the gradient type do not match"); + } + } + + bool isForward() const { return is_fwd; } + const TensorDescriptor& GetX1Desc() const { return x1Desc; } + const TensorDescriptor& GetX2Desc() const { return x2Desc; } + const TensorDescriptor& GetYDesc() const { return yDesc; } + + bool IsAllPacked() const + { + if(!x1Desc.IsPacked()) + return false; + if(!x2Desc.IsPacked()) + return false; + if(!yDesc.IsPacked()) + return false; + return true; + } + + NetworkConfig MakeNetworkConfig() const override; + +private: + bool is_fwd; + gradType grad; + TensorDescriptor x1Desc; + TensorDescriptor x2Desc; + TensorDescriptor yDesc; +}; + +} // namespace outer +} // namespace miopen diff --git a/src/include/miopen/outer/solvers.hpp b/src/include/miopen/outer/solvers.hpp new file mode 100644 index 0000000000..aa2321ad85 --- /dev/null +++ b/src/include/miopen/outer/solvers.hpp @@ -0,0 +1,85 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace miopen { + +namespace solver { + +namespace outer { + +using OuterSolver = NonTunableSolverBase; + +struct OuterForward final : OuterSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const override; + std::size_t GetWorkspaceSize(const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const override; + bool MayNeedWorkspace() const override { return true; } +}; + +struct OuterBackwardGrad1 final : OuterSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const override; + std::size_t GetWorkspaceSize(const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const override; + bool MayNeedWorkspace() const override { return true; } +}; + +struct OuterBackwardGrad2 final : OuterSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const override; + std::size_t GetWorkspaceSize(const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const override; + bool MayNeedWorkspace() const override { return true; } +}; + +} // namespace outer + +} // namespace solver + +} // namespace miopen diff --git a/src/kernels/MIOpenOuter.cpp b/src/kernels/MIOpenOuter.cpp new file mode 100644 index 0000000000..ad05ea0f0f --- /dev/null +++ b/src/kernels/MIOpenOuter.cpp @@ -0,0 +1,88 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" + +extern "C" __global__ void OuterForward(const FLOAT* input1, + const FLOAT* input2, + FLOAT* output, + const size_t n, + const size_t m, + const size_t nm) +{ + size_t gid = blockIdx.x * blockDim.x + threadIdx.x; + if(gid >= nm) + return; + + size_t ix[2]; + + ix[0] = gid / m; + ix[1] = gid % m; + + output[gid] = CVT_ACCUM2FLOAT(CVT_FLOAT2ACCUM(input1[ix[0]]) * CVT_FLOAT2ACCUM(input2[ix[1]])); +} + +extern "C" __global__ void OuterBackwardGrad1(const FLOAT* input2, + FLOAT* input1_grad, + const FLOAT* output_grad, + const size_t n, + const size_t m) +{ + size_t gid = blockIdx.x * blockDim.x + threadIdx.x; + if(gid >= n) + return; + + FLOAT_ACCUM sum = 0; + for(size_t j = 0; j < m; ++j) + { + sum += CVT_FLOAT2ACCUM(input2[j]) * CVT_FLOAT2ACCUM(output_grad[gid * m + j]); + } + + input1_grad[gid] = CVT_ACCUM2FLOAT(sum); +} + +extern "C" __global__ void OuterBackwardGrad2(const FLOAT* input1, + FLOAT* input2_grad, + const FLOAT* output_grad, + const size_t n, + const size_t m) +{ + size_t gid = blockIdx.x * blockDim.x + threadIdx.x; + if(gid >= m) + return; + + FLOAT_ACCUM sum = 0; + for(size_t i = 0; i < n; ++i) + { + sum += CVT_FLOAT2ACCUM(input1[i]) * CVT_FLOAT2ACCUM(output_grad[i * m + gid]); + } + + input2_grad[gid] = CVT_ACCUM2FLOAT(sum); +} diff --git a/src/outer.cpp b/src/outer.cpp new file mode 100644 index 0000000000..9e6fe29cef --- /dev/null +++ b/src/outer.cpp @@ -0,0 +1,93 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +miopenStatus_t OuterForward(Handle& handle, + const TensorDescriptor& x1Desc, + ConstData_t x1, + const TensorDescriptor& x2Desc, + ConstData_t x2, + const TensorDescriptor& yDesc, + Data_t y) +{ + const auto problem = outer::ProblemDescription(true, outer::NONE, x1Desc, x2Desc, yDesc); + const auto invoke_params = outer::InvokeParamsForward{x1Desc, x1, x2Desc, x2, yDesc, y}; + const auto algo = AlgorithmName{"OuterForward"}; + const auto solvers = solver::SolverContainer{}; + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +miopenStatus_t OuterBackwardGrad1(Handle& handle, + const TensorDescriptor& x2Desc, + ConstData_t x2, + const TensorDescriptor& x1GradDesc, + ConstData_t x1Grad, + const TensorDescriptor& yGradDesc, + ConstData_t yGrad) +{ + const auto problem = + outer::ProblemDescription(false, outer::ONE, x1GradDesc, x2Desc, yGradDesc); + const auto invoke_params = + outer::InvokeParamsBackwardGrad1{x2Desc, x2, x1GradDesc, x1Grad, yGradDesc, yGrad}; + const auto algo = AlgorithmName{"OuterBackwardGrad1"}; + const auto solvers = solver::SolverContainer{}; + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +miopenStatus_t OuterBackwardGrad2(Handle& handle, + const TensorDescriptor& x1Desc, + ConstData_t x1, + const TensorDescriptor& x2GradDesc, + ConstData_t x2Grad, + const TensorDescriptor& yGradDesc, + ConstData_t yGrad) +{ + const auto problem = + outer::ProblemDescription(false, outer::TWO, x1Desc, x2GradDesc, yGradDesc); + const auto invoke_params = + outer::InvokeParamsBackwardGrad2{x1Desc, x1, x2GradDesc, x2Grad, yGradDesc, yGrad}; + const auto algo = AlgorithmName{"OuterBackwardGrad2"}; + const auto solvers = solver::SolverContainer{}; + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/outer/problem_description.cpp b/src/outer/problem_description.cpp new file mode 100644 index 0000000000..f69a8dec03 --- /dev/null +++ b/src/outer/problem_description.cpp @@ -0,0 +1,68 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include + +#include + +namespace miopen { + +namespace outer { + +NetworkConfig ProblemDescription::MakeNetworkConfig() const +{ + std::ostringstream ss; + if(is_fwd == true) + { + ss << "outerfwd"; + } + else + { + if(grad == ONE) + { + ss << "outerbwdgrad1"; + } + else if(grad == TWO) + { + ss << "outerbwdgrad2"; + } + } + auto x1length = x1Desc.GetLengths(); + auto x2length = x2Desc.GetLengths(); + auto ylength = yDesc.GetLengths(); + auto dtype = x1Desc.GetType(); + ss << "dtype" << dtype; + ss << "x1len" << x1length[0]; + ss << "x2len" << x2length[0]; + + return NetworkConfig{ss.str()}; +} + +} // namespace outer + +} // namespace miopen diff --git a/src/outer_api.cpp b/src/outer_api.cpp new file mode 100644 index 0000000000..9dac3d9c22 --- /dev/null +++ b/src/outer_api.cpp @@ -0,0 +1,128 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include + +static void LogCmdOuter(const miopenTensorDescriptor_t Desc, bool is_fwd) +{ + if(miopen::IsLoggingCmd()) + { + std::stringstream ss; + auto dtype = miopen::deref(Desc).GetType(); + if(dtype == miopenHalf) + { + ss << "sumfp16"; + } + else if(dtype == miopenFloat) + { + ss << "sumfp32"; + } + else if(dtype == miopenBFloat16) + { + ss << "sumbfp16"; + } + + ss << " -N " << miopen::deref(Desc).GetLengths()[0] << " -M " + << miopen::deref(Desc).GetLengths()[1]; + + ss << " -F " << ((is_fwd) ? "1" : "0"); + + MIOPEN_LOG_DRIVER_CMD(ss.str()); + } +} + +extern "C" miopenStatus_t miopenOuterForward(miopenHandle_t handle, + const miopenTensorDescriptor_t x1Desc, + const void* x1, + const miopenTensorDescriptor_t x2Desc, + const void* x2, + const miopenTensorDescriptor_t yDesc, + void* y) +{ + MIOPEN_LOG_FUNCTION(handle, x1Desc, x1, x2Desc, x2, yDesc, y); + + LogCmdOuter(yDesc, true); + + return miopen::try_([&] { + miopen::OuterForward(miopen::deref(handle), + miopen::deref(x1Desc), + DataCast(x1), + miopen::deref(x2Desc), + DataCast(x2), + miopen::deref(yDesc), + DataCast(y)); + }); +} + +extern "C" miopenStatus_t miopenOuterBackwardGrad1(miopenHandle_t handle, + const miopenTensorDescriptor_t x2Desc, + const void* x2, + const miopenTensorDescriptor_t x1GradDesc, + void* x1Grad, + const miopenTensorDescriptor_t yGradDesc, + const void* yGrad) +{ + MIOPEN_LOG_FUNCTION(handle, x2Desc, x2, x1GradDesc, x1Grad, yGradDesc, yGrad); + + LogCmdOuter(yGradDesc, false); + + return miopen::try_([&] { + miopen::OuterBackwardGrad1(miopen::deref(handle), + miopen::deref(x2Desc), + DataCast(x2), + miopen::deref(x1GradDesc), + DataCast(x1Grad), + miopen::deref(yGradDesc), + DataCast(yGrad)); + }); +} + +extern "C" miopenStatus_t miopenOuterBackwardGrad2(miopenHandle_t handle, + const miopenTensorDescriptor_t x1Desc, + const void* x1, + const miopenTensorDescriptor_t x2GradDesc, + void* x2Grad, + const miopenTensorDescriptor_t yGradDesc, + const void* yGrad) +{ + MIOPEN_LOG_FUNCTION(handle, x1Desc, x1, x2GradDesc, x2Grad, yGradDesc, yGrad); + + LogCmdOuter(yGradDesc, true); + + return miopen::try_([&] { + miopen::OuterBackwardGrad2(miopen::deref(handle), + miopen::deref(x1Desc), + DataCast(x1), + miopen::deref(x2GradDesc), + DataCast(x2Grad), + miopen::deref(yGradDesc), + DataCast(yGrad)); + }); +} diff --git a/src/solver/outer/backwardgrad1_outer.cpp b/src/solver/outer/backwardgrad1_outer.cpp new file mode 100644 index 0000000000..ff46e7fd66 --- /dev/null +++ b/src/solver/outer/backwardgrad1_outer.cpp @@ -0,0 +1,131 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include +#include +#include +#include +#include + +namespace miopen { + +namespace solver { + +namespace outer { + +static bool IsImprovementOverROCm(const miopen::outer::ProblemDescription& problem) +{ + auto ydims = problem.GetYDesc().GetLengths(); + if(ydims[0] <= 32 && ydims[1] <= 128) + return true; + else + return false; +} + +bool OuterBackwardGrad1::IsApplicable([[maybe_unused]] const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const +{ + if(!problem.IsAllPacked()) + return false; + if(!IsImprovementOverROCm(problem)) + return false; + return true; +} + +ConvSolution OuterBackwardGrad1::GetSolution([[maybe_unused]] const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const +{ + static const size_t LOCAL_SIZE = 256; + auto result = ConvSolution{miopenStatusSuccess}; + + auto dtype = problem.GetX1Desc().GetType(); + auto x1dims = problem.GetX1Desc().GetLengths(); + + auto input_dtype = miopen::GetDataType(problem.GetX1Desc().GetType()); + auto output_dtype = miopen::GetDataType(problem.GetYDesc().GetType()); + + size_t xlocalsize = LOCAL_SIZE; + size_t ylocalsize = 1; + size_t zlocalsize = 1; + + size_t xgridsize = x1dims[0]; + if(xgridsize % LOCAL_SIZE != 0) + { + xgridsize = (xgridsize / LOCAL_SIZE + 1) * LOCAL_SIZE; + } + size_t ygridsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenOuter.cpp"; + kernel.kernel_name = "OuterBackwardGrad1"; + + const auto build_params = + KernelBuildParameters{{"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}}; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto yGradDims = params.yGradDesc.GetLengths(); + + kernel(params.x2, params.x1Grad, params.yGrad, yGradDims[0], yGradDims[1]); + }; + }; + + result.construction_params.push_back(kernel); + + return result; +} + +std::size_t OuterBackwardGrad1::GetWorkspaceSize( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::outer::ProblemDescription& problem) const +{ + return 0; +} + +} // namespace outer + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/outer/backwardgrad2_outer.cpp b/src/solver/outer/backwardgrad2_outer.cpp new file mode 100644 index 0000000000..8901ad16e3 --- /dev/null +++ b/src/solver/outer/backwardgrad2_outer.cpp @@ -0,0 +1,131 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include +#include +#include +#include +#include + +namespace miopen { + +namespace solver { + +namespace outer { + +static bool IsImprovementOverROCm(const miopen::outer::ProblemDescription& problem) +{ + auto ydims = problem.GetYDesc().GetLengths(); + if(ydims[0] <= 32 && ydims[1] <= 128) + return true; + else + return false; +} + +bool OuterBackwardGrad2::IsApplicable([[maybe_unused]] const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const +{ + if(!problem.IsAllPacked()) + return false; + if(!IsImprovementOverROCm(problem)) + return false; + return true; +} + +ConvSolution OuterBackwardGrad2::GetSolution([[maybe_unused]] const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const +{ + static const size_t LOCAL_SIZE = 256; + auto result = ConvSolution{miopenStatusSuccess}; + + auto dtype = problem.GetX1Desc().GetType(); + auto x2dims = problem.GetX2Desc().GetLengths(); + + auto input_dtype = miopen::GetDataType(problem.GetX1Desc().GetType()); + auto output_dtype = miopen::GetDataType(problem.GetYDesc().GetType()); + + size_t xlocalsize = LOCAL_SIZE; + size_t ylocalsize = 1; + size_t zlocalsize = 1; + + size_t xgridsize = x2dims[0]; + if(xgridsize % LOCAL_SIZE != 0) + { + xgridsize = (xgridsize / LOCAL_SIZE + 1) * LOCAL_SIZE; + } + size_t ygridsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenOuter.cpp"; + kernel.kernel_name = "OuterBackwardGrad2"; + + const auto build_params = + KernelBuildParameters{{"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}}; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto yGradDims = params.yGradDesc.GetLengths(); + + kernel(params.x1, params.x2Grad, params.yGrad, yGradDims[0], yGradDims[1]); + }; + }; + + result.construction_params.push_back(kernel); + + return result; +} + +std::size_t OuterBackwardGrad2::GetWorkspaceSize( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::outer::ProblemDescription& problem) const +{ + return 0; +} + +} // namespace outer + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/outer/forward_outer.cpp b/src/solver/outer/forward_outer.cpp new file mode 100644 index 0000000000..16617340b7 --- /dev/null +++ b/src/solver/outer/forward_outer.cpp @@ -0,0 +1,144 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include +#include +#include +#include +#include + +namespace miopen { + +namespace solver { + +namespace outer { + +static bool IsImprovementOverROCm(const miopen::outer::ProblemDescription& problem) +{ + auto dtype = problem.GetX1Desc().GetType(); + auto ydims = problem.GetYDesc().GetLengths(); + + if((ydims[0] <= 512) || (2048 < ydims[0] && ydims[1] <= 128) || + ((ydims[0] <= 2048 && ydims[1] <= 2048) && + (dtype == miopenHalf || dtype == miopenBFloat16)) || + ((ydims[0] <= 2048 && ydims[1] <= 512) && dtype == miopenFloat)) + return true; + else + return false; +} + +bool OuterForward::IsApplicable( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::outer::ProblemDescription& problem) const +{ + if(!problem.IsAllPacked()) + return false; + if(!IsImprovementOverROCm(problem)) + return false; + return true; +} + +ConvSolution OuterForward::GetSolution([[maybe_unused]] const ExecutionContext& context, + const miopen::outer::ProblemDescription& problem) const +{ + static const size_t LOCAL_SIZE = 256; + auto result = ConvSolution{miopenStatusSuccess}; + + auto dtype = problem.GetX1Desc().GetType(); + auto x1dims = problem.GetX1Desc().GetLengths(); + auto x2dims = problem.GetX2Desc().GetLengths(); + auto ydims = problem.GetYDesc().GetLengths(); + + auto input_dtype = miopen::GetDataType(problem.GetX1Desc().GetType()); + auto output_dtype = miopen::GetDataType(problem.GetYDesc().GetType()); + + size_t xlocalsize = LOCAL_SIZE; + size_t ylocalsize = 1; + size_t zlocalsize = 1; + + size_t xgridsize = ydims[0] * ydims[1]; + if(xgridsize % LOCAL_SIZE != 0) + { + xgridsize = (xgridsize / LOCAL_SIZE + 1) * LOCAL_SIZE; + } + size_t ygridsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenOuter.cpp"; + kernel.kernel_name = "OuterForward"; + + const auto build_params = + KernelBuildParameters{{"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}}; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto yGradDims = params.yDesc.GetLengths(); + + kernel(params.x1, + params.x2, + params.y, + yGradDims[0], + yGradDims[1], + yGradDims[0] * yGradDims[1]); + }; + }; + + result.construction_params.push_back(kernel); + + return result; +} + +std::size_t OuterForward::GetWorkspaceSize( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::outer::ProblemDescription& problem) const +{ + return 0; +} + +} // namespace outer + +} // namespace solver + +} // namespace miopen diff --git a/test/cpu_outer.hpp b/test/cpu_outer.hpp new file mode 100644 index 0000000000..28372a3ce3 --- /dev/null +++ b/test/cpu_outer.hpp @@ -0,0 +1,85 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_CPU_OUTER_HPP +#define GUARD_CPU_OUTER_HPP + +#include "tensor_holder.hpp" + +template +void cpu_outer_forward(tensor input1, tensor input2, tensor& ref_output) +{ + auto input1_dims = input1.desc.GetLengths(); + auto input2_dims = input2.desc.GetLengths(); + auto output_dims = ref_output.desc.GetLengths(); + + size_t in_n = input1_dims[0]; + size_t in_m = input2_dims[0]; + + size_t cnt = 0; + + for(size_t i = 0; i < in_n; i++) + { + for(size_t j = 0; j < in_m; j++) + { + ref_output[cnt++] = input1[i] * input2[j]; + } + } +} + +template +void cpu_outer_backward(tensor input1, + tensor input2, + tensor outputGrad, + tensor& input1Grad, + tensor& input2Grad) +{ + auto input1_dims = input1.desc.GetLengths(); + auto input2_dims = input2.desc.GetLengths(); + auto output_dims = outputGrad.desc.GetLengths(); + + size_t in_n = input1_dims[0]; + size_t in_m = input2_dims[0]; + + for(size_t i = 0; i < in_n; i++) + { + float sum = 0; + for(size_t j = 0; j < in_m; j++) + { + sum += static_cast(outputGrad[i * in_m + j]) * static_cast(input2[j]); + } + input1Grad[i] = sum; + } + for(size_t j = 0; j < in_m; j++) + { + float sum = 0; + for(size_t i = 0; i < in_n; i++) + { + sum += static_cast(input1[i]) * static_cast(outputGrad[i * in_m + j]); + } + input2Grad[j] = sum; + } +} +#endif diff --git a/test/gtest/outer.cpp b/test/gtest/outer.cpp new file mode 100644 index 0000000000..b09893f34f --- /dev/null +++ b/test/gtest/outer.cpp @@ -0,0 +1,161 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "outer.hpp" +#include + +MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) + +namespace env = miopen::env; + +namespace outer { + +std::string GetFloatArg() +{ + const auto tmp = env::value(MIOPEN_TEST_FLOAT_ARG); + if(tmp.empty()) + { + return ""; + } + return tmp; +} + +struct OuterFwdTestFloat : OuterFwdTest +{ +}; + +struct OuterBwdTestFloat : OuterBwdTest +{ +}; + +struct OuterFwdTestHalf : OuterFwdTest +{ +}; + +struct OuterBwdTestHalf : OuterBwdTest +{ +}; + +struct OuterFwdTestBFloat16 : OuterFwdTest +{ +}; + +struct OuterBwdTestBFloat16 : OuterBwdTest +{ +}; + +} // namespace outer +using namespace outer; + +TEST_P(OuterFwdTestFloat, OuterFwdTest) +{ + if(env::enabled(MIOPEN_TEST_ALL) && (GetFloatArg() == "--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(OuterBwdTestFloat, OuterBwdTest) +{ + if(env::enabled(MIOPEN_TEST_ALL) && (GetFloatArg() == "--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(OuterFwdTestHalf, OuterFwdTest) +{ + if(env::enabled(MIOPEN_TEST_ALL) && (GetFloatArg() == "--half")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(OuterBwdTestHalf, OuterBwdTest) +{ + if(env::enabled(MIOPEN_TEST_ALL) && (GetFloatArg() == "--half")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(OuterFwdTestBFloat16, OuterFwdTest) +{ + if(env::enabled(MIOPEN_TEST_ALL) && (GetFloatArg() == "--bfloat16")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(OuterBwdTestBFloat16, OuterBwdTest) +{ + if(env::enabled(MIOPEN_TEST_ALL) && (GetFloatArg() == "--bfloat16")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +INSTANTIATE_TEST_SUITE_P(OuterTestSet, OuterFwdTestFloat, testing::ValuesIn(OuterFwdTestConfigs())); +INSTANTIATE_TEST_SUITE_P(OuterTestSet, OuterBwdTestFloat, testing::ValuesIn(OuterBwdTestConfigs())); +INSTANTIATE_TEST_SUITE_P(OuterTestSet, OuterFwdTestHalf, testing::ValuesIn(OuterFwdTestConfigs())); +INSTANTIATE_TEST_SUITE_P(OuterTestSet, OuterBwdTestHalf, testing::ValuesIn(OuterBwdTestConfigs())); +INSTANTIATE_TEST_SUITE_P(OuterTestSet, + OuterFwdTestBFloat16, + testing::ValuesIn(OuterFwdTestConfigs())); +INSTANTIATE_TEST_SUITE_P(OuterTestSet, + OuterBwdTestBFloat16, + testing::ValuesIn(OuterBwdTestConfigs())); diff --git a/test/gtest/outer.hpp b/test/gtest/outer.hpp new file mode 100644 index 0000000000..00b22d2314 --- /dev/null +++ b/test/gtest/outer.hpp @@ -0,0 +1,237 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "../driver/tensor_driver.hpp" +#include "cpu_outer.hpp" +#include "get_handle.hpp" +#include "random.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" +#include +#include +#include + +struct OuterTestCase +{ + size_t N; + size_t M; + friend std::ostream& operator<<(std::ostream& os, const OuterTestCase& tc) + { + return os << " N:" << tc.N << " M:" << tc.M; + } + + std::vector GetInput() { return std::vector({N, M}); } +}; + +std::vector OuterFwdTestConfigs() +{ + return {{512, 128}, + {512, 32768}, + {2048, 128}, + {2048, 256}, + {2048, 512}, + {32768, 32}, + {32768, 64}, + {32768, 128}}; +} + +std::vector OuterBwdTestConfigs() +{ + return {{16, 16}, {16, 32}, {16, 64}, {16, 128}, {32, 16}, {32, 32}, {32, 64}, {32, 128}}; +} + +template +struct OuterFwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + outer_config = GetParam(); + auto gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(1, 10); }; + + auto in_dims = outer_config.GetInput(); + + input1 = tensor{std::vector({in_dims[0]})}.generate(gen_value); + input2 = tensor{std::vector({in_dims[1]})}.generate(gen_value); + + std::vector out_dims; + + for(int i = 0; i < in_dims.size(); i++) + { + out_dims.push_back(in_dims[i]); + } + + output = tensor{out_dims}; + std::fill(output.begin(), output.end(), std::numeric_limits::quiet_NaN()); + + ref_output = tensor{out_dims}; + std::fill(ref_output.begin(), ref_output.end(), std::numeric_limits::quiet_NaN()); + + input1_dev = handle.Write(input1.data); + input2_dev = handle.Write(input2.data); + output_dev = handle.Write(output.data); + } + void RunTest() + { + auto&& handle = get_handle(); + + cpu_outer_forward(input1, input2, ref_output); + miopenStatus_t status; + + status = miopen::OuterForward(handle, + input1.desc, + input1_dev.get(), + input2.desc, + input2_dev.get(), + output.desc, + output_dev.get()); + + EXPECT_EQ(status, miopenStatusSuccess); + + output.data = handle.Read(output_dev, output.data.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + auto error = miopen::rms_range(ref_output, output); + + EXPECT_TRUE(miopen::range_distance(ref_output) == miopen::range_distance(output)); + EXPECT_TRUE(error < threshold * 10) << "Error output beyond tolerance Error:" << error + << ", Thresholdx10: " << threshold * 10; + } + OuterTestCase outer_config; + + tensor input1; + tensor input2; + tensor output; + + tensor ref_output; + + miopen::Allocator::ManageDataPtr input1_dev; + miopen::Allocator::ManageDataPtr input2_dev; + miopen::Allocator::ManageDataPtr output_dev; +}; + +template +struct OuterBwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + outer_config = GetParam(); + auto gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(1, 10); }; + + auto in_dims = outer_config.GetInput(); + + input1 = tensor{std::vector({in_dims[0]})}.generate(gen_value); + input2 = tensor{std::vector({in_dims[1]})}.generate(gen_value); + + std::vector out_dims; + for(int i = 0; i < in_dims.size(); i++) + { + out_dims.push_back(in_dims[i]); + } + outputGrad = tensor{out_dims}.generate(gen_value); + + input1Grad = tensor{std::vector({in_dims[0]})}; + input2Grad = tensor{std::vector({in_dims[1]})}; + + std::fill(input1Grad.begin(), input1Grad.end(), std::numeric_limits::quiet_NaN()); + std::fill(input2Grad.begin(), input2Grad.end(), std::numeric_limits::quiet_NaN()); + + ref_input1Grad = tensor{std::vector({in_dims[0]})}; + ref_input2Grad = tensor{std::vector({in_dims[1]})}; + + input1_dev = handle.Write(input1.data); + input2_dev = handle.Write(input2.data); + outputGrad_dev = handle.Write(outputGrad.data); + + input1Grad_dev = handle.Write(input1Grad.data); + input2Grad_dev = handle.Write(input2Grad.data); + } + void RunTest() + { + auto&& handle = get_handle(); + + cpu_outer_backward(input1, input2, outputGrad, ref_input1Grad, ref_input2Grad); + miopenStatus_t status1, status2; + + status1 = miopen::OuterBackwardGrad1(handle, + input2.desc, + input2_dev.get(), + input1Grad.desc, + input1Grad_dev.get(), + outputGrad.desc, + outputGrad_dev.get()); + + status2 = miopen::OuterBackwardGrad2(handle, + input1.desc, + input1_dev.get(), + input2Grad.desc, + input2Grad_dev.get(), + outputGrad.desc, + outputGrad_dev.get()); + + EXPECT_EQ(status1, miopenStatusSuccess); + EXPECT_EQ(status2, miopenStatusSuccess); + + input1Grad.data = handle.Read(input1Grad_dev, input1Grad.data.size()); + input2Grad.data = handle.Read(input2Grad_dev, input2Grad.data.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + auto error1 = miopen::rms_range(ref_input1Grad, input1Grad); + auto error2 = miopen::rms_range(ref_input1Grad, input1Grad); + + EXPECT_TRUE(miopen::range_distance(ref_input1Grad) == miopen::range_distance(input1Grad)); + EXPECT_TRUE(miopen::range_distance(ref_input2Grad) == miopen::range_distance(input2Grad)); + EXPECT_TRUE(error1 < threshold * 10) << "Error1 output beyond tolerance Error:" << error1 + << ", Thresholdx10: " << threshold * 10; + EXPECT_TRUE(error2 < threshold * 10) << "Error2 output beyond tolerance Error:" << error2 + << ", Thresholdx10: " << threshold * 10; + } + OuterTestCase outer_config; + + tensor input1; + tensor input2; + tensor input1Grad; + tensor input2Grad; + tensor outputGrad; + + tensor ref_input1Grad; + tensor ref_input2Grad; + + miopen::Allocator::ManageDataPtr input1_dev; + miopen::Allocator::ManageDataPtr input2_dev; + miopen::Allocator::ManageDataPtr input1Grad_dev; + miopen::Allocator::ManageDataPtr input2Grad_dev; + miopen::Allocator::ManageDataPtr outputGrad_dev; +}; \ No newline at end of file