diff --git a/docs/apireference.rst b/docs/apireference.rst index f463954b45..38a0c58bec 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -24,4 +24,5 @@ API Reference layernorm sum argmax + groupnorm diff --git a/docs/groupnorm.rst b/docs/groupnorm.rst new file mode 100644 index 0000000000..d61eb1ac9e --- /dev/null +++ b/docs/groupnorm.rst @@ -0,0 +1,20 @@ + +GroupNorm Layer(experimental) +============================= + +The groupnorm types and functions. +It splits input channels into num_group groups and do normalize for each group. + +To enable this, define MIOPEN_BETA_API before including miopen.h. + + +miopenNormMode_t +----------------------- + +.. doxygenenum:: miopenNormMode_t + +miopenGroupNormForward +---------------------------------- + +.. doxygenfunction:: miopenGroupNormForward + diff --git a/docs/layernorm.rst b/docs/layernorm.rst index 89f1a3cc2d..86743637a3 100644 --- a/docs/layernorm.rst +++ b/docs/layernorm.rst @@ -6,10 +6,10 @@ The layernorm types and functions. To enable this, define MIOPEN_BETA_API before including miopen.h. -miopenLayerNormMode_t +miopenNormMode_t ----------------------- -.. doxygenenum:: miopenLayerNormMode_t +.. doxygenenum:: miopenNormMode_t miopenLayerNormForward ---------------------------------- diff --git a/driver/driver.hpp b/driver/driver.hpp index d450dec545..48783bd5c9 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -151,7 +151,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "pool[fp16], lrn[fp16], " "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], " "tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16], " - "argmax[bfp16|fp16]\n"); + "argmax[bfp16|fp16], groupnorm[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -175,6 +175,7 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "layernorm" && arg != "layernormfp16" && arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" && arg != "sumbfp16" && arg != "argmax" && arg != "argmaxfp16" && arg != "argmaxbfp16" && + arg != "groupnorm" && arg != "groupnormfp16" && arg != "groupnormbfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); diff --git a/driver/groupnorm_driver.hpp b/driver/groupnorm_driver.hpp new file mode 100644 index 0000000000..c143496cdd --- /dev/null +++ b/driver/groupnorm_driver.hpp @@ -0,0 +1,417 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#ifndef GUARD_MIOPEN_GROUPNORM_DRIVER_HPP +#define GUARD_MIOPEN_GROUPNORM_DRIVER_HPP + +#include "InputFlags.hpp" +#include "driver.hpp" +#include "mloGroupNormHost.hpp" +#include "tensor_driver.hpp" +#include "timer.hpp" +#include <../test/verify.hpp> +#include +#include +#include +#include +#include +#include +#include +#include <../test/tensor_holder.hpp> +#include "random.hpp" + +template +class GroupNormDriver : public Driver +{ +public: + GroupNormDriver() : Driver() + { + miopenCreateTensorDescriptor(&inputDesc); + miopenCreateTensorDescriptor(&weightDesc); + miopenCreateTensorDescriptor(&biasDesc); + miopenCreateTensorDescriptor(&outputDesc); + miopenCreateTensorDescriptor(&meanDesc); + miopenCreateTensorDescriptor(&rstdDesc); + + data_type = miopen_type{}; + } + + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + std::vector GetInputTensorLengthsFromCmdLine(); + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + + Tref GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~GroupNormDriver() override + { + miopenDestroyTensorDescriptor(inputDesc); + miopenDestroyTensorDescriptor(weightDesc); + miopenDestroyTensorDescriptor(biasDesc); + miopenDestroyTensorDescriptor(outputDesc); + miopenDestroyTensorDescriptor(meanDesc); + miopenDestroyTensorDescriptor(rstdDesc); + } + +private: + InputFlags inflags; + + int forw; + int dim_size; + + miopenTensorDescriptor_t inputDesc; + miopenTensorDescriptor_t weightDesc; + miopenTensorDescriptor_t biasDesc; + miopenTensorDescriptor_t outputDesc; + miopenTensorDescriptor_t meanDesc; + miopenTensorDescriptor_t rstdDesc; + + std::unique_ptr in_dev; + std::unique_ptr weight_dev; + std::unique_ptr bias_dev; + std::unique_ptr out_dev; + std::unique_ptr mean_dev; + std::unique_ptr rstd_dev; + + std::vector in; + std::vector weight; + std::vector bias; + std::vector out; + std::vector mean; + std::vector rstd; + std::vector outhost; + std::vector meanhost; + std::vector rstdhost; + + int num_groups; + float eps; + miopenNormMode_t mode; +}; + +template +int GroupNormDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int GroupNormDriver::GetandSetData() +{ + num_groups = inflags.GetValueInt("num_groups"); + eps = static_cast(inflags.GetValueDouble("eps")); + mode = miopenNormMode_t(inflags.GetValueInt("mode")); + + std::vector in_len = GetInputTensorLengthsFromCmdLine(); + std::vector weight_bias_len = {in_len[1]}; + std::vector mean_rstd_len = {in_len[0], num_groups}; + + SetTensorNd(inputDesc, in_len, data_type); + SetTensorNd(weightDesc, weight_bias_len, data_type); + SetTensorNd(biasDesc, weight_bias_len, data_type); + SetTensorNd(outputDesc, in_len, data_type); + SetTensorNd(meanDesc, mean_rstd_len, data_type); + SetTensorNd(rstdDesc, mean_rstd_len, data_type); + + return 0; +} + +template +int GroupNormDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward GroupNorm (Default=1)", "int"); + inflags.AddInputFlag("batchsize", 'n', "100", "Mini-batch size (Default=100)", "int"); + inflags.AddInputFlag("in_channels", 'c', "6", "Number of Input Channels (Default=6)", "int"); + inflags.AddInputFlag("in_d", 'D', "0", "Input Depth (Default=0)", "int"); + inflags.AddInputFlag("in_h", 'H', "32", "Input Height (Default=32)", "int"); + inflags.AddInputFlag("in_w", 'W', "32", "Input Width (Default=32)", "int"); + + inflags.AddInputFlag("eps", 'e', "0.00001", "Alpha (Default=0.00001)", "double"); + inflags.AddInputFlag("num_groups", 'g', "3", "num_groups", "int"); + inflags.AddInputFlag( + "mode", 'm', "0", "elemwise affine mode (0), weight and bias mode (1) (Default=0)", "int"); + + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +std::vector GroupNormDriver::GetInputTensorLengthsFromCmdLine() +{ + int in_n = inflags.GetValueInt("batchsize"); + int in_c = inflags.GetValueInt("in_channels"); + int in_w = inflags.GetValueInt("in_w"); + int in_h = inflags.GetValueInt("in_h"); + int in_d = inflags.GetValueInt("in_d"); + + if((in_n != 0) && (in_c != 0) && (in_d != 0) && (in_h != 0) && (in_w != 0)) + { + dim_size = 5; + return std::vector({in_n, in_c, in_d, in_h, in_w}); + } + else if((in_n != 0) && (in_c != 0) && (in_h != 0) && (in_w != 0)) + { + dim_size = 4; + return std::vector({in_n, in_c, in_h, in_w}); + } + else if((in_n != 0) && (in_c != 0) && (in_w != 0)) + { + dim_size = 3; + return std::vector({in_n, in_c, in_w}); + } + else + { + MIOPEN_THROW("Error Input Tensor Lengths"); + } +} + +template +int GroupNormDriver::AllocateBuffersAndCopy() +{ + size_t in_sz = GetTensorSize(inputDesc); + size_t weight_sz = GetTensorSize(weightDesc); + size_t bias_sz = GetTensorSize(biasDesc); + size_t out_sz = GetTensorSize(outputDesc); + size_t mean_sz = GetTensorSize(meanDesc); + size_t rstd_sz = GetTensorSize(rstdDesc); + + uint32_t ctx = 0; + + in_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(Tgpu))); + weight_dev = std::unique_ptr(new GPUMem(ctx, weight_sz, sizeof(Tgpu))); + bias_dev = std::unique_ptr(new GPUMem(ctx, bias_sz, sizeof(Tgpu))); + out_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); + mean_dev = std::unique_ptr(new GPUMem(ctx, mean_sz, sizeof(Tref))); + rstd_dev = std::unique_ptr(new GPUMem(ctx, rstd_sz, sizeof(Tref))); + + in = std::vector(in_sz, static_cast(0)); + weight = std::vector(weight_sz, static_cast(0)); + bias = std::vector(bias_sz, static_cast(0)); + out = std::vector(out_sz, static_cast(0)); + mean = std::vector(mean_sz, static_cast(0)); + rstd = std::vector(rstd_sz, static_cast(0)); + outhost = std::vector(out_sz, static_cast(0)); + meanhost = std::vector(mean_sz, static_cast(0)); + rstdhost = std::vector(rstd_sz, static_cast(0)); + + int status; + + for(int i = 0; i < in_sz; i++) + { + in[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + } + status = in_dev->ToGPU(q, in.data()); + + for(int i = 0; i < weight_sz; i++) + { + weight[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + } + status |= weight_dev->ToGPU(q, weight.data()); + + for(int i = 0; i < bias_sz; i++) + { + bias[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + } + status |= bias_dev->ToGPU(q, bias.data()); + + status |= out_dev->ToGPU(q, out.data()); + status |= mean_dev->ToGPU(q, mean.data()); + status |= rstd_dev->ToGPU(q, rstd.data()); + + if(status != 0) + std::cout << "Error copying data to GPU\n" << std::endl; + + return miopenStatusSuccess; +} + +template +int GroupNormDriver::RunForwardGPU() +{ + float kernel_total_time = 0.0; + float kernel_first_time = 0.0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenGroupNormForward(GetHandle(), + mode, + inputDesc, + in_dev->GetMem(), + weightDesc, + weight_dev->GetMem(), + biasDesc, + bias_dev->GetMem(), + num_groups, + eps, + outputDesc, + out_dev->GetMem(), + meanDesc, + mean_dev->GetMem(), + rstdDesc, + rstd_dev->GetMem()); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + printf("Wall-clock Time Forward GroupNorm Elapsed: %f ms\n", t.gettime_ms() / iter); + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + printf("GPU Kernel Time Forward GroupNorm Elapsed: %f ms\n", kernel_average_time); + } + + out_dev->FromGPU(GetStream(), out.data()); + mean_dev->FromGPU(GetStream(), mean.data()); + rstd_dev->FromGPU(GetStream(), rstd.data()); + + return miopenStatusSuccess; +} + +template +int GroupNormDriver::RunForwardCPU() +{ + mloGroupNormForwardRunHost(inputDesc, + in.data(), + weight.data(), + bias.data(), + outhost.data(), + meanhost.data(), + rstdhost.data(), + num_groups, + eps, + mode); + + return miopenStatusSuccess; +} + +template +int GroupNormDriver::RunBackwardGPU() +{ + return miopenStatusSuccess; +} + +template +Tref GroupNormDriver::GetTolerance() +{ + if(data_type == miopenHalf) + { + return 1e-3; + } + else if(data_type == miopenFloat) + { + return 5e-5; + } + else if(data_type == miopenDouble) + { + return 1e-10; + } + else if(data_type == miopenBFloat16) + { + return 5e-3; + } + return 0; +} + +template +int GroupNormDriver::VerifyForward() +{ + RunForwardCPU(); + const Tref tolerance = GetTolerance(); + auto error = miopen::rms_range(outhost, out); + + if(!std::isfinite(error) || error > tolerance) + { + std::cout << "Forward GroupNorm FAILED: " << error << std::endl; + return EC_VerifyFwd; + } + else + { + printf("Forward GroupNorm Verifies on CPU and GPU (err=%f)\n", error); + } + + auto meanerror = miopen::rms_range(meanhost, mean); + if(!std::isfinite(meanerror) || meanerror > tolerance) + { + std::cout << "Forward GroupNorm mean FAILED: " << meanerror << std::endl; + return EC_VerifyFwd; + } + else + { + printf("Forward GroupNorm mean Verifies on CPU and GPU (err=%f)\n", meanerror); + } + + auto rstderror = miopen::rms_range(rstdhost, rstd); + if(!std::isfinite(rstderror) || rstderror > tolerance) + { + std::cout << "Forward GroupNorm rstd FAILED: " << rstderror << std::endl; + return EC_VerifyFwd; + } + else + { + printf("Forward GroupNorm rstd Verifies on CPU and GPU (err=%f)\n", rstderror); + } + + return miopenStatusSuccess; +} + +template +int GroupNormDriver::VerifyBackward() +{ + return miopenStatusSuccess; +} + +#endif // GUARD_MIOPEN_GROUPNORM_DRIVER_HPP diff --git a/driver/layernorm_driver.hpp b/driver/layernorm_driver.hpp index 12e32dcae2..59b19f3029 100644 --- a/driver/layernorm_driver.hpp +++ b/driver/layernorm_driver.hpp @@ -51,7 +51,7 @@ int32_t mloLayerNormForwardRunHost(miopenTensorDescriptor_t inputDesc, Tcheck* rstdhost, float eps, int32_t normalized_dim, - miopenLayerNormMode_t mode) + miopenNormMode_t mode) { auto dims = miopen::deref(inputDesc).GetLengths(); size_t outer_size = 1; @@ -172,7 +172,7 @@ class LayerNormDriver : public Driver float eps; int dim; - miopenLayerNormMode_t mode; + miopenNormMode_t mode; }; template @@ -214,7 +214,7 @@ int LayerNormDriver::GetandSetData() SetTensorNd(rstdDesc, outer_len, data_type); eps = static_cast(inflags.GetValueDouble("eps")); - mode = miopenLayerNormMode_t(inflags.GetValueInt("mode")); + mode = miopenNormMode_t(inflags.GetValueInt("mode")); return 0; } diff --git a/driver/main.cpp b/driver/main.cpp index e6ebbb2604..99264cc989 100644 --- a/driver/main.cpp +++ b/driver/main.cpp @@ -31,6 +31,7 @@ #include "conv_driver.hpp" #include "CBAInferFusion_driver.hpp" #include "driver.hpp" +#include "groupnorm_driver.hpp" #include "gemm_driver.hpp" #include "lrn_driver.hpp" #include "pool_driver.hpp" @@ -178,6 +179,18 @@ int main(int argc, char* argv[]) { drv = new DropoutDriver(); } + else if(base_arg == "groupnorm") + { + drv = new GroupNormDriver(); + } + else if(base_arg == "groupnormfp16") + { + drv = new GroupNormDriver(); + } + else if(base_arg == "groupnormbfp16") + { + drv = new GroupNormDriver(); + } else if(base_arg == "tensorop") { drv = new TensorOpDriver(); diff --git a/driver/mloGroupNormHost.hpp b/driver/mloGroupNormHost.hpp new file mode 100644 index 0000000000..e89f389ec9 --- /dev/null +++ b/driver/mloGroupNormHost.hpp @@ -0,0 +1,85 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MLO_GROUPNORMHOST_H_ +#define MLO_GROUPNORMHOST_H_ + +//////////////////////////////////////////////////////////// +// +/////////////////////////////////////////////////////////// + +template +int32_t mloGroupNormForwardRunHost(miopenTensorDescriptor_t inputDesc, + Tgpu* input, + Tgpu* weight, + Tgpu* bias, + Tcheck* outputhost, + Tcheck* meanhost, + Tcheck* rstdhost, + uint64_t num_groups, + float eps, + miopenNormMode_t mode) +{ + auto dims = miopen::deref(inputDesc).GetLengths(); + + size_t numel = miopen::deref(inputDesc).GetElementSize(); + size_t numel_per_channel = numel / dims[0] / dims[1]; + size_t num_channels = dims[1]; + + size_t outer_size = dims[0] * num_groups; + size_t inner_size = numel / outer_size; + + for(size_t o = 0; o < outer_size; o++) + { + Tcheck pmean = 0.0f; + Tcheck pvar = 0.0f; + for(size_t i = 0; i < inner_size; i++) + { + Tcheck tmp = static_cast(input[o * inner_size + i]); + pmean += tmp; + pvar += tmp * tmp; + } + + pmean = pmean / inner_size; + pvar = pvar / inner_size - pmean * pmean; + Tcheck prstd = 1.0f / sqrt(pvar + eps); + + meanhost[o] = pmean; + rstdhost[o] = prstd; + + for(size_t i = 0; i < inner_size; i++) + { + size_t idx = o * inner_size + i; + size_t c = (idx / numel_per_channel) % num_channels; + Tcheck pweight = mode ? static_cast(weight[c]) : 1; + Tcheck pbias = mode ? static_cast(bias[c]) : 0; + + outputhost[idx] = (static_cast(input[idx]) - pmean) * prstd * pweight + pbias; + } + } + + return 0; +} +#endif diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 84bd92cdd2..e1e27f6dbc 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -66,6 +66,7 @@ * @defgroup find2 * @defgroup sum * @defgroup argmax + * @defgroup groupnorm * */ @@ -466,7 +467,7 @@ typedef enum } miopenLRNMode_t; #ifdef MIOPEN_BETA_API /*! @ingroup layernorm - * @enum miopenLayerNormMode_t + * @enum miopenNormMode_t * LayerNorm mode */ typedef enum @@ -474,7 +475,7 @@ typedef enum MIOPEN_ELEMENTWISE_AFFINE = 0, /*!< initialized to ones for weights and zeros for biases */ MIOPEN_WEIGHT_BIAS = 1, /*!< learnable weights and biases of the module of shape normalized_shape */ -} miopenLayerNormMode_t; +} miopenNormMode_t; #endif /*! @ingroup batchnorm * @enum miopenBatchNormMode_t @@ -2538,7 +2539,7 @@ MIOPEN_EXPORT miopenStatus_t miopenDestroyLRNDescriptor(miopenLRNDescriptor_t lr * @return miopenStatus_t */ MIOPEN_EXPORT miopenStatus_t miopenLayerNormForward(miopenHandle_t handle, - miopenLayerNormMode_t mode, + miopenNormMode_t mode, const miopenTensorDescriptor_t xDesc, const void* x, const miopenTensorDescriptor_t weightDesc, @@ -5652,6 +5653,53 @@ MIOPEN_EXPORT miopenStatus_t miopenArgmaxForward(miopenHandle_t handle, #endif +#ifdef MIOPEN_BETA_API +// GroupNorm APIs +/** @addtogroup groupnorm + * + * @{ + */ +/*! @brief Execute a groupnorm forward layer + * + * @param handle MIOpen handle (input) + * @param mode GroupNorm mode (input) + * @param xDesc Tensor descriptor for data input tensor x (input) + * @param x Data tensor x (input) + * @param weightDesc Tensor descriptor for data input tensor weight (input) + * @param weight Data tensor weight (input) + * @param biasDesc Tensor descriptor for data input tensor bias (input) + * @param bias Data tensor bias (input) + * @param num_groups nNmber of groups to separate the channels into (input) + * @param epsilon Value to stablize inverse variance calculation (input) + * @param yDesc Tensor descriptor for output data tensor y (input) + * @param y Data tensor y (output) + * @param meanDesc Tensor descriptor for output data tensor mean (input) + * @param mean Data tensor mean (output) + * @param rstdDesc Tensor descriptor for output data tensor rstd (input) + * @param rstd Data tensor rstd (output) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenGroupNormForward(miopenHandle_t handle, + miopenNormMode_t mode, + const miopenTensorDescriptor_t xDesc, + const void* x, + const miopenTensorDescriptor_t weightDesc, + const void* weight, + const miopenTensorDescriptor_t biasDesc, + const void* bias, + const uint64_t num_groups, + const float epsilon, + const miopenTensorDescriptor_t yDesc, + void* y, + const miopenTensorDescriptor_t meanDesc, + void* mean, + const miopenTensorDescriptor_t rstdDesc, + void* rstd); + +/** @} */ +// CLOSEOUT groupnorm DOXYGEN GROUP +#endif + #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 32ddf80999..14d6994a55 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -119,16 +119,18 @@ set( MIOpen_Source fusion.cpp fusion/problem_description.cpp generic_search.cpp + groupnorm_api.cpp + groupnorm/problem_description.cpp handle_api.cpp invoker_cache.cpp kernel_build_params.cpp kernel_warnings.cpp layernorm_api.cpp + layernorm/problem_description.cpp load_file.cpp lock_file.cpp logger.cpp lrn_api.cpp - norm/problem_description.cpp op_args.cpp operator.cpp performance_config.cpp @@ -243,9 +245,10 @@ set( MIOpen_Source solver/gemm.cpp solver/gemm_bwd.cpp solver/gemm_wrw.cpp - solver/norm/forward_layernorm.cpp - solver/norm/forward_layernorm2d_ck.cpp - solver/norm/forward_layernorm4d_ck.cpp + solver/groupnorm/forward_groupnorm.cpp + solver/layernorm/forward_layernorm.cpp + solver/layernorm/forward_layernorm2d_ck.cpp + solver/layernorm/forward_layernorm4d_ck.cpp solver/pooling/forward2d.cpp solver/pooling/forwardNaive.cpp solver/pooling/forwardNd.cpp @@ -432,6 +435,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenConvDirUni.cl kernels/MIOpenConvDirBatchNormActiv.cl kernels/MIOpenConvDirGenFwd.cl + kernels/MIOpenGroupNorm.cpp kernels/MIOpenLayerNorm.cpp kernels/MIOpenLRNBwd.cl kernels/MIOpenLRNFwd.cl @@ -553,6 +557,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN list(APPEND MIOpen_Source activ.cpp argmax.cpp + groupnorm.cpp kernel_cache.cpp layer_norm.cpp lrn.cpp diff --git a/src/groupnorm.cpp b/src/groupnorm.cpp new file mode 100644 index 0000000000..2cf21102af --- /dev/null +++ b/src/groupnorm.cpp @@ -0,0 +1,80 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +miopenStatus_t GroupNormForward(Handle& handle, + const TensorDescriptor& xDesc, + ConstData_t x, + const TensorDescriptor& weightDesc, + ConstData_t weight, + const TensorDescriptor& biasDesc, + ConstData_t bias, + const TensorDescriptor& yDesc, + Data_t y, + const TensorDescriptor& meanDesc, + Data_t mean, + const TensorDescriptor& rstdDesc, + Data_t rstd, + miopenNormMode_t mode, + uint64_t num_groups, + float epsilon) +{ + const auto problem = groupnorm::ProblemDescription{ + mode, xDesc, weightDesc, biasDesc, yDesc, meanDesc, rstdDesc, num_groups, epsilon}; + + const auto invoke_params = [&]() { + auto tmp = groupnorm::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.xDesc = &xDesc; + tmp.x = x; + tmp.weight = weight; + tmp.bias = bias; + tmp.y = y; + tmp.mean = mean; + tmp.rstd = rstd; + tmp.num_groups = num_groups; + tmp.epsilon = epsilon; + tmp.mode = mode; + return tmp; + }(); + + const auto algo = AlgorithmName{"GroupNormForward"}; + const auto solvers = solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/groupnorm/problem_description.cpp b/src/groupnorm/problem_description.cpp new file mode 100644 index 0000000000..86ba4bc2d9 --- /dev/null +++ b/src/groupnorm/problem_description.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include + +#include + +namespace miopen { + +namespace groupnorm { + +NetworkConfig ProblemDescription::MakeNetworkConfig() const +{ + auto dims = xDesc.GetLengths(); + size_t numel = xDesc.GetElementSize(); + size_t num_batches = dims[0]; + size_t num_channels = dims[1]; + + auto dtype = xDesc.GetType(); + + std::ostringstream ss; + + ss << "dtype" << dtype; + ss << "numel" << numel; + ss << "num_batches" << num_batches; + ss << "num_channels" << num_channels; + ss << "num_groups" << num_groups; + + return NetworkConfig{ss.str()}; +} + +} // namespace groupnorm + +} // namespace miopen diff --git a/src/groupnorm_api.cpp b/src/groupnorm_api.cpp new file mode 100644 index 0000000000..4b6c7b0970 --- /dev/null +++ b/src/groupnorm_api.cpp @@ -0,0 +1,139 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include +#include +#include + +static void LogCmdGroupNorm(const miopenTensorDescriptor_t xDesc, + const miopenNormMode_t mode, + uint64_t num_groups, + bool is_fwd) +{ + if(miopen::IsLoggingCmd()) + { + std::stringstream ss; + auto dtype = miopen::deref(xDesc).GetType(); + if(dtype == miopenHalf) + { + ss << "groupnormfp16"; + } + else if(dtype == miopenFloat) + { + ss << "groupnormfp32"; + } + else if(dtype == miopenBFloat16) + { + ss << "groupnormbfp16"; + } + else if(dtype == miopenDouble) + { + ss << "groupnormfp64"; + } + + int32_t size = {0}; + miopenGetTensorDescriptorSize(xDesc, &size); + ss << " -n " << miopen::deref(xDesc).GetLengths()[0] << " -c " + << miopen::deref(xDesc).GetLengths()[1]; + if(size == 5) + { + ss << " -D " << miopen::deref(xDesc).GetLengths()[2] << " -H " + << miopen::deref(xDesc).GetLengths()[3] << " -W " + << miopen::deref(xDesc).GetLengths()[4]; + } + else if(size == 4) + { + ss << " -H " << miopen::deref(xDesc).GetLengths()[2] << " -W " + << miopen::deref(xDesc).GetLengths()[3]; + } + else if(size == 3) + { + ss << " -W " << miopen::deref(xDesc).GetLengths()[2]; + } + + ss << " -g " << num_groups; + ss << " -m " << mode; + ss << " -F " << ((is_fwd) ? "1" : "2"); + + MIOPEN_LOG_DRIVER_CMD(ss.str()); + } +} + +extern "C" miopenStatus_t miopenGroupNormForward(miopenHandle_t handle, + miopenNormMode_t mode, + const miopenTensorDescriptor_t xDesc, + const void* x, + const miopenTensorDescriptor_t weightDesc, + const void* weight, + const miopenTensorDescriptor_t biasDesc, + const void* bias, + const uint64_t num_groups, + const float epsilon, + const miopenTensorDescriptor_t yDesc, + void* y, + const miopenTensorDescriptor_t meanDesc, + void* mean, + const miopenTensorDescriptor_t rstdDesc, + void* rstd) +{ + MIOPEN_LOG_FUNCTION(handle, + mode, + xDesc, + x, + weightDesc, + weight, + biasDesc, + bias, + num_groups, + epsilon, + yDesc, + y, + meanDesc, + mean, + rstdDesc, + rstd); + + LogCmdGroupNorm(xDesc, mode, num_groups, true); + return miopen::try_([&] { + miopen::GroupNormForward(miopen::deref(handle), + miopen::deref(xDesc), + DataCast(x), + miopen::deref(weightDesc), + DataCast(weight), + miopen::deref(biasDesc), + DataCast(bias), + miopen::deref(yDesc), + DataCast(y), + miopen::deref(meanDesc), + DataCast(mean), + miopen::deref(rstdDesc), + DataCast(rstd), + mode, + num_groups, + epsilon); + }); +} diff --git a/src/include/miopen/groupnorm.hpp b/src/include/miopen/groupnorm.hpp new file mode 100644 index 0000000000..837df25013 --- /dev/null +++ b/src/include/miopen/groupnorm.hpp @@ -0,0 +1,55 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#ifndef MIOPEN_GROUPNORM_HPP_ +#define MIOPEN_GROUPNORM_HPP_ + +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +miopenStatus_t GroupNormForward(Handle& handle, + const TensorDescriptor& xDesc, + ConstData_t x, + const TensorDescriptor& weightDesc, + ConstData_t weight, + const TensorDescriptor& biasDesc, + ConstData_t bias, + const TensorDescriptor& yDesc, + Data_t y, + const TensorDescriptor& meanDesc, + Data_t mean, + const TensorDescriptor& rstdDesc, + Data_t rstd, + miopenNormMode_t mode, + uint64_t num_groups, + float epsilon); + +} // namespace miopen +#endif // _MIOPEN_GROUPNORM_HPP_ diff --git a/src/include/miopen/norm/invoke_params.hpp b/src/include/miopen/groupnorm/invoke_params.hpp similarity index 79% rename from src/include/miopen/norm/invoke_params.hpp rename to src/include/miopen/groupnorm/invoke_params.hpp index de6abd8c7a..a02f3a4797 100644 --- a/src/include/miopen/norm/invoke_params.hpp +++ b/src/include/miopen/groupnorm/invoke_params.hpp @@ -30,7 +30,7 @@ #include namespace miopen { -namespace norm { +namespace groupnorm { struct InvokeParams : public miopen::InvokeParams { @@ -38,20 +38,20 @@ struct InvokeParams : public miopen::InvokeParams const TensorDescriptor* xDesc = nullptr; - ConstData_t x = nullptr; - ConstData_t weight = nullptr; - ConstData_t bias = nullptr; - Data_t y = nullptr; - Data_t mean = nullptr; - Data_t rstd = nullptr; - float epsilon = 0; - int32_t normalized_dim = 0; - miopenLayerNormMode_t mode = MIOPEN_ELEMENTWISE_AFFINE; + ConstData_t x = nullptr; + ConstData_t weight = nullptr; + ConstData_t bias = nullptr; + Data_t y = nullptr; + Data_t mean = nullptr; + Data_t rstd = nullptr; + uint64_t num_groups = 0; + float epsilon = 0; + miopenNormMode_t mode = MIOPEN_ELEMENTWISE_AFFINE; std::size_t GetWorkspaceSize() const { return 0; } Data_t GetWorkspace() const { return nullptr; } }; -} // namespace norm +} // namespace groupnorm } // namespace miopen diff --git a/src/include/miopen/groupnorm/problem_description.hpp b/src/include/miopen/groupnorm/problem_description.hpp new file mode 100644 index 0000000000..15c86b4dec --- /dev/null +++ b/src/include/miopen/groupnorm/problem_description.hpp @@ -0,0 +1,142 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace groupnorm { + +struct ProblemDescription : ProblemDescriptionBase +{ + ProblemDescription(miopenNormMode_t mode_, + const TensorDescriptor& xDesc_, + const TensorDescriptor& weightDesc_, + const TensorDescriptor& biasDesc_, + const TensorDescriptor& yDesc_, + const TensorDescriptor& meanDesc_, + const TensorDescriptor& rstdDesc_, + uint64_t num_groups_, + float epsilon_) + : mode(mode_), + xDesc(xDesc_), + weightDesc(weightDesc_), + biasDesc(biasDesc_), + yDesc(yDesc_), + meanDesc(meanDesc_), + rstdDesc(rstdDesc_), + num_groups(num_groups_), + epsilon(epsilon_) + { + if(xDesc.GetLengths() != yDesc.GetLengths()) + { + MIOPEN_THROW(miopenStatusBadParm, + "groupnorm::ProblemDescription: Tensor dimension lengths do not match."); + } + if((num_groups < 1) || (xDesc.GetLengths()[1] % num_groups != 0)) + { + MIOPEN_THROW(miopenStatusBadParm, + "groupnorm::ProblemDescription: The channel size of input tensor should " + "be divisible by num_groups."); + } + if(xDesc.GetLengths().size() < 3) + { + MIOPEN_THROW(miopenStatusBadParm, + "groupnorm::ProblemDescription: The number of dimensions of the input " + "tensor should be at least 3."); + } + } + + miopenNormMode_t GetMode() const { return mode; } + const TensorDescriptor& GetXDesc() const { return xDesc; } + const TensorDescriptor& GetWeightDesc() const { return weightDesc; } + const TensorDescriptor& GetBiasDesc() const { return biasDesc; } + const TensorDescriptor& GetYDesc() const { return yDesc; } + const TensorDescriptor& GetMeanDesc() const { return meanDesc; } + const TensorDescriptor& GetRstdDesc() const { return rstdDesc; } + int32_t GetNumGroups() const { return num_groups; } + float GetEpsilon() const { return epsilon; } + + bool IsValidType() const + { + if(xDesc.GetType() != yDesc.GetType()) + { + return false; + } + if(yDesc.GetType() != weightDesc.GetType()) + { + return false; + } + if(weightDesc.GetType() != biasDesc.GetType()) + { + return false; + } + if(meanDesc.GetType() != rstdDesc.GetType()) + { + return false; + } + return true; + } + + bool IsAllPacked() const + { + if(!(xDesc.IsPacked() && weightDesc.IsPacked() && biasDesc.IsPacked() && yDesc.IsPacked() && + meanDesc.IsPacked() && rstdDesc.IsPacked())) + { + return false; + } + return true; + } + + NetworkConfig MakeNetworkConfig() const override; + +private: + miopenNormMode_t mode; + TensorDescriptor xDesc; + TensorDescriptor weightDesc; + TensorDescriptor biasDesc; + TensorDescriptor yDesc; + TensorDescriptor meanDesc; + TensorDescriptor rstdDesc; + + uint64_t num_groups; + float epsilon; + + NetworkConfig MakeForwardNetworkConfig() const; +}; + +} // namespace groupnorm + +} // namespace miopen diff --git a/src/include/miopen/groupnorm/solvers.hpp b/src/include/miopen/groupnorm/solvers.hpp new file mode 100644 index 0000000000..70ede100d0 --- /dev/null +++ b/src/include/miopen/groupnorm/solvers.hpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace miopen { + +namespace solver { + +namespace groupnorm { + +using NormalizationSolver = + NonTunableSolverBase; + +struct GroupNormForward final : NormalizationSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::groupnorm::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::groupnorm::ProblemDescription& problem) const override; +}; + +} // namespace groupnorm + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/layernorm.hpp b/src/include/miopen/layernorm.hpp index 7780e57cda..3a8bf54a90 100644 --- a/src/include/miopen/layernorm.hpp +++ b/src/include/miopen/layernorm.hpp @@ -46,7 +46,7 @@ miopenStatus_t LayerNormForward(Handle& handle, Data_t mean, const TensorDescriptor& rstdDesc, Data_t rstd, - miopenLayerNormMode_t mode, + miopenNormMode_t mode, float epsilon, int32_t normalized_dim); diff --git a/src/include/miopen/layernorm/invoke_params.hpp b/src/include/miopen/layernorm/invoke_params.hpp new file mode 100644 index 0000000000..b97bac7d08 --- /dev/null +++ b/src/include/miopen/layernorm/invoke_params.hpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +namespace miopen { +namespace layernorm { + +struct InvokeParams : public miopen::InvokeParams +{ + InvokeParams() = default; + + const TensorDescriptor* xDesc = nullptr; + + ConstData_t x = nullptr; + ConstData_t weight = nullptr; + ConstData_t bias = nullptr; + Data_t y = nullptr; + Data_t mean = nullptr; + Data_t rstd = nullptr; + float epsilon = 0; + int32_t normalized_dim = 0; + miopenNormMode_t mode = MIOPEN_ELEMENTWISE_AFFINE; + + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } +}; + +} // namespace layernorm + +} // namespace miopen diff --git a/src/include/miopen/norm/problem_description.hpp b/src/include/miopen/layernorm/problem_description.hpp similarity index 96% rename from src/include/miopen/norm/problem_description.hpp rename to src/include/miopen/layernorm/problem_description.hpp index 8a1758457f..78a631b292 100644 --- a/src/include/miopen/norm/problem_description.hpp +++ b/src/include/miopen/layernorm/problem_description.hpp @@ -35,11 +35,11 @@ namespace miopen { struct NetworkConfig; -namespace norm { +namespace layernorm { struct ProblemDescription : ProblemDescriptionBase { - ProblemDescription(miopenLayerNormMode_t mode_, + ProblemDescription(miopenNormMode_t mode_, const TensorDescriptor& xDesc_, const TensorDescriptor& weightDesc_, const TensorDescriptor& biasDesc_, @@ -60,7 +60,7 @@ struct ProblemDescription : ProblemDescriptionBase { } - miopenLayerNormMode_t GetMode() const { return mode; } + miopenNormMode_t GetMode() const { return mode; } const TensorDescriptor& GetXDesc() const { return xDesc; } const TensorDescriptor& GetWeightDesc() const { return weightDesc; } const TensorDescriptor& GetBiasDesc() const { return biasDesc; } @@ -143,7 +143,7 @@ struct ProblemDescription : ProblemDescriptionBase NetworkConfig MakeNetworkConfig() const override; private: - miopenLayerNormMode_t mode; + miopenNormMode_t mode; TensorDescriptor xDesc; TensorDescriptor weightDesc; TensorDescriptor biasDesc; @@ -157,6 +157,6 @@ struct ProblemDescription : ProblemDescriptionBase NetworkConfig MakeForwardNetworkConfig() const; }; -} // namespace norm +} // namespace layernorm } // namespace miopen diff --git a/src/include/miopen/norm/solvers.hpp b/src/include/miopen/layernorm/solvers.hpp similarity index 75% rename from src/include/miopen/norm/solvers.hpp rename to src/include/miopen/layernorm/solvers.hpp index b29d692566..503bb87fb6 100644 --- a/src/include/miopen/norm/solvers.hpp +++ b/src/include/miopen/layernorm/solvers.hpp @@ -25,7 +25,7 @@ *******************************************************************************/ #pragma once -#include +#include #include #include @@ -33,19 +33,19 @@ namespace miopen { namespace solver { -namespace norm { +namespace layernorm { using NormalizationSolver = - NonTunableSolverBase; + NonTunableSolverBase; struct LayernormForward final : NormalizationSolver { const std::string& SolverDbId() const override { return GetSolverDbId(); } bool IsApplicable(const ExecutionContext& context, - const miopen::norm::ProblemDescription& problem) const override; + const miopen::layernorm::ProblemDescription& problem) const override; ConvSolution GetSolution(const ExecutionContext& context, - const miopen::norm::ProblemDescription& problem) const override; + const miopen::layernorm::ProblemDescription& problem) const override; }; struct Layernorm2DCKForward final : NormalizationSolver @@ -53,9 +53,9 @@ struct Layernorm2DCKForward final : NormalizationSolver const std::string& SolverDbId() const override { return GetSolverDbId(); } bool IsApplicable(const ExecutionContext& context, - const miopen::norm::ProblemDescription& problem) const override; + const miopen::layernorm::ProblemDescription& problem) const override; ConvSolution GetSolution(const ExecutionContext& context, - const miopen::norm::ProblemDescription& problem) const override; + const miopen::layernorm::ProblemDescription& problem) const override; }; struct Layernorm4DCKForward final : NormalizationSolver @@ -63,12 +63,12 @@ struct Layernorm4DCKForward final : NormalizationSolver const std::string& SolverDbId() const override { return GetSolverDbId(); } bool IsApplicable(const ExecutionContext& context, - const miopen::norm::ProblemDescription& problem) const override; + const miopen::layernorm::ProblemDescription& problem) const override; ConvSolution GetSolution(const ExecutionContext& context, - const miopen::norm::ProblemDescription& problem) const override; + const miopen::layernorm::ProblemDescription& problem) const override; }; -} // namespace norm +} // namespace layernorm } // namespace solver diff --git a/src/kernels/MIOpenGroupNorm.cpp b/src/kernels/MIOpenGroupNorm.cpp new file mode 100644 index 0000000000..54d70d323b --- /dev/null +++ b/src/kernels/MIOpenGroupNorm.cpp @@ -0,0 +1,121 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" + +extern "C" __global__ void GroupNormFwdContiguous(const FLOAT* __restrict__ x, + FLOAT* __restrict__ y, + const FLOAT* __restrict__ weight, + const FLOAT* __restrict__ bias, + FLOAT_ACCUM* __restrict__ mean, + FLOAT_ACCUM* __restrict__ rstd, + float eps, + uint64_t num_groups, + uint64_t num_channels, + uint64_t numel_per_channel, + bool mode) +{ + /* + * Each group works on a single channel. + * Example) + * x dim = {N, C, L}, normalized shape = {C, L} + * outer_size = N, inner_size = C * L + * + * Example2) + * x dim = {N, C, L}, normalized shape = {L} + * outer_size = N * C, inner_size = L + * + * => gws = {outer_size * LOCAL_SIZE}, lws = {LOCAL_SIZE} + */ + + /* + * Reduction to calculate mean and rstd + */ + + const uint64_t gid = blockIdx.x; + const uint64_t lid = threadIdx.x; + + const size_t inner_size = numel_per_channel * num_channels / num_groups; + + FLOAT_ACCUM pmean = static_cast(0); + FLOAT_ACCUM pvar = static_cast(0); + __shared__ FLOAT_ACCUM ltmp1[LOCAL_SIZE]; + __shared__ FLOAT_ACCUM ltmp2[LOCAL_SIZE]; + + // reduce sum for mean and var + for(uint64_t i = lid; i < inner_size; i += LOCAL_SIZE) + { + size_t x_idx = gid * inner_size + i; + + FLOAT_ACCUM tmp = CVT_FLOAT2ACCUM(x[x_idx]); + pmean += tmp; + pvar += tmp * tmp; + } + + ltmp1[lid] = pmean; + ltmp2[lid] = pvar; + __syncthreads(); + for(uint32_t i = LOCAL_SIZE >> 1; i > 0; i >>= 1) + { + if(lid < i) + { + ltmp1[lid] += ltmp1[lid + i]; + ltmp2[lid] += ltmp2[lid + i]; + } + __syncthreads(); + } + pmean = ltmp1[0] / inner_size; + pvar = ltmp2[0] / inner_size - pmean * pmean; + FLOAT_ACCUM prstd = rsqrt(pvar + FLOAT_ACCUM(eps)); + + if(lid == 0) + { + if(mean) + mean[gid] = pmean; + if(rstd) + rstd[gid] = prstd; + } + + // forward calculation + for(uint64_t i = lid; i < inner_size; i += LOCAL_SIZE) + { + size_t idx = gid * inner_size + i; + + FLOAT_ACCUM pweight; + FLOAT_ACCUM pbias; + + size_t c = mode ? (idx / numel_per_channel) % num_channels : 0; + pweight = mode ? CVT_FLOAT2ACCUM(weight[c]) : CVT_FP32_2ACCUM(1.0f); + pbias = mode ? CVT_FLOAT2ACCUM(bias[c]) : static_cast(0); + + FLOAT_ACCUM val = (CVT_FLOAT2ACCUM(x[idx]) - pmean) * prstd * pweight + pbias; + y[idx] = CVT_ACCUM2FLOAT(val); + } +} diff --git a/src/kernels/float_types.h b/src/kernels/float_types.h index beded11d8d..dc29a66a41 100644 --- a/src/kernels/float_types.h +++ b/src/kernels/float_types.h @@ -203,11 +203,11 @@ #if MIOPEN_USE_BFP16 == 1 #ifdef __HIP_PLATFORM_AMD__ -#define CVT_FLOAT2ACCUM(x) MIOPEN_ERROR_NOT_IMLEMENTED -#define CVT_ACCUM2FLOAT(x) MIOPEN_ERROR_NOT_IMLEMENTED -#define CVT_INTEGRAL2ACCUM(x) MIOPEN_ERROR_NOT_IMLEMENTED -#define CVT_FP32_2FLOAT(x) MIOPEN_ERROR_NOT_IMLEMENTED -#define CVT_FP32_2ACCUM(x) MIOPEN_ERROR_NOT_IMLEMENTED +#define CVT_FLOAT2ACCUM(x) (bfloat16_to_float(x)) +#define CVT_ACCUM2FLOAT(x) (float_to_bfloat16(x)) +#define CVT_INTEGRAL2ACCUM(x) (static_cast(x)) +#define CVT_FP32_2FLOAT(x) (CVT_ACCUM2FLOAT(x)) +#define CVT_FP32_2ACCUM(x) (x) #else #define CVT_FLOAT2ACCUM(x) (bfloat16_to_float(x)) #define CVT_ACCUM2FLOAT(x) (float_to_bfloat16(x)) diff --git a/src/layer_norm.cpp b/src/layer_norm.cpp index 40787235f7..7d2789973f 100644 --- a/src/layer_norm.cpp +++ b/src/layer_norm.cpp @@ -28,8 +28,8 @@ #include #include #include -#include -#include +#include +#include #include namespace miopen { @@ -47,15 +47,15 @@ miopenStatus_t LayerNormForward(Handle& handle, Data_t mean, const TensorDescriptor& rstdDesc, Data_t rstd, - miopenLayerNormMode_t mode, + miopenNormMode_t mode, float epsilon, int32_t normalized_dim) { - const auto problem = norm::ProblemDescription{ + const auto problem = layernorm::ProblemDescription{ mode, xDesc, weightDesc, biasDesc, yDesc, meanDesc, rstdDesc, epsilon, normalized_dim}; const auto invoke_params = [&]() { - auto tmp = norm::InvokeParams{}; + auto tmp = layernorm::InvokeParams{}; tmp.type = InvokeType::Run; tmp.xDesc = &xDesc; tmp.x = x; @@ -71,9 +71,9 @@ miopenStatus_t LayerNormForward(Handle& handle, }(); const auto algo = AlgorithmName{"LayerNormForward"}; - const auto solvers = solver::SolverContainer{}; + const auto solvers = solver::SolverContainer{}; solvers.ExecutePrimitive(handle, problem, algo, invoke_params); diff --git a/src/norm/problem_description.cpp b/src/layernorm/problem_description.cpp similarity index 95% rename from src/norm/problem_description.cpp rename to src/layernorm/problem_description.cpp index 3e99557187..0d56e98a8b 100644 --- a/src/norm/problem_description.cpp +++ b/src/layernorm/problem_description.cpp @@ -24,14 +24,14 @@ * *******************************************************************************/ -#include +#include #include #include namespace miopen { -namespace norm { +namespace layernorm { NetworkConfig ProblemDescription::MakeNetworkConfig() const { @@ -59,6 +59,6 @@ NetworkConfig ProblemDescription::MakeNetworkConfig() const return NetworkConfig{ss.str()}; } -} // namespace norm +} // namespace layernorm } // namespace miopen diff --git a/src/layernorm_api.cpp b/src/layernorm_api.cpp index 12cb18805e..a364620792 100644 --- a/src/layernorm_api.cpp +++ b/src/layernorm_api.cpp @@ -31,7 +31,7 @@ #include static void -LogCmdLayerNorm(const miopenTensorDescriptor_t xDesc, const miopenLayerNormMode_t mode, bool is_fwd) +LogCmdLayerNorm(const miopenTensorDescriptor_t xDesc, const miopenNormMode_t mode, bool is_fwd) { if(miopen::IsLoggingCmd()) { @@ -77,7 +77,7 @@ LogCmdLayerNorm(const miopenTensorDescriptor_t xDesc, const miopenLayerNormMode_ } extern "C" miopenStatus_t miopenLayerNormForward(miopenHandle_t handle, - miopenLayerNormMode_t mode, + miopenNormMode_t mode, const miopenTensorDescriptor_t xDesc, const void* x, const miopenTensorDescriptor_t weightDesc, diff --git a/src/solver.cpp b/src/solver.cpp index 031d63be7d..e4800fbd2d 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -28,9 +28,10 @@ #include #include -#include #include -#include +#include +#include +#include #include #include @@ -607,9 +608,11 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdInference{}.SolverDbId()); Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKBwdBackward{}.SolverDbId()); Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdTraining{}.SolverDbId()); - Register(registry, ++id, Primitive::Normalization, norm::Layernorm2DCKForward{}.SolverDbId()); - Register(registry, ++id, Primitive::Normalization, norm::Layernorm4DCKForward{}.SolverDbId()); - Register(registry, ++id, Primitive::Normalization, norm::LayernormForward{}.SolverDbId()); + Register( + registry, ++id, Primitive::Normalization, layernorm::Layernorm2DCKForward{}.SolverDbId()); + Register( + registry, ++id, Primitive::Normalization, layernorm::Layernorm4DCKForward{}.SolverDbId()); + Register(registry, ++id, Primitive::Normalization, layernorm::LayernormForward{}.SolverDbId()); Register(registry, ++id, Primitive::Reduce, reduce::SumForward{}.SolverDbId()); RegisterWithSolver(registry, ++id, @@ -629,6 +632,7 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) fusion::ConvCKIgemmFwdBiasResAddActivFused{}.SolverDbId(), miopenConvolutionAlgoImplicitGEMM); Register(registry, ++id, Primitive::Reduce, reduce::ArgmaxForward{}.SolverDbId()); + Register(registry, ++id, Primitive::Normalization, groupnorm::GroupNormForward{}.SolverDbId()); RegisterWithSolver(registry, ++id, diff --git a/src/solver/groupnorm/forward_groupnorm.cpp b/src/solver/groupnorm/forward_groupnorm.cpp new file mode 100644 index 0000000000..e4018d16ab --- /dev/null +++ b/src/solver/groupnorm/forward_groupnorm.cpp @@ -0,0 +1,148 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 1024 + +namespace miopen { + +namespace solver { + +namespace groupnorm { + +std::size_t sizeof_kernel_FLOAT_ACCUM(const miopen::groupnorm::ProblemDescription& problem) +{ + const auto datatype = problem.GetMeanDesc().GetType(); + return get_data_size(datatype); +} + +std::size_t sizeof_local_memory(const miopen::groupnorm::ProblemDescription& problem) +{ + return LOCAL_SIZE * sizeof_kernel_FLOAT_ACCUM(problem) * 2; +} + +bool GroupNormForward::IsApplicable(const ExecutionContext&, + const miopen::groupnorm::ProblemDescription& problem) const +{ + if(!problem.IsValidType()) + return false; + if(!problem.IsAllPacked()) + return false; + if(!(sizeof_local_memory(problem) <= TargetProperties::GetMaxLocalMemorySize())) + return false; + if(problem.GetXDesc().GetLengths()[0] * problem.GetNumGroups() < 32 || + problem.GetXDesc().GetLengths()[1] / problem.GetNumGroups() >= 64) + return false; + return true; +} + +ConvSolution +GroupNormForward::GetSolution(const ExecutionContext& context, + const miopen::groupnorm::ProblemDescription& problem) const +{ + std::ignore = context; + + auto result = ConvSolution{miopenStatusSuccess}; + + { + auto dtype = problem.GetXDesc().GetType(); + auto dims = problem.GetXDesc().GetLengths(); + + size_t num_groups = problem.GetNumGroups(); + size_t outer_size = dims[0] * num_groups; + + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = outer_size * xlocalsize; + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + + kernel.kernel_file = "MIOpenGroupNorm.cpp"; + kernel.kernel_name = "GroupNormFwdContiguous"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"LOCAL_SIZE", LOCAL_SIZE}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto dims = params.xDesc->GetLengths(); + size_t numel = params.xDesc->GetElementSize(); + size_t numel_per_channel = numel / dims[0] / dims[1]; + size_t num_channels = dims[1]; + + kernel(params.x, + params.y, + params.weight, + params.bias, + params.mean, + params.rstd, + params.epsilon, + params.num_groups, + num_channels, + numel_per_channel, + static_cast(params.mode)); + }; + }; + + return result; +} + +} // namespace groupnorm + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/norm/forward_layernorm.cpp b/src/solver/layernorm/forward_layernorm.cpp similarity index 86% rename from src/solver/norm/forward_layernorm.cpp rename to src/solver/layernorm/forward_layernorm.cpp index c3571309b5..0b625eb2fc 100644 --- a/src/solver/norm/forward_layernorm.cpp +++ b/src/solver/layernorm/forward_layernorm.cpp @@ -24,11 +24,12 @@ * *******************************************************************************/ +#include + +#include #include -#include #include -#include -#include +#include #include #define LOCAL_SIZE 256 @@ -37,15 +38,15 @@ namespace miopen { namespace solver { -namespace norm { +namespace layernorm { -std::size_t sizeof_kernel_FLOAT(const miopen::norm::ProblemDescription& problem) +std::size_t sizeof_kernel_FLOAT(const miopen::layernorm::ProblemDescription& problem) { const auto datatype = problem.GetXDesc().GetType(); return get_data_size(datatype); } -std::size_t sizeof_local_memory(const miopen::norm::ProblemDescription& problem) +std::size_t sizeof_local_memory(const miopen::layernorm::ProblemDescription& problem) { std::size_t rv = 0; rv += LOCAL_SIZE * sizeof_kernel_FLOAT(problem) * 2; @@ -53,7 +54,7 @@ std::size_t sizeof_local_memory(const miopen::norm::ProblemDescription& problem) } bool LayernormForward::IsApplicable(const ExecutionContext&, - const miopen::norm::ProblemDescription& problem) const + const miopen::layernorm::ProblemDescription& problem) const { if(!problem.IsSameType()) return false; @@ -68,9 +69,12 @@ bool LayernormForward::IsApplicable(const ExecutionContext&, return true; } -ConvSolution LayernormForward::GetSolution(const ExecutionContext&, - const miopen::norm::ProblemDescription& problem) const +ConvSolution +LayernormForward::GetSolution(const ExecutionContext& context, + const miopen::layernorm::ProblemDescription& problem) const { + std::ignore = context; + auto result = ConvSolution{miopenStatusSuccess}; { @@ -119,7 +123,7 @@ ConvSolution LayernormForward::GetSolution(const ExecutionContext&, result.invoker_factory = [](const std::vector& kernels) { return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { decltype(auto) kernel = handle_.Run(kernels.front()); - decltype(auto) params = raw_params.CastTo(); + decltype(auto) params = raw_params.CastTo(); auto dims = params.xDesc->GetLengths(); size_t inner_size = 1; @@ -144,7 +148,7 @@ ConvSolution LayernormForward::GetSolution(const ExecutionContext&, return result; } -} // namespace norm +} // namespace layernorm } // namespace solver diff --git a/src/solver/norm/forward_layernorm2d_ck.cpp b/src/solver/layernorm/forward_layernorm2d_ck.cpp similarity index 91% rename from src/solver/norm/forward_layernorm2d_ck.cpp rename to src/solver/layernorm/forward_layernorm2d_ck.cpp index 3fb40dc024..5eb909cb76 100644 --- a/src/solver/norm/forward_layernorm2d_ck.cpp +++ b/src/solver/layernorm/forward_layernorm2d_ck.cpp @@ -25,8 +25,8 @@ *******************************************************************************/ #include -#include -#include +#include +#include #if MIOPEN_USE_COMPOSABLEKERNEL #include #include @@ -35,7 +35,7 @@ MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_LAYERNORM2DCKFORWARD_CONV_CK_LN) namespace miopen { namespace solver { -namespace norm { +namespace layernorm { #if MIOPEN_USE_COMPOSABLEKERNEL using F16 = ck::half_t; @@ -68,7 +68,7 @@ using DeviceOpLnFwdPtrs = ck::tensor_operation::device::instance::DeviceOperatio namespace { struct CKArgs { - CKArgs(const miopen::norm::ProblemDescription& problem) + CKArgs(const miopen::layernorm::ProblemDescription& problem) { auto length = problem.GetXDesc().GetLengths(); @@ -140,7 +140,7 @@ struct CKArgs } // namespace template -bool CheckCKApplicability(const miopen::norm::ProblemDescription& problem) +bool CheckCKApplicability(const miopen::layernorm::ProblemDescription& problem) { const auto ln_args = CKArgs{problem}; const auto ln_ptrs = DeviceOpType::GetInstances(); @@ -152,7 +152,7 @@ bool CheckCKApplicability(const miopen::norm::ProblemDescription& problem) template typename LnPtrsType::iterator FindLnPtr(LnPtrsType& ln_ptrs, - const miopen::norm::ProblemDescription& problem) + const miopen::layernorm::ProblemDescription& problem) { const auto ln_args = CKArgs{problem}; return std::find_if(ln_ptrs.begin(), ln_ptrs.end(), [&ln_args](auto& ln_ptrs) { @@ -162,7 +162,7 @@ typename LnPtrsType::iterator FindLnPtr(LnPtrsType& ln_ptrs, template ConvSolution MakeInvokerFactory([[maybe_unused]] const ExecutionContext& context, - const miopen::norm::ProblemDescription& problem) + const miopen::layernorm::ProblemDescription& problem) { auto ln_ptr = DeviceOpType::GetInstances(); auto ln_ptr_iter = FindLnPtr(ln_ptr, problem); @@ -203,14 +203,14 @@ ConvSolution MakeInvokerFactory([[maybe_unused]] const ExecutionContext& context } #endif -bool IsRank2Dim1(const miopen::norm::ProblemDescription& problem) +bool IsRank2Dim1(const miopen::layernorm::ProblemDescription& problem) { return (problem.GetXDesc().GetLengths().size() == 2) && (problem.GetNormalizedDim() == 1); } bool Layernorm2DCKForward::IsApplicable( [[maybe_unused]] const ExecutionContext& context, - [[maybe_unused]] const miopen::norm::ProblemDescription& problem) const + [[maybe_unused]] const miopen::layernorm::ProblemDescription& problem) const { #if MIOPEN_USE_COMPOSABLEKERNEL if(miopen::IsDisabled(ENV(MIOPEN_DEBUG_LAYERNORM2DCKFORWARD_CONV_CK_LN))) @@ -247,7 +247,7 @@ bool Layernorm2DCKForward::IsApplicable( ConvSolution Layernorm2DCKForward::GetSolution( [[maybe_unused]] const ExecutionContext& context, - [[maybe_unused]] const miopen::norm::ProblemDescription& problem) const + [[maybe_unused]] const miopen::layernorm::ProblemDescription& problem) const { #if MIOPEN_USE_COMPOSABLEKERNEL switch(problem.GetXDesc().GetType()) @@ -255,11 +255,11 @@ ConvSolution Layernorm2DCKForward::GetSolution( case miopenHalf: return MakeInvokerFactory, CKArgs, - miopen::norm::InvokeParams>(context, problem); + miopen::layernorm::InvokeParams>(context, problem); case miopenFloat: return MakeInvokerFactory, CKArgs, - miopen::norm::InvokeParams>(context, problem); + miopen::layernorm::InvokeParams>(context, problem); case miopenDouble: case miopenBFloat16: case miopenInt8: @@ -274,6 +274,6 @@ ConvSolution Layernorm2DCKForward::GetSolution( return {}; } -} // namespace norm +} // namespace layernorm } // namespace solver } // namespace miopen diff --git a/src/solver/norm/forward_layernorm4d_ck.cpp b/src/solver/layernorm/forward_layernorm4d_ck.cpp similarity index 91% rename from src/solver/norm/forward_layernorm4d_ck.cpp rename to src/solver/layernorm/forward_layernorm4d_ck.cpp index a862b92928..edd3aeea4e 100644 --- a/src/solver/norm/forward_layernorm4d_ck.cpp +++ b/src/solver/layernorm/forward_layernorm4d_ck.cpp @@ -25,8 +25,8 @@ *******************************************************************************/ #include -#include -#include +#include +#include #if MIOPEN_USE_COMPOSABLEKERNEL #include #include @@ -35,7 +35,7 @@ MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_LAYERNORM4DCKFORWARD_CONV_CK_LN) namespace miopen { namespace solver { -namespace norm { +namespace layernorm { #if MIOPEN_USE_COMPOSABLEKERNEL using F16 = ck::half_t; @@ -68,7 +68,7 @@ using DeviceOpLnFwdPtrs = ck::tensor_operation::device::instance::DeviceOperatio namespace { struct CKArgs { - CKArgs(const miopen::norm::ProblemDescription& problem) + CKArgs(const miopen::layernorm::ProblemDescription& problem) { auto length = problem.GetXDesc().GetLengths(); @@ -148,7 +148,7 @@ struct CKArgs } // namespace template -bool CheckCKApplicability(const miopen::norm::ProblemDescription& problem) +bool CheckCKApplicability(const miopen::layernorm::ProblemDescription& problem) { const auto ln_args = CKArgs{problem}; const auto ln_ptrs = DeviceOpType::GetInstances(); @@ -160,7 +160,7 @@ bool CheckCKApplicability(const miopen::norm::ProblemDescription& problem) template typename LnPtrsType::iterator FindLnPtr(LnPtrsType& ln_ptrs, - const miopen::norm::ProblemDescription& problem) + const miopen::layernorm::ProblemDescription& problem) { const auto ln_args = CKArgs{problem}; return std::find_if(ln_ptrs.begin(), ln_ptrs.end(), [&ln_args](auto& ln_ptrs) { @@ -170,7 +170,7 @@ typename LnPtrsType::iterator FindLnPtr(LnPtrsType& ln_ptrs, template ConvSolution MakeInvokerFactory([[maybe_unused]] const ExecutionContext& context, - const miopen::norm::ProblemDescription& problem) + const miopen::layernorm::ProblemDescription& problem) { auto ln_ptr = DeviceOpType::GetInstances(); auto ln_ptr_iter = FindLnPtr(ln_ptr, problem); @@ -211,14 +211,14 @@ ConvSolution MakeInvokerFactory([[maybe_unused]] const ExecutionContext& context } #endif -bool IsRank4Dim1(const miopen::norm::ProblemDescription& problem) +bool IsRank4Dim1(const miopen::layernorm::ProblemDescription& problem) { return (problem.GetXDesc().GetLengths().size() == 4) && (problem.GetNormalizedDim() == 1); } bool Layernorm4DCKForward::IsApplicable( [[maybe_unused]] const ExecutionContext& context, - [[maybe_unused]] const miopen::norm::ProblemDescription& problem) const + [[maybe_unused]] const miopen::layernorm::ProblemDescription& problem) const { #if MIOPEN_USE_COMPOSABLEKERNEL if(miopen::IsDisabled(ENV(MIOPEN_DEBUG_LAYERNORM4DCKFORWARD_CONV_CK_LN))) @@ -255,7 +255,7 @@ bool Layernorm4DCKForward::IsApplicable( ConvSolution Layernorm4DCKForward::GetSolution( [[maybe_unused]] const ExecutionContext& context, - [[maybe_unused]] const miopen::norm::ProblemDescription& problem) const + [[maybe_unused]] const miopen::layernorm::ProblemDescription& problem) const { #if MIOPEN_USE_COMPOSABLEKERNEL switch(problem.GetXDesc().GetType()) @@ -263,11 +263,11 @@ ConvSolution Layernorm4DCKForward::GetSolution( case miopenHalf: return MakeInvokerFactory, CKArgs, - miopen::norm::InvokeParams>(context, problem); + miopen::layernorm::InvokeParams>(context, problem); case miopenFloat: return MakeInvokerFactory, CKArgs, - miopen::norm::InvokeParams>(context, problem); + miopen::layernorm::InvokeParams>(context, problem); case miopenDouble: case miopenBFloat16: case miopenInt8: @@ -282,6 +282,6 @@ ConvSolution Layernorm4DCKForward::GetSolution( return {}; } -} // namespace norm +} // namespace layernorm } // namespace solver } // namespace miopen diff --git a/test/cpu_groupnorm.hpp b/test/cpu_groupnorm.hpp new file mode 100644 index 0000000000..6ac92ed4ed --- /dev/null +++ b/test/cpu_groupnorm.hpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_CPU_GROUPNORM_HPP +#define GUARD_CPU_GROUPNORM_HPP + +#include "tensor_holder.hpp" + +template +void cpu_groupnorm_forward(tensor input, + tensor weight, + tensor bias, + tensor& ref_output, + tensor& ref_mean, + tensor& ref_rstd, + uint64_t num_groups, + float eps, + miopenNormMode_t mode) +{ + auto dims = input.desc.GetLengths(); + + size_t numel = input.desc.GetElementSize(); + size_t numel_per_channel = numel / dims[0] / dims[1]; + size_t num_channels = dims[1]; + + size_t outer_size = dims[0] * num_groups; + size_t inner_size = numel / outer_size; + + par_ford(outer_size)([&](int32_t o) { + T mean_v = 0.0f; + T var_v = 0.0f; + + ford(inner_size)([&](int32_t i) { + T tmp = input[o * inner_size + i]; + mean_v += tmp; + var_v += tmp * tmp; + }); + + mean_v = mean_v / inner_size; + var_v = var_v / inner_size - mean_v * mean_v; + T rstd_v = 1.0f / sqrt(var_v + eps); + + ref_mean[o] = mean_v; + ref_rstd[o] = rstd_v; + + ford(inner_size)([&](int32_t i) { + size_t idx = o * inner_size + i; + size_t c = (idx / numel_per_channel) % num_channels; + T weight_v = mode ? weight[c] : 1; + T bias_v = mode ? bias[c] : 0; + + ref_output[idx] = (input[idx] - mean_v) * rstd_v * weight_v + bias_v; + }); + }); +} +#endif diff --git a/test/cpu_layernorm.hpp b/test/cpu_layernorm.hpp index 9f89249a1b..5190fc6d9a 100644 --- a/test/cpu_layernorm.hpp +++ b/test/cpu_layernorm.hpp @@ -37,7 +37,7 @@ void cpu_layernorm_forward(tensor input, tensor& ref_rstd, float eps, int32_t dim, - miopenLayerNormMode_t mode) + miopenNormMode_t mode) { auto dims = input.desc.GetLengths(); size_t outer_size = 1; diff --git a/test/gtest/groupnorm.cpp b/test/gtest/groupnorm.cpp new file mode 100644 index 0000000000..58d0685c93 --- /dev/null +++ b/test/gtest/groupnorm.cpp @@ -0,0 +1,66 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include "groupnorm.hpp" + +MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) + +std::string GetFloatArg() +{ + const auto& tmp = miopen::GetStringEnv(ENV(MIOPEN_TEST_FLOAT_ARG)); + if(tmp.empty()) + { + return ""; + } + return tmp; +} + +struct GroupNormTestFloat : GroupNormTest +{ +}; + +TEST_P(GroupNormTestFloat, GroupNormTestFw) +{ + const auto& handle = get_handle(); + + if((miopen::StartsWith(handle.GetDeviceName(), "gfx908") || + miopen::StartsWith(handle.GetDeviceName(), "gfx90a") || + miopen::StartsWith(handle.GetDeviceName(), "gfx94")) && + miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +INSTANTIATE_TEST_SUITE_P(GroupNormTestSet, + GroupNormTestFloat, + testing::ValuesIn(GroupNormTestConfigs())); diff --git a/test/gtest/groupnorm.hpp b/test/gtest/groupnorm.hpp new file mode 100644 index 0000000000..910c1b0118 --- /dev/null +++ b/test/gtest/groupnorm.hpp @@ -0,0 +1,294 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include + +#include "tensor_holder.hpp" +#include "cpu_groupnorm.hpp" +#include "get_handle.hpp" +#include "random.hpp" +#include "../driver/tensor_driver.hpp" +#include "verify.hpp" +#include + +struct GroupNormTestCase +{ + size_t N; + size_t C; + size_t D; + size_t H; + size_t W; + size_t num_groups; + float eps; + miopenNormMode_t mode; + friend std::ostream& operator<<(std::ostream& os, const GroupNormTestCase& tc) + { + return os << " N:" << tc.N << " C:" << tc.C << " D:" << tc.D << " H:" << tc.H + << " W:" << tc.W << " num_groups:" << tc.num_groups << " eps:" << tc.eps + << " mode:" << tc.mode; + } + + std::vector GetInput() + { + if((N != 0) && (C != 0) && (D != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, D, H, W}); + } + else if((N != 0) && (C != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, H, W}); + } + else if((N != 0) && (C != 0) && (W != 0)) + { + return std::vector({N, C, W}); + } + else + { + std::cout << "Error Input Tensor Lengths\n" << std::endl; + return std::vector({0}); + } + } +}; + +std::vector GroupNormTestConfigs() +{ // n c d h w num_groups eps mode + + return {{32, 1, 32, 32, 32, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 1, 14, 14, 14, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 14, 14, 14, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 12, 12, 12, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 6, 6, 6, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {256, 1, 32, 32, 32, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {256, 32, 14, 14, 14, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {256, 32, 12, 12, 12, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {256, 32, 6, 6, 6, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {512, 1, 32, 32, 32, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {512, 32, 14, 14, 14, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {512, 32, 12, 12, 12, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {512, 32, 6, 6, 6, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 2, 32, 57, 125, 2, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 14, 25, 59, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 6, 10, 27, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 4, 6, 11, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 2, 2, 3, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 32, 28, 62, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 14, 12, 29, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 6, 4, 12, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 32, 4, 2, 2, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {16, 32, 6, 50, 50, 4, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + // {1, 3, 8, 240, 320, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + // {1, 3, 16, 240, 320, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + // {1, 3, 8, 128, 171, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + // {1, 3, 16, 128, 171, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + // {1, 3, 8, 112, 112, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + // {1, 3, 16, 112, 112, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 1, 32, 32, 32, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 1, 14, 14, 14, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 14, 14, 14, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 12, 12, 12, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 6, 6, 6, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {256, 1, 32, 32, 32, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + {256, 32, 14, 14, 14, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {256, 32, 12, 12, 12, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {256, 32, 6, 6, 6, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {512, 1, 32, 32, 32, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + {512, 32, 14, 14, 14, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {512, 32, 12, 12, 12, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {512, 32, 6, 6, 6, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 2, 32, 57, 125, 2, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 14, 25, 59, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 6, 10, 27, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 4, 6, 11, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 2, 2, 3, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 32, 28, 62, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 14, 12, 29, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 6, 4, 12, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 32, 4, 2, 2, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + {16, 32, 6, 50, 50, 4, 1e-5, MIOPEN_WEIGHT_BIAS}, + // {1, 3, 8, 240, 320, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + // {1, 3, 16, 240, 320, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + // {1, 3, 8, 128, 171, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + // {1, 3, 16, 128, 171, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + // {1, 3, 8, 112, 112, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + // {1, 3, 16, 112, 112, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 4, 0, 4, 256, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {64, 4, 0, 4, 256, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 4, 0, 4, 256, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + {64, 4, 0, 4, 256, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 1, 0, 0, 256, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {64, 1, 0, 0, 256, 1, 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 1, 0, 0, 256, 1, 1e-5, MIOPEN_WEIGHT_BIAS}, + {64, 1, 0, 0, 256, 1, 1e-5, MIOPEN_WEIGHT_BIAS}}; +} + +inline int32_t SetTensorLayout(miopen::TensorDescriptor& desc) +{ + const std::vector lens = desc.GetLengths(); + std::vector int32_t_lens(lens.begin(), lens.end()); + + // set the strides for the tensor + return SetTensorNd(&desc, int32_t_lens, desc.GetType()); +} + +template +struct GroupNormTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + groupnorm_config = GetParam(); + auto gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(1e-2, 100); }; + + num_groups = groupnorm_config.num_groups; + eps = groupnorm_config.eps; + mode = groupnorm_config.mode; + + std::vector inout_dim = groupnorm_config.GetInput(); + std::vector weight_bias_dim = {inout_dim[1]}; + std::vector mean_rstd_dim = {inout_dim[0], num_groups}; + + input = tensor{inout_dim}.generate(gen_value); + + if(mode == MIOPEN_ELEMENTWISE_AFFINE) + { + auto gen_one = [&](auto...) { return 1; }; + auto gen_zero = [&](auto...) { return 0; }; + weight = tensor{weight_bias_dim}.generate(gen_one); + bias = tensor{weight_bias_dim}.generate(gen_zero); + } + else + { + weight = tensor{weight_bias_dim}.generate(gen_value); + bias = tensor{weight_bias_dim}.generate(gen_value); + } + output = tensor{inout_dim}; + mean = tensor{mean_rstd_dim}; + rstd = tensor{mean_rstd_dim}; + + SetTensorLayout(weight.desc); + SetTensorLayout(bias.desc); + SetTensorLayout(input.desc); + SetTensorLayout(output.desc); + SetTensorLayout(mean.desc); + SetTensorLayout(rstd.desc); + + std::fill(output.begin(), output.end(), std::numeric_limits::quiet_NaN()); + std::fill(mean.begin(), mean.end(), std::numeric_limits::quiet_NaN()); + std::fill(rstd.begin(), rstd.end(), std::numeric_limits::quiet_NaN()); + + ref_output = tensor{inout_dim}; + ref_mean = tensor{mean_rstd_dim}; + ref_rstd = tensor{mean_rstd_dim}; + + std::fill(ref_output.begin(), ref_output.end(), std::numeric_limits::quiet_NaN()); + std::fill(ref_mean.begin(), ref_mean.end(), std::numeric_limits::quiet_NaN()); + std::fill(ref_rstd.begin(), ref_rstd.end(), std::numeric_limits::quiet_NaN()); + + input_dev = handle.Write(input.data); + weight_dev = handle.Write(weight.data); + bias_dev = handle.Write(bias.data); + output_dev = handle.Write(output.data); + mean_dev = handle.Write(mean.data); + rstd_dev = handle.Write(rstd.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + + cpu_groupnorm_forward( + input, weight, bias, ref_output, ref_mean, ref_rstd, num_groups, eps, mode); + miopenStatus_t status; + + status = miopen::GroupNormForward(handle, + input.desc, + input_dev.get(), + weight.desc, + weight_dev.get(), + bias.desc, + bias_dev.get(), + output.desc, + output_dev.get(), + mean.desc, + mean_dev.get(), + rstd.desc, + rstd_dev.get(), + mode, + num_groups, + eps); + + EXPECT_EQ(status, miopenStatusSuccess); + + output.data = handle.Read(output_dev, output.data.size()); + mean.data = handle.Read(mean_dev, mean.data.size()); + rstd.data = handle.Read(rstd_dev, rstd.data.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + auto error = miopen::rms_range(ref_output, output); + + EXPECT_TRUE(miopen::range_distance(ref_output) == miopen::range_distance(output)); + EXPECT_TRUE(error < threshold * 1000) << "Error output beyond tolerance Error:" << error + << ", Thresholdx1000: " << threshold * 1000; + + error = miopen::rms_range(ref_mean, mean); + EXPECT_TRUE(miopen::range_distance(ref_mean) == miopen::range_distance(mean)); + EXPECT_TRUE(error < threshold * 50) << "Error mean beyond tolerance Error:" << error + << ", Thresholdx50: " << threshold * 50; + + error = miopen::rms_range(ref_rstd, rstd); + EXPECT_TRUE(miopen::range_distance(ref_rstd) == miopen::range_distance(rstd)); + EXPECT_TRUE(error < threshold * 2000) << "Error rstd beyond tolerance Error:" << error + << ", Thresholdx2000: " << threshold * 2000; + } + GroupNormTestCase groupnorm_config; + + tensor input; + tensor weight; + tensor bias; + tensor output; + tensor mean; + tensor rstd; + + tensor ref_output; + tensor ref_mean; + tensor ref_rstd; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr weight_dev; + miopen::Allocator::ManageDataPtr bias_dev; + miopen::Allocator::ManageDataPtr output_dev; + miopen::Allocator::ManageDataPtr mean_dev; + miopen::Allocator::ManageDataPtr rstd_dev; + + size_t num_groups; + float eps; + miopenNormMode_t mode; +}; diff --git a/test/gtest/layernorm.hpp b/test/gtest/layernorm.hpp index af6e396bb4..bd3f1cd85e 100644 --- a/test/gtest/layernorm.hpp +++ b/test/gtest/layernorm.hpp @@ -43,7 +43,7 @@ struct LayerNormTestCase size_t W; size_t nomalized_dim; float eps; - miopenLayerNormMode_t ln_mode; + miopenNormMode_t ln_mode; friend std::ostream& operator<<(std::ostream& os, const LayerNormTestCase& tc) { return os << " N:" << tc.N << " C:" << tc.C << " D:" << tc.D << " H:" << tc.H @@ -306,5 +306,5 @@ struct LayerNormTest : public ::testing::TestWithParam size_t nomalized_dim; float eps; - miopenLayerNormMode_t ln_mode; + miopenNormMode_t ln_mode; };