diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 71289c8b42..abc0679a8a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -152,6 +152,7 @@ set( MIOpen_Source solver/activ/bwd_1.cpp solver/activ/fwd_0.cpp solver/activ/fwd_1.cpp + solver/batchnorm/backward_ck.cpp solver/batchnorm/backward_per_activation.cpp solver/batchnorm/backward_per_activation_fused.cpp solver/batchnorm/backward_spatial_multiple.cpp @@ -163,6 +164,7 @@ set( MIOpen_Source solver/batchnorm/forward_per_activation_fused.cpp solver/batchnorm/forward_spatial_multiple.cpp solver/batchnorm/forward_spatial_single.cpp + solver/batchnorm/forward_training_ck.cpp solver/conv_asm_1x1u.cpp solver/conv_asm_1x1u_bias_activ_fused.cpp solver/conv_asm_1x1u_stride2.cpp diff --git a/src/batch_norm_api.cpp b/src/batch_norm_api.cpp index 03db138945..69454b185a 100644 --- a/src/batch_norm_api.cpp +++ b/src/batch_norm_api.cpp @@ -243,13 +243,6 @@ miopenBatchNormalizationBackward(miopenHandle_t handle, const void* savedMean, const void* savedInvVariance) { - // bfloat16 not supported for batchnorm operation - if(miopen::deref(xDesc).GetType() == miopenBFloat16 || - miopen::deref(dyDesc).GetType() == miopenBFloat16 || - miopen::deref(dxDesc).GetType() == miopenBFloat16) - { - return miopenStatusNotImplemented; - } MIOPEN_LOG_FUNCTION(handle, bn_mode, diff --git a/src/include/miopen/batchnorm/solvers.hpp b/src/include/miopen/batchnorm/solvers.hpp index c7d050abeb..70d64bb204 100644 --- a/src/include/miopen/batchnorm/solvers.hpp +++ b/src/include/miopen/batchnorm/solvers.hpp @@ -142,6 +142,26 @@ struct BnCKFwdInference final : BatchnormSolver const miopen::batchnorm::ProblemDescription& problem) const override; }; +struct BnCKBwdBackward final : BatchnormSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::batchnorm::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::batchnorm::ProblemDescription& problem) const override; +}; + +struct BnCKFwdTraining final : BatchnormSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::batchnorm::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::batchnorm::ProblemDescription& problem) const override; +}; + } // namespace batchnorm } // namespace solver diff --git a/src/include/miopen/solver/implicitgemm_ck_util.hpp b/src/include/miopen/solver/implicitgemm_ck_util.hpp index 8656bdbabc..318d970170 100644 --- a/src/include/miopen/solver/implicitgemm_ck_util.hpp +++ b/src/include/miopen/solver/implicitgemm_ck_util.hpp @@ -41,8 +41,10 @@ typename ConvPtrsType::iterator FindConvPtrByID(ConvPtrsType& conv_ptrs, }); } -template -std::vector FillValidKernelsIDs(const ProblemDescription& problem) +template +std::vector FillValidKernelsIDs(const ProblemDescriptionType& problem) { const auto args = CKArgsType{problem}; const auto conv_ptrs = DeviceOpType::GetInstances(); @@ -59,8 +61,10 @@ std::vector FillValidKernelsIDs(const ProblemDescription& problem) return valid_kernels; } -template -bool IsCKArgsSupported(const ProblemDescription& problem, const std::string& kernel_id) +template +bool IsCKArgsSupported(const ProblemDescriptionType& problem, const std::string& kernel_id) { auto conv_ptrs = DeviceOpType::GetInstances(); auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id); @@ -68,20 +72,25 @@ bool IsCKArgsSupported(const ProblemDescription& problem, const std::string& ker return (ptr_iter != conv_ptrs.end()) && CKArgsType{problem}.IsSupportedBy(*ptr_iter); } -template -bool IsCKApplicable(const ProblemDescription& problem) +template +bool IsCKApplicable(const ProblemDescriptionType& problem) { const auto args = CKArgsType{problem}; - if(!std::all_of(args.strides.begin(), args.strides.end(), [](auto x) { return x == 1; })) - return false; + // if(!std::all_of(args.strides.begin(), args.strides.end(), [](auto x) { return x == 1; })) + // return false; const auto ptrs = DeviceOpType::GetInstances(); return std::any_of( ptrs.begin(), ptrs.end(), [&args](auto& ptr) { return args.IsSupportedBy(ptr); }); } -template -ConvSolution InitInvokerFactory(const ProblemDescription& problem, const std::string& kernel_id) +template +ConvSolution InitInvokerFactory(const ProblemDescriptionType& problem, const std::string& kernel_id) { auto conv_ptrs = DeviceOpType::GetInstances(); auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id); @@ -112,5 +121,41 @@ ConvSolution InitInvokerFactory(const ProblemDescription& problem, const std::st return result; } +template +ConvSolution InitAnyInvokerFactory(const ProblemDescriptionType& problem, + const std::string& kernel_id) +{ + auto conv_ptrs = DeviceOpType::GetInstances(); + auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id); + + if(ptr_iter == conv_ptrs.end()) + return {miopenStatusInvalidValue}; + + ConvSolution result; + result.invoker_factory = + [ck_args = CKArgsType{problem}, + sh_conv_ptr = std::shared_ptr{std::move(*ptr_iter)}](const std::vector&) mutable { + return [ck_args = std::move(ck_args), sh_conv_ptr = std::move(sh_conv_ptr)]( + const Handle& handle, const AnyInvokeParams& primitive_parameters) { + const auto& data_ctx = primitive_parameters.CastTo(); + auto argument_ptr = ck_args.MakeArgPtr(sh_conv_ptr, data_ctx); + auto invoker_ptr = sh_conv_ptr->MakeInvokerPointer(); + + const auto enable_profiling = handle.IsProfilingEnabled(); + float elapsed_time = + invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling}); + if(enable_profiling) + { + handle.ResetKernelTime(); + handle.AccumKernelTime(elapsed_time); + } + }; + }; + return result; +} + } // namespace solver } // namespace miopen diff --git a/src/ocl/batchnormocl.cpp b/src/ocl/batchnormocl.cpp index 6c8a079a2a..6147a827b8 100644 --- a/src/ocl/batchnormocl.cpp +++ b/src/ocl/batchnormocl.cpp @@ -131,7 +131,8 @@ void BatchNormForwardTraining(Handle& handle, return tmp; }(); - const auto solvers = solver::SolverContainer{}; @@ -300,7 +301,7 @@ void BatchNormBackward(Handle& handle, { MIOPEN_THROW(miopenStatusBadParm); } - if(dxDesc.GetType() != dyDesc.GetType() || dyDesc.GetType() != xDesc.GetType()) + if(dxDesc.GetType() != dyDesc.GetType()) { MIOPEN_THROW(miopenStatusBadParm); } @@ -338,7 +339,6 @@ void BatchNormBackward(Handle& handle, tmp.dx = dx; tmp.bnScale = bnScale; tmp.resultBnScaleDiff = resultBnScaleDiff; - tmp.resultBnScaleDiff = resultBnScaleDiff; tmp.resultBnBiasDiff = resultBnBiasDiff; tmp.epsilon = epsilon; tmp.savedMean = savedMean; @@ -346,7 +346,8 @@ void BatchNormBackward(Handle& handle, return tmp; }(); - const auto solvers = solver::SolverContainer{}; diff --git a/src/solver.cpp b/src/solver.cpp index d83935e646..4cd680dd9c 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -569,6 +569,8 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) RegisterWithSolver( registry, ++id, ConvHipImplicitGemm3DGroupBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdInference{}.SolverDbId()); + Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKBwdBackward{}.SolverDbId()); + Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdTraining{}.SolverDbId()); // IMPORTANT: New solvers should be added to the end of the function! } diff --git a/src/solver/batchnorm/backward_ck.cpp b/src/solver/batchnorm/backward_ck.cpp new file mode 100644 index 0000000000..fba8724990 --- /dev/null +++ b/src/solver/batchnorm/backward_ck.cpp @@ -0,0 +1,251 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL +#include +#include +#include +#endif +MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_CK_BN_BACK) + +namespace miopen { +namespace solver { +namespace batchnorm { +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using index_t = int32_t; + +constexpr index_t Rank = 4; +constexpr index_t NumBatchNormReduceDim = 3; + +using F16 = ck::half_t; +using F32 = float; +using F64 = double; +using BF16 = ushort; + +template +using DeviceOpBNBwdPtrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchNormBwd>; + +struct CKArgsBNormBwd +{ + CKArgsBNormBwd(const miopen::batchnorm::ProblemDescription& problem) + { + std::copy(problem.GetXDesc().GetLengths().begin(), + problem.GetXDesc().GetLengths().end(), + lens.begin()); + + std::copy(problem.GetXDesc().GetStrides().begin(), + problem.GetXDesc().GetStrides().end(), + strides.begin()); + arrScaleBiasMeanVarLengths[0] = lens[1]; // get channel + arrScaleBiasMeanVarStrides[0] = 1; + + // prep for CK + std::sort(strides.begin(), strides.end(), std::greater<>()); + std::rotate(lens.begin() + 1, lens.begin() + 2, lens.end()); + } + + CKArgsBNormBwd(const CKArgsBNormBwd&) = default; + CKArgsBNormBwd(CKArgsBNormBwd&&) = default; + CKArgsBNormBwd& operator=(const CKArgsBNormBwd&) = default; + + template + auto MakeArgPtr(const InvokerPtr& invoker_ptr, const InvokerParams& data_ctx) const + { + return invoker_ptr->MakeArgumentPointer(lens, + strides, + strides, + strides, + reduceDims, + arrScaleBiasMeanVarLengths, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + data_ctx.x, + data_ctx.dy, + data_ctx.bnScale, + data_ctx.savedMean, + data_ctx.savedInvVariance, + epsilon, + PassThrough{}, + data_ctx.dx, + data_ctx.resultBnScaleDiff, + data_ctx.resultBnBiasDiff); + } + + template + bool IsSupportedBy(const ConvPtr& invoker_ptr) const + { + auto arg_ptr = MakeArgPtr(invoker_ptr, miopen::batchnorm::BwdInvokeParams{}); + return invoker_ptr->IsSupportedArgument(arg_ptr.get()); + } + + std::array lens; // inOutLengths + std::array strides; // inOutStrides + std::vector invariantDims; + + std::array arrScaleBiasMeanVarLengths; + std::array arrScaleBiasMeanVarStrides; + + double epsilon = 1e-5; + std::array reduceDims{0, 1, 2}; +}; + +template +static bool CheckCKApplicability(const miopen::batchnorm::ProblemDescription& problem) +{ + return IsCKApplicable, + CKArgsBNormBwd>(problem); +} + +#endif + +bool BnCKBwdBackward::IsApplicable(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& bn_problem) const +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = ctx; + std::ignore = fdesc_problem; + return false; +#else + if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_CK_BN_BACK{})) + return false; + if(!bn_problem.IsLayoutNHWC()) + return false; + if(!ck_utility::is_ck_supported_hardware(ctx.GetStream())) + return false; + if(bn_problem.GetXDesc().GetType() != bn_problem.GetScaleBiasDiffDesc().GetType()) + return false; + + switch(bn_problem.GetXDesc().GetType()) + { + case miopenFloat: return CheckCKApplicability(bn_problem); + case miopenDouble: return CheckCKApplicability(bn_problem); + case miopenHalf: return CheckCKApplicability(bn_problem); + case miopenBFloat16: + return CheckCKApplicability(bn_problem); + case miopenInt32: + case miopenInt8: + case miopenInt8x4: + case miopenBFloat8: + case miopenFloat8: + default: MIOPEN_THROW("Unsupported datatype"); + } + return false; +#endif +} + +template +ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem) +{ + const auto& valid_kernel_ids = FillValidKernelsIDs, + CKArgsBNormBwd>(bn_problem); + assert(!valid_kernel_ids.empty()); + const auto& kernel_id = valid_kernel_ids[0]; + return InitAnyInvokerFactory, + CKArgsBNormBwd, + miopen::batchnorm::BwdInvokeParams>(bn_problem, kernel_id); +} + +ConvSolution BnCKBwdBackward::GetSolution( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + switch(bn_problem.GetXDesc().GetType()) + { + + case miopenFloat: return MakeAnyInvokerFactory(bn_problem); + case miopenDouble: return MakeAnyInvokerFactory(bn_problem); + case miopenHalf: return MakeAnyInvokerFactory(bn_problem); + case miopenBFloat16: + return MakeAnyInvokerFactory(bn_problem); + case miopenInt8: + case miopenInt32: + case miopenInt8x4: + case miopenBFloat8: + case miopenFloat8: + default: + MIOPEN_THROW(miopenStatusInternalError, "BnCKBwdBackward operation not for this data type"); + } +#endif + return {}; +} + +} // namespace batchnorm +} // namespace solver +} // namespace miopen diff --git a/src/solver/batchnorm/forward_training_ck.cpp b/src/solver/batchnorm/forward_training_ck.cpp new file mode 100644 index 0000000000..a65cec14a9 --- /dev/null +++ b/src/solver/batchnorm/forward_training_ck.cpp @@ -0,0 +1,239 @@ + +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL +#include +#include +#include +#endif +MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_CK_BN_FWD_TRAINING) + +namespace miopen { +namespace solver { +namespace batchnorm { +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + +using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; +using index_t = int32_t; + +constexpr index_t Rank = 4; +constexpr index_t NumBatchNormReduceDim = 3; + +using F16 = ck::half_t; +using F32 = float; +using F64 = double; +using BF16 = ushort; + +template +using DeviceOpBNFwdTrainingPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchNormFwd>; + +struct CKArgsBNormFwdTraining +{ + CKArgsBNormFwdTraining(const miopen::batchnorm::ProblemDescription& problem) + { + std::copy(problem.GetXDesc().GetLengths().begin(), + problem.GetXDesc().GetLengths().end(), + xyLengths.begin()); + + std::copy(problem.GetXDesc().GetStrides().begin(), + problem.GetXDesc().GetStrides().end(), + xyStrides.begin()); + arrScaleBiasMeanVarLengths[0] = xyLengths[1]; // get channel + arrScaleBiasMeanVarStrides[0] = 1; + + // prep for CK + std::sort(xyStrides.begin(), xyStrides.end(), std::greater<>()); + std::rotate(xyLengths.begin() + 1, xyLengths.begin() + 2, xyLengths.end()); + } + + CKArgsBNormFwdTraining(const CKArgsBNormFwdTraining&) = default; + CKArgsBNormFwdTraining(CKArgsBNormFwdTraining&&) = default; + CKArgsBNormFwdTraining& operator=(const CKArgsBNormFwdTraining&) = default; + + template + auto MakeArgPtr(const InvokerPtr& invoker_ptr, const InvokerParams& data_ctx) const + { + return invoker_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + xyStrides, + reduceDims, + arrScaleBiasMeanVarLengths, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + data_ctx.x, + data_ctx.bnScale, + data_ctx.bnBias, + data_ctx.epsilon, + PassThroughOp{}, + data_ctx.y, + data_ctx.resultSaveMean, + data_ctx.resultSaveInvVariance, + data_ctx.expAvgFactor, + data_ctx.resultRunningMean, + data_ctx.resultRunningVariance); + } + + template + bool IsSupportedBy(const ConvPtr& invoker_ptr) const + { + auto arg_ptr = MakeArgPtr(invoker_ptr, miopen::batchnorm::InvokeParams{}); + return invoker_ptr->IsSupportedArgument(arg_ptr.get()); + } + + std::array xyLengths; + std::array xyStrides; + std::vector invariantDims; + + std::array arrScaleBiasMeanVarLengths; + std::array arrScaleBiasMeanVarStrides; + + std::array reduceDims{0, 1, 2}; +}; + +template +static bool CheckCKApplicability(const miopen::batchnorm::ProblemDescription& problem) +{ + return IsCKApplicable, + CKArgsBNormFwdTraining>(problem); +} +#endif + +bool BnCKFwdTraining::IsApplicable(const ExecutionContext& context, + const miopen::batchnorm::ProblemDescription& bn_problem) const +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = context; + std::ignore = fdesc_problem; + return false; +#else + if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_CK_BN_FWD_TRAINING{})) + return false; + if(!bn_problem.IsLayoutNHWC()) + return false; + if(!ck_utility::is_ck_supported_hardware(context.GetStream())) + return false; + + switch(bn_problem.GetXDesc().GetType()) + { + case miopenHalf: return CheckCKApplicability(bn_problem); + case miopenFloat: return CheckCKApplicability(bn_problem); + case miopenDouble: return CheckCKApplicability(bn_problem); + case miopenBFloat16: return CheckCKApplicability(bn_problem); + case miopenInt32: + case miopenInt8: + case miopenInt8x4: + case miopenBFloat8: + case miopenFloat8: + default: MIOPEN_THROW("BnCKFwdTraining operation does not supprot this data type"); + } + return false; +#endif +} + +template +ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem) +{ + const auto& valid_kernel_ids = FillValidKernelsIDs, + CKArgsBNormFwdTraining>(bn_problem); + assert(!valid_kernel_ids.empty()); + const auto& kernel_id = valid_kernel_ids[0]; + return InitAnyInvokerFactory, + CKArgsBNormFwdTraining, + miopen::batchnorm::InvokeParams>(bn_problem, kernel_id); +} + +ConvSolution BnCKFwdTraining::GetSolution( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + switch(bn_problem.GetXDesc().GetType()) + { + + case miopenFloat: return MakeAnyInvokerFactory(bn_problem); + case miopenDouble: return MakeAnyInvokerFactory(bn_problem); + case miopenHalf: return MakeAnyInvokerFactory(bn_problem); + case miopenBFloat16: return MakeAnyInvokerFactory(bn_problem); + case miopenInt8: + case miopenInt32: + case miopenInt8x4: + case miopenBFloat8: + case miopenFloat8: + default: + MIOPEN_THROW(miopenStatusInternalError, "BnCKFwdTraining operation not for this data type"); + } +#endif + return {}; +} + +} // namespace batchnorm +} // namespace solver +} // namespace miopen diff --git a/test/bn_spatial_nhwc_test.cpp b/test/bn_spatial_nhwc_test.cpp deleted file mode 100644 index abca57e7ce..0000000000 --- a/test/bn_spatial_nhwc_test.cpp +++ /dev/null @@ -1,749 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#include "driver.hpp" -#include "get_handle.hpp" -#include "tensor_holder.hpp" -#include "test.hpp" -#include "verify.hpp" -#include "random.hpp" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define MIO_BN_TEST_EXPAVGFACTOR 0.1 -#define MIO_BN_TEST_EPSILON 1e-5 -#define MIO_BN_USE_MIX_PREC 1 -#if MIO_BN_USE_MIX_PREC == 1 -#define PREC_TYPE float -#else -#define PREC_TYPE T -#endif - -template -struct verify_forward_train_bn_spatial -{ - const tensor input; - const tensor scale; - const tensor shift; - - std::tuple, tensor, tensor, tensor, tensor> cpu() const - { - double epsilon = MIO_BN_TEST_EPSILON; - double expAvgFactor = MIO_BN_TEST_EXPAVGFACTOR; - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(input.desc.GetLengths()); - - std::size_t rs_n_batch, rs_channels, rs_height, rs_width; - auto derivedBnDesc = - miopen::TensorDescriptor(input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(rs_n_batch, rs_height, rs_width, rs_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - tensor runMean; - tensor runVar; - if(input.desc.GetType() == miopenFloat) - { - runMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}.generate( - tensor_elem_gen_integer{17}); - runVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}.generate( - tensor_elem_gen_integer{17}); - } - else - { - prng::reset_seed(); - runMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - runVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - - const U Data_scale = static_cast(0.001); - for(std::size_t i = 0; i < runMean.desc.GetElementSize(); i++) - { - runMean[i] = prng::gen_descreet_uniform_sign(Data_scale, 100); - runVar[i] = prng::gen_descreet_unsigned(Data_scale, 100); - } - } - auto saveMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - auto saveInvVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - auto out = input; - std::fill(out.begin(), out.end(), 0); - - const auto nhw = double(height * width * n_batch); - par_for(channels, 1, [&](int cidx) { - double elemStd = 0.; - double variance_accum = 0.; - double mean_accum = 0.; - double invVar = 0.; - double newRunMean = 0.; - double adjust = 0.; - - std::vector variance_accum_arr(height, 0.0); - std::vector mean_accum_arr(height, 0.0); - std::vector dshift_accum_arr(height, 0.0); - std::vector dscale_accum_arr(height, 0.0); - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - mean_accum_arr[row] += input(bidx, cidx, row, column); - } - } - } - for(std::size_t i = 0; i < height; i++) - mean_accum += mean_accum_arr[i]; - - mean_accum /= nhw; - - elemStd = 0.; - variance_accum = 0.; - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - out(bidx, cidx, row, column) = elemStd = - input(bidx, cidx, row, column) - mean_accum; - variance_accum_arr[row] += elemStd * elemStd; - } - } - } - for(std::size_t i = 0; i < height; i++) - variance_accum += variance_accum_arr[i]; - - variance_accum /= nhw; - invVar = 1.0 / sqrt(variance_accum + epsilon); - - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - out(bidx, cidx, row, column) = - scale(0, 0, 0, cidx) * (invVar * out(bidx, cidx, row, column)) + - shift(0, 0, 0, cidx); - } - } - } - - saveMean(0, 0, 0, cidx) = mean_accum; - saveInvVar(0, 0, 0, cidx) = invVar; - - newRunMean = runMean(0, 0, 0, cidx) * (1 - expAvgFactor); - runMean(0, 0, 0, cidx) = mean_accum * expAvgFactor + newRunMean; - adjust = (n_batch * height * width == 1) ? variance_accum - : (nhw / (nhw - 1)) * variance_accum; - runVar(0, 0, 0, cidx) = - (1 - expAvgFactor) * runVar(0, 0, 0, cidx) + expAvgFactor * adjust; - }); - - return std::make_tuple(out, runMean, runVar, saveMean, saveInvVar); - } - - std::tuple, tensor, tensor, tensor, tensor> gpu() const - { - auto&& handle = get_handle(); - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(input.desc.GetLengths()); - - auto out = input; - std::fill(out.begin(), out.end(), 0); - - std::size_t rs_n_batch, rs_channels, rs_height, rs_width; - auto derivedBnDesc = - miopen::TensorDescriptor(input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(rs_n_batch, rs_height, rs_width, rs_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - tensor runMean; - tensor runVar; - if(input.desc.GetType() == miopenFloat) - { - runMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}.generate( - tensor_elem_gen_integer{17}); - runVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}.generate( - tensor_elem_gen_integer{17}); - } - else - { - prng::reset_seed(); - runMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - runVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - - const U Data_scale = static_cast(0.001); - for(std::size_t i = 0; i < runMean.desc.GetElementSize(); i++) - { - runMean[i] = prng::gen_descreet_uniform_sign(Data_scale, 100); - runVar[i] = prng::gen_descreet_unsigned(Data_scale, 100); - } - } - - auto saveMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - auto saveInvVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - - auto in_dev = handle.Write(input.data); - auto scale_dev = handle.Write(scale.data); - auto shift_dev = handle.Write(shift.data); - - auto runMean_dev = handle.Write(runMean.data); - auto runVar_dev = handle.Write(runVar.data); - auto saveMean_dev = handle.Create(channels); - auto saveInvVar_dev = handle.Create(channels); - auto out_dev = handle.Create(n_batch * channels * height * width); - - double epsilon = MIO_BN_TEST_EPSILON; - double expAvgFactor = MIO_BN_TEST_EXPAVGFACTOR; - - float alpha = 1.0; - float beta = 0.0; - - miopen::BatchNormForwardTraining(handle, - miopenBNSpatial, - &alpha, - &beta, - input.desc, - in_dev.get(), - out.desc, - out_dev.get(), - scale.desc, - scale_dev.get(), - shift_dev.get(), - expAvgFactor, - runMean_dev.get(), - runVar_dev.get(), - epsilon, - saveMean_dev.get(), - saveInvVar_dev.get()); - - saveMean.data = handle.Read(saveMean_dev, saveMean.data.size()); - saveInvVar.data = handle.Read(saveInvVar_dev, saveInvVar.data.size()); - runMean.data = handle.Read(runMean_dev, runMean.data.size()); - runVar.data = handle.Read(runVar_dev, runVar.data.size()); - out.data = handle.Read(out_dev, out.data.size()); - - return std::make_tuple(out, runMean, runVar, saveMean, saveInvVar); - } - - void fail(int badtensor) const - { - std::cout << "Forward Train Spatial Batch Normalization: " << std::endl; - std::cout << "Input tensor: " << input.desc.ToString() << std::endl; - - switch(badtensor) - { - case(0): std::cout << "Output tensor output failed verification." << std::endl; break; - case(1): std::cout << "Running Mean output tensor failed verification." << std::endl; break; - case(2): - std::cout << "Running Variance output tensor failed verification." << std::endl; - break; - case(3): std::cout << "Saved Mean tensor failed verification." << std::endl; break; - case(4): std::cout << "Saved Variance tensor failed verification." << std::endl; break; - default: break; - } - } -}; - -template -struct verify_backward_bn_spatial_recalc -{ - const tensor x_input; - const tensor dy_input; - const tensor scale; - - std::tuple, tensor, tensor> cpu() const - { - double epsilon = MIO_BN_TEST_EPSILON; - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(x_input.desc.GetLengths()); - - std::size_t ss_n_batch, ss_channels, ss_height, ss_width; - auto derivedBnDesc = - miopen::TensorDescriptor(x_input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(ss_n_batch, ss_height, ss_width, ss_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - auto dx_out = dy_input; - std::fill(dx_out.begin(), dx_out.end(), 0); - - auto dscale = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dscale.begin(), dscale.end(), 0); - - auto dshift = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dshift.begin(), dshift.end(), 0); - - const auto nhw = double(height * width * n_batch); - - par_for(channels, 1, [&](int cidx) { - double elemStd = 0.; - unsigned int xhat_index; - double mean = 0.; - double invVar = 0.; - double dyelem = 0.; - double variance = 0.; - - std::vector xhat(height * width * n_batch, 0.0); - std::vector variance_accum_arr(height, 0.0); - std::vector mean_accum_arr(height, 0.0); - std::vector dshift_accum_arr(height, 0.0); - std::vector dscale_accum_arr(height, 0.0); - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - mean_accum_arr[row] += x_input(bidx, cidx, row, column); - } - } - } - for(std::size_t i = 0; i < height; i++) - mean += mean_accum_arr[i]; - - mean /= nhw; - - elemStd = 0.; - variance = 0.; - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - elemStd = x_input(bidx, cidx, row, column) - mean; - variance_accum_arr[row] += elemStd * elemStd; - } - } - } - for(std::size_t i = 0; i < height; i++) - variance += variance_accum_arr[i]; - - variance /= nhw; - invVar = 1. / double(sqrt(variance + epsilon)); - - dscale(0, cidx, 0, 0) = 0.; - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - xhat_index = height * width * bidx + (width * row + column); - elemStd = x_input(bidx, cidx, row, column) - mean; - xhat[xhat_index] = elemStd * invVar; - dyelem = dy_input(bidx, cidx, row, column); - dshift_accum_arr[row] += dyelem; - dscale_accum_arr[row] += xhat[xhat_index] * dyelem; - } - } - } - for(std::size_t i = 0; i < height; i++) - { - dshift(0, cidx, 0, 0) += dshift_accum_arr[i]; - dscale(0, cidx, 0, 0) += dscale_accum_arr[i]; - } - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - xhat_index = height * width * bidx + (width * row + column); - - double tmp1 = - nhw * dy_input(bidx, cidx, row, column) - dshift(0, cidx, 0, 0); - double tmp2 = -xhat[xhat_index] * dscale(0, cidx, 0, 0); - double tmp3 = (scale(0, 0, 0, cidx) * invVar) / nhw; - dx_out(bidx, cidx, row, column) = tmp3 * (tmp2 + tmp1); - } - } - } - }); - - return std::make_tuple(dx_out, dscale, dshift); - } - - std::tuple, tensor, tensor> gpu() const - { - auto&& handle = get_handle(); - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(x_input.desc.GetLengths()); - - auto dx_out = dy_input; - std::fill(dx_out.begin(), dx_out.end(), 0); - - std::size_t ss_n_batch, ss_channels, ss_height, ss_width; - auto derivedBnDesc = - miopen::TensorDescriptor(x_input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(ss_n_batch, ss_height, ss_width, ss_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - auto dscale = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dscale.begin(), dscale.end(), 0); - - auto dshift = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dshift.begin(), dshift.end(), 0); - - float alpha = 1.0; - float beta = 0.0; - - auto xin_dev = handle.Write(x_input.data); - auto dyin_dev = handle.Write(dy_input.data); - auto scale_dev = handle.Write(scale.data); - auto dscale_dev = handle.Write(dscale.data); - auto dshift_dev = handle.Write(dshift.data); - auto dx_out_dev = handle.Write(dx_out.data); - - double epsilon = MIO_BN_TEST_EPSILON; - - miopen::BatchNormBackward(handle, - miopenBNSpatial, - &alpha, - &beta, - &alpha, - &beta, - x_input.desc, - xin_dev.get(), - dy_input.desc, - dyin_dev.get(), - dx_out.desc, - dx_out_dev.get(), - scale.desc, - scale_dev.get(), - dscale_dev.get(), - dshift_dev.get(), - epsilon, - nullptr, - nullptr); - - dx_out.data = handle.Read(dx_out_dev, dx_out.data.size()); - dscale.data = handle.Read(dscale_dev, dscale.data.size()); - dshift.data = handle.Read(dshift_dev, dshift.data.size()); - - return std::make_tuple(dx_out, dscale, dshift); - } - - void fail(int badtensor) const - { - std::cout << "Backward Batch Spatial Normalization Recalc Mean and Variance: " << std::endl; - std::cout << "X Input tensor: " << x_input.desc.ToString() << std::endl; - std::cout << "Delta Y Input tensor: " << dy_input.desc.ToString() << std::endl; - switch(badtensor) - { - case(0): - std::cout << "Delta X output tensor output failed verification." << std::endl; - break; - case(1): std::cout << "Delta scale output tensor failed verification." << std::endl; break; - case(2): std::cout << "Delta shift output tensor failed verification." << std::endl; break; - default: break; - } - } -}; - -template -struct verify_backward_bn_spatial_use_saved -{ - const tensor x_input; - const tensor dy_input; - const tensor scale; - const tensor savedMean; - const tensor savedInvVar; - std::tuple, tensor, tensor> cpu() const - { - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(x_input.desc.GetLengths()); - - auto dx_out = dy_input; - std::fill(dx_out.begin(), dx_out.end(), 0); - - std::size_t ss_n_batch, ss_channels, ss_height, ss_width; - auto derivedBnDesc = - miopen::TensorDescriptor(x_input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(ss_n_batch, ss_height, ss_width, ss_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - auto dscale = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dscale.begin(), dscale.end(), 0); - - auto dshift = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dshift.begin(), dshift.end(), 0); - - const auto nhw = double(height * width * n_batch); - - par_for(channels, 1, [&](int cidx) { - double elemStd = 0.; - unsigned int xhat_index; - double mean = savedMean(0, 0, 0, cidx); - double invVar = savedInvVar(0, 0, 0, cidx); - double dyelem = 0.; - - std::vector xhat(n_batch * height * width, 0.0); - std::vector dshift_accum_arr(height, 0.0); - std::vector dscale_accum_arr(height, 0.0); - dscale(0, cidx, 0, 0) = 0.; - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - xhat_index = height * width * bidx + (width * row + column); - elemStd = x_input(bidx, cidx, row, column) - mean; - xhat[xhat_index] = elemStd * invVar; - dyelem = dy_input(bidx, cidx, row, column); - dshift_accum_arr[row] += dyelem; - dscale_accum_arr[row] += xhat[xhat_index] * dyelem; - } - } - } - for(std::size_t i = 0; i < height; i++) - { - dshift(0, cidx, 0, 0) += dshift_accum_arr[i]; - dscale(0, cidx, 0, 0) += dscale_accum_arr[i]; - } - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - xhat_index = height * width * bidx + (width * row + column); - - double tmp1 = - nhw * dy_input(bidx, cidx, row, column) - dshift(0, cidx, 0, 0); - double tmp2 = -xhat[xhat_index] * dscale(0, cidx, 0, 0); - double tmp3 = (scale(0, 0, 0, cidx) * invVar) / nhw; - dx_out(bidx, cidx, row, column) = tmp3 * (tmp2 + tmp1); - } - } - } - }); - - return std::make_tuple(dx_out, dscale, dshift); - } - - std::tuple, tensor, tensor> gpu() const - { - auto&& handle = get_handle(); - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(x_input.desc.GetLengths()); - - auto dx_out = dy_input; - std::fill(dx_out.begin(), dx_out.end(), 0); - - std::size_t ss_n_batch, ss_channels, ss_height, ss_width; - auto derivedBnDesc = - miopen::TensorDescriptor(x_input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(ss_n_batch, ss_height, ss_width, ss_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - auto dscale = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dscale.begin(), dscale.end(), 0); - - auto dshift = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dshift.begin(), dshift.end(), 0); - - float alpha = 1.0; - float beta = 0.0; - - auto xin_dev = handle.Write(x_input.data); - auto dyin_dev = handle.Write(dy_input.data); - auto scale_dev = handle.Write(scale.data); - auto dscale_dev = handle.Write(dscale.data); - auto dshift_dev = handle.Write(dshift.data); - auto dx_out_dev = handle.Write(dx_out.data); - auto savedMean_dev = handle.Write(savedMean.data); - auto savedInvVar_dev = handle.Write(savedInvVar.data); - - double epsilon = MIO_BN_TEST_EPSILON; - - miopen::BatchNormBackward(handle, - miopenBNSpatial, - &alpha, - &beta, - &alpha, - &beta, - x_input.desc, - xin_dev.get(), - dy_input.desc, - dyin_dev.get(), - dx_out.desc, - dx_out_dev.get(), - scale.desc, - scale_dev.get(), - dscale_dev.get(), - dshift_dev.get(), - epsilon, - savedMean_dev.get(), - savedInvVar_dev.get()); - - dx_out.data = handle.Read(dx_out_dev, dx_out.data.size()); - dscale.data = handle.Read(dscale_dev, dscale.data.size()); - dshift.data = handle.Read(dshift_dev, dshift.data.size()); - - return std::make_tuple(dx_out, dscale, dshift); - } - - void fail(int badtensor) const - { - std::cout << "Backward Batch Spatial Normalization Use Saved Mean and Variance: " - << std::endl; - std::cout << "X Input tensor: " << x_input.desc.ToString() << std::endl; - std::cout << "Delta Y Input tensor: " << dy_input.desc.ToString() << std::endl; - switch(badtensor) - { - case(0): - std::cout << "Delta X output tensor output failed verification." << std::endl; - break; - case(1): std::cout << "Delta scale output tensor failed verification." << std::endl; break; - case(2): std::cout << "Delta shift output tensor failed verification." << std::endl; break; - default: break; - } - } -}; - -template -struct batch_norm_spatial_nhwc_driver : test_driver -{ - tensor input; - tensor scale; - tensor shift; - batch_norm_spatial_nhwc_driver() - { - this->batch_factor = 4; - add(input, - "input", - get_bn_spatial_input_tensor( - tensor_elem_gen_integer{miopen_type{} == miopenHalf ? 5 : 17})); - } - - void run() - { - std::size_t n, c, h, w; - std::tie(n, c, h, w) = miopen::tien<4>(input.desc.GetLengths()); - - std::size_t ssn, ssc, ssh, ssw; - auto derivedBnDesc = miopen::TensorDescriptor(input.desc.GetType(), - std::vector{1, 1, 1, c}, - std::vector{c, c, c, 1}); - std::tie(ssn, ssh, ssw, ssc) = miopen::tien<4>(derivedBnDesc.GetLengths()); - - std::vector new_len = input.desc.GetLengths(); - std::vector new_str; - miopen::tensor_layout_to_strides(new_len, "NCHW", "NHWC", new_str); - input.desc = miopen::TensorDescriptor(miopen_type{}, new_len, new_str); - - if(input.desc.GetType() == miopenFloat) - { - scale = tensor{ssn, ssh, ssw, ssc}.generate(tensor_elem_gen_integer{17}); - shift = tensor{ssn, ssh, ssw, ssc}.generate(tensor_elem_gen_integer{17}); - } - else - { - scale = tensor{ssn, ssh, ssw, ssc}; - shift = tensor{ssn, ssh, ssw, ssc}; - - const PREC_TYPE Data_scale = static_cast(1e-4); - for(std::size_t i = 0; i < scale.desc.GetElementSize(); i++) - { - scale[i] = prng::gen_descreet_uniform_sign(Data_scale, 100); - shift[i] = prng::gen_descreet_uniform_sign(Data_scale, 100); - } - for(std::size_t i = 0; i < input.desc.GetElementSize(); i++) - { - input[i] = prng::gen_descreet_uniform_sign(static_cast(1e-5), 100); - } - } - - auto outpair = verify(verify_forward_train_bn_spatial{input, scale, shift}); - - auto dy_input = std::get<0>(outpair.second); - for(std::size_t bidx = 0; bidx < n; bidx++) - { - for(std::size_t cidx = 0; cidx < c; cidx++) - { - for(std::size_t row = 0; row < h; row++) - { - for(std::size_t column = 0; column < w; column++) - { - dy_input(bidx, cidx, row, column) *= 0.1; - } - } - } - } - this->tolerance = 80 * input.desc.GetElementSize(); - verify(verify_backward_bn_spatial_recalc{input, dy_input, scale}); - - auto savedMean = std::get<3>(outpair.second); - auto savedInvVar = std::get<4>(outpair.second); - verify(verify_backward_bn_spatial_use_saved{ - input, dy_input, scale, savedMean, savedInvVar}); - } -}; - -int main(int argc, const char* argv[]) -{ - test_drive(argc, argv); - return 0; -} diff --git a/test/fusionHost.hpp b/test/fusionHost.hpp index cffefea0e2..5374abd1fa 100644 --- a/test/fusionHost.hpp +++ b/test/fusionHost.hpp @@ -36,7 +36,6 @@ #include #include #include -// #include "driver.hpp" #include "get_handle.hpp" #include "tensor_holder.hpp" #include "verify.hpp" @@ -203,17 +202,17 @@ void batchNormPerActivHostInference(const tensor& input, }); } -template +template void batchNormSpatialHostFwdTrain(const tensor& input, tensor& out, const tensor& scale, const tensor& bias, double epsilon, double expAvgFactor, - tensor& saveMean, - tensor& saveInvVar, - tensor& runMean, - tensor& runVar) + tensor& saveMean, + tensor& saveInvVar, + tensor& runMean, + tensor& runVar) { int height, width, n_batch, channels; @@ -279,15 +278,15 @@ void batchNormSpatialHostFwdTrain(const tensor& input, }); } -template -void batchNormSpatialHostBwdTrain(const tensor& x_input, - const tensor& dy_input, - tensor& dx_out, - const tensor& scale, - tensor& dscale, - tensor& dbias, - const tensor& savedMean, - const tensor& savedInvVar) +template +void batchNormSpatialHostBwdTrain(const tensor& x_input, + const tensor& dy_input, + tensor& dx_out, + const tensor& scale, + tensor& dscale, + tensor& dbias, + const tensor& savedMean, + const tensor& savedInvVar) { int height, width, n_batch, channels; @@ -335,7 +334,7 @@ void batchNormSpatialHostBwdTrain(const tensor& x_input, double tmp1 = nhw * dy_input(bidx, cidx, row, column) - dbias(0, cidx, 0, 0); double tmp2 = -xhat[xhat_index] * dscale(0, cidx, 0, 0); double tmp3 = (scale(0, cidx, 0, 0) * invVar) / nhw; - dx_out(bidx, cidx, row, column) = static_cast(tmp3 * (tmp2 + tmp1)); + dx_out(bidx, cidx, row, column) = static_cast(tmp3 * (tmp2 + tmp1)); } // end for(n_batchs) } // for (column) } // for (row) diff --git a/test/gtest/bn.hpp b/test/gtest/bn.hpp index 0b763da411..22f8391fe6 100644 --- a/test/gtest/bn.hpp +++ b/test/gtest/bn.hpp @@ -84,3 +84,174 @@ struct BNInferTest : public ::testing::TestWithParam +struct BNBwdTest : public ::testing::TestWithParam> +{ +protected: + void SetUp() override + { + test_skipped = false; + std::tie(bn_config, tensor_layout) = GetParam(); + bn_bwd_test_data.SetUpImpl(bn_config, tensor_layout); + + auto&& handle = get_handle(); + miopenBatchNormalizationBackward(&handle, + bn_config.mode, + &bn_bwd_test_data.alphaDataDiff, + &bn_bwd_test_data.betaDataDiff, + &bn_bwd_test_data.alphaParamDiff, + &bn_bwd_test_data.betaParamDiff, + &bn_bwd_test_data.input.desc, + bn_bwd_test_data.in_dev.get(), + &bn_bwd_test_data.dy.desc, + bn_bwd_test_data.dy_dev.get(), + &bn_bwd_test_data.output.desc, + bn_bwd_test_data.out_dev.get(), + &bn_bwd_test_data.bnScale.desc, + bn_bwd_test_data.bnScale_dev.get(), + bn_bwd_test_data.dScale_dev.get(), + bn_bwd_test_data.dBias_dev.get(), + bn_bwd_test_data.epsilon, + bn_bwd_test_data.savedMean_dev.get(), + bn_bwd_test_data.savedInvVar_dev.get()); + + std::fill(bn_bwd_test_data.output.begin(), + bn_bwd_test_data.output.end(), + std::numeric_limits::quiet_NaN()); + } + + void TearDown() override + { + if(test_skipped) + return; + auto&& handle = get_handle(); + bn_bwd_test_data.output.data = + handle.Read(bn_bwd_test_data.out_dev, bn_bwd_test_data.output.data.size()); + bn_bwd_test_data.dScale.data = handle.Read(bn_bwd_test_data.dScale_dev, + bn_bwd_test_data.dScale.data.size()); + bn_bwd_test_data.dBias.data = + handle.Read(bn_bwd_test_data.dBias_dev, bn_bwd_test_data.dBias.data.size()); + + test::ComputeCPUBNBwd(bn_bwd_test_data); + + // using tolerance = 1e-4 since this the tolerance CK uses + test::CompareTensor(bn_bwd_test_data.output, bn_bwd_test_data.ref_out, 1e-4); + test::CompareTensor(bn_bwd_test_data.dScale, bn_bwd_test_data.dScale_ref, 1e-4); + test::CompareTensor(bn_bwd_test_data.dBias, bn_bwd_test_data.dBias_ref, 1e-4); + } + + BNTestCase bn_config; + bool test_skipped = false; + BNBwdTestData + bn_bwd_test_data; + miopenTensorLayout_t tensor_layout; +}; + +template +struct BNFwdTrainTest + : public ::testing::TestWithParam> +{ +protected: + void SetUp() override + { + test_skipped = false; + std::tie(bn_config, tensor_layout) = GetParam(); + bn_fwd_train_test_data.SetUpImpl(bn_config, tensor_layout); + + auto&& handle = get_handle(); + miopenBatchNormalizationForwardTraining(&handle, + bn_config.mode, + &bn_fwd_train_test_data.alpha, + &bn_fwd_train_test_data.beta, + &bn_fwd_train_test_data.input.desc, + bn_fwd_train_test_data.in_dev.get(), + &bn_fwd_train_test_data.output.desc, + bn_fwd_train_test_data.out_dev.get(), + &bn_fwd_train_test_data.scale.desc, + bn_fwd_train_test_data.scale_dev.get(), + bn_fwd_train_test_data.shift_dev.get(), + bn_fwd_train_test_data.averageFactor, + bn_fwd_train_test_data.runMean_dev.get(), + bn_fwd_train_test_data.runVariance_dev.get(), + bn_fwd_train_test_data.epsilon, + bn_fwd_train_test_data.saveMean_dev.get(), + bn_fwd_train_test_data.saveVariance_dev.get()); + + std::fill(bn_fwd_train_test_data.output.begin(), + bn_fwd_train_test_data.output.end(), + std::numeric_limits::quiet_NaN()); + std::fill(bn_fwd_train_test_data.saveMean_ref.begin(), + bn_fwd_train_test_data.saveMean_ref.end(), + std::numeric_limits::quiet_NaN()); + std::fill(bn_fwd_train_test_data.saveVariance_ref.begin(), + bn_fwd_train_test_data.saveVariance_ref.end(), + std::numeric_limits::quiet_NaN()); + } + + void TearDown() override + { + if(test_skipped) + return; + auto&& handle = get_handle(); + bn_fwd_train_test_data.output.data = handle.Read( + bn_fwd_train_test_data.out_dev, bn_fwd_train_test_data.output.data.size()); + + bn_fwd_train_test_data.saveMean.data = handle.Read( + bn_fwd_train_test_data.saveMean_dev, bn_fwd_train_test_data.saveMean.data.size()); + bn_fwd_train_test_data.saveVariance.data = + handle.Read(bn_fwd_train_test_data.saveVariance_dev, + bn_fwd_train_test_data.saveVariance_ref.data.size()); + bn_fwd_train_test_data.runMean.data = handle.Read( + bn_fwd_train_test_data.runMean_dev, bn_fwd_train_test_data.runMean_ref.data.size()); + bn_fwd_train_test_data.runVariance.data = + handle.Read(bn_fwd_train_test_data.runVariance_dev, + bn_fwd_train_test_data.runVariance_ref.data.size()); + test::ComputeCPUBNFwdTrain(bn_fwd_train_test_data); + + // 4e-3 is tolerance used by CK kernel. + test::CompareTensor( + bn_fwd_train_test_data.output, bn_fwd_train_test_data.ref_out, 4e-3); + test::CompareTensor( + bn_fwd_train_test_data.saveMean, bn_fwd_train_test_data.saveMean_ref, 4e-3); + test::CompareTensor( + bn_fwd_train_test_data.saveVariance, bn_fwd_train_test_data.saveVariance_ref, 4e-3); + test::CompareTensor( + bn_fwd_train_test_data.runMean, bn_fwd_train_test_data.runMean_ref, 4e-3); + test::CompareTensor( + bn_fwd_train_test_data.runVariance, bn_fwd_train_test_data.runVariance_ref, 4e-3); + } + + BNTestCase bn_config; + bool test_skipped = false; + BNFwdTrainTestData + bn_fwd_train_test_data; + miopenTensorLayout_t tensor_layout; +}; diff --git a/test/gtest/bn_bwd.cpp b/test/gtest/bn_bwd.cpp new file mode 100644 index 0000000000..722b42e872 --- /dev/null +++ b/test/gtest/bn_bwd.cpp @@ -0,0 +1,73 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "bn.hpp" + +struct BNBwdTestTestHalf + : BNBwdTest +{ +}; + +struct BNBwdTestFloat : BNBwdTest +{ +}; + +struct BNBwdTestBFloat16 : BNBwdTest +{ +}; + +struct BNBwdTestDouble : BNBwdTest +{ +}; + +TEST_P(BNBwdTestTestHalf, BnBwdCKHalf) {} + +TEST_P(BNBwdTestFloat, BnBwdCKFloat) {} + +// Currently disabled since miopen::batchnorm::MakeForwardTrainingNetworkConfig +// only supports half and float +TEST_P(BNBwdTestBFloat16, DISABLED_BnBwdCKBFloat16) {} +TEST_P(BNBwdTestDouble, DISABLED_BnBwdCKDouble) {} + +INSTANTIATE_TEST_SUITE_P(BNBwdTestTestHalfNHWCSuite, + BNBwdTestTestHalf, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNBwdTestFloatNHWCSuite, + BNBwdTestFloat, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNBwdTestBFloat16NHWCSuite, + BNBwdTestBFloat16, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNBwdTestDoubleNHWCSuite, + BNBwdTestDouble, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); diff --git a/test/gtest/bn_fwd_train.cpp b/test/gtest/bn_fwd_train.cpp new file mode 100644 index 0000000000..4a4dd4c728 --- /dev/null +++ b/test/gtest/bn_fwd_train.cpp @@ -0,0 +1,73 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "bn.hpp" + +struct BNFwdTrainTestHalf + : BNFwdTrainTest +{ +}; + +struct BNFwdTrainTestFloat : BNFwdTrainTest +{ +}; + +struct BNFwdTrainTestDouble : BNFwdTrainTest +{ +}; + +struct BNFwdTrainTestBFloat16 : BNFwdTrainTest +{ +}; + +TEST_P(BNFwdTrainTestHalf, BnFwdTrainCKHalf) {} + +TEST_P(BNFwdTrainTestFloat, BnFwdTrainCKFloat) {} + +// Currently disabled since miopen::batchnorm::MakeForwardTrainingNetworkConfig +// only supports half and float +TEST_P(BNFwdTrainTestDouble, DISABLED_BnFwdTrainCKDouble) {} +TEST_P(BNFwdTrainTestBFloat16, DISABLED_BnFwdTrainCKBFloat16) {} + +INSTANTIATE_TEST_SUITE_P(BNFwdTrainTestHalfNHWCSuite, + BNFwdTrainTestHalf, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNFwdTrainTestFloatNHWCSuite, + BNFwdTrainTestFloat, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNFwdTrainTestFloatNHWCSuite, + BNFwdTrainTestDouble, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNFwdTrainTestFloatNHWCSuite, + BNFwdTrainTestBFloat16, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); diff --git a/test/gtest/bn_infer.cpp b/test/gtest/bn_infer.cpp index 6598ef7169..0dceaa1ba5 100644 --- a/test/gtest/bn_infer.cpp +++ b/test/gtest/bn_infer.cpp @@ -43,14 +43,14 @@ struct BNInferTestBFloat16 : BNInferTest +#include "random.hpp" #include #include @@ -60,7 +59,8 @@ std::vector Network1() { // pyt_mlperf_resnet50v1.5 return { - {16, 8, 128, 256, miopenBNSpatial, miopen::batchnorm::Direction::ForwardInference, 1, 0}, + {192, 1, 8, 8, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 1, 0}, + {16, 8, 128, 256, miopenBNSpatial, miopen::batchnorm::Direction::ForwardTraining, 1, 0}, {16, 8, 128, 256, miopenBNSpatial, miopen::batchnorm::Direction::ForwardInference, 1, 0}, {64, 2048, 7, 7, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 0, 1}, {64, 2048, 7, 7, miopenBNSpatial, miopen::batchnorm::Direction::ForwardTraining, 1, 1}, @@ -125,7 +125,7 @@ struct BNTestData { input = tensor{miopen_type{}, tensor_layout, bn_config.GetInput()}; output = tensor{miopen_type{}, tensor_layout, bn_config.GetInput()}; - ref_out = output; + ref_out = tensor{miopen_type{}, tensor_layout, bn_config.GetInput()}; } void InitTensorsWithRandValue() @@ -226,3 +226,218 @@ struct BNInferTestData : public BNTestData estVariance_dev = handle.Write(estVariance.data); } }; + +template +struct BNBwdTestData : public BNTestData +{ + void SetUpImpl(const TConfig& config, miopenTensorLayout_t t_layout) + { + BNTestData::SetUpImpl(config, t_layout); + CreateTensors(); + InitTensorsWithRandValue(); + WriteToGPU(); + } + + tensor bnScale; + + tensor savedMean; + tensor savedInvVar; + + tensor dy; + tensor dScale; + tensor dBias; + tensor dScale_ref; + tensor dBias_ref; + + miopen::Allocator::ManageDataPtr bnScale_dev; + miopen::Allocator::ManageDataPtr savedMean_dev; + miopen::Allocator::ManageDataPtr savedInvVar_dev; + + miopen::Allocator::ManageDataPtr dy_dev; + miopen::Allocator::ManageDataPtr dScale_dev; + miopen::Allocator::ManageDataPtr dBias_dev; + miopen::Allocator::ManageDataPtr dScale_ref_dev; + miopen::Allocator::ManageDataPtr dBias_ref_dev; + double epsilon = std::numeric_limits::epsilon(); + + float alphaDataDiff = static_cast(1), betaDataDiff = static_cast(0); + float alphaParamDiff = static_cast(1), betaParamDiff = static_cast(0); + +private: + void CreateTensors() + { + dy = tensor{miopen_type{}, + BNTestData::tensor_layout, + BNTestData::bn_config.GetInput()}; + + auto derivedBnDesc = miopen::TensorDescriptor{}; + miopen::DeriveBNTensorDescriptor(derivedBnDesc, + BNTestData::input.desc, + BNTestData::bn_mode); + bnScale = tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + savedMean = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + savedInvVar = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + dScale = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + dBias = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + dScale_ref = dScale; + dBias_ref = dBias; + } + + void InitTensorsWithRandValue() + { + auto gen_value = [](auto...) { + return prng::gen_descreet_uniform_sign(static_cast(1e-2), 100); + }; + dy.generate(gen_value); + bnScale.generate(gen_value); + savedMean.generate(gen_value); + + auto gen_var = [](auto...) { + return static_cast(1e-2) * + static_cast(prng::gen_0_to_B(100) + 1); + }; + savedInvVar.generate(gen_var); + + std::fill(dScale.begin(), dScale.end(), 0.); + std::fill(dBias.begin(), dBias.end(), 0.); + + std::fill(dScale_ref.begin(), dScale_ref.end(), 0.); + std::fill(dBias_ref.begin(), dBias_ref.end(), 0.); + } + void WriteToGPU() + { + auto&& handle = get_handle(); + + bnScale_dev = handle.Write(bnScale.data); + savedMean_dev = handle.Write(savedMean.data); + savedInvVar_dev = handle.Write(savedInvVar.data); + dy_dev = handle.Write(dy.data); + + dScale_dev = handle.Write(dScale.data); + dBias_dev = handle.Write(dBias.data); + } +}; + +template +struct BNFwdTrainTestData : public BNTestData +{ + void SetUpImpl(const TConfig& config, miopenTensorLayout_t t_layout) + { + BNTestData::SetUpImpl(config, t_layout); + CreateTensors(); + InitTensorsWithRandValue(); + WriteToGPU(); + } + + tensor scale; + tensor shift; + tensor saveMean; + tensor saveVariance; + tensor runMean; + tensor runVariance; + + tensor saveMean_ref; + tensor saveVariance_ref; + tensor runMean_ref; + tensor runVariance_ref; + + miopen::Allocator::ManageDataPtr scale_dev; + miopen::Allocator::ManageDataPtr shift_dev; // bias + miopen::Allocator::ManageDataPtr saveMean_dev; + miopen::Allocator::ManageDataPtr saveVariance_dev; + miopen::Allocator::ManageDataPtr runMean_dev; + miopen::Allocator::ManageDataPtr runVariance_dev; + double epsilon = 1.0e-5; + double averageFactor = 0.1; + float alpha = static_cast(1.0f); + float beta = static_cast(0); + const float activ_alpha = static_cast(0.5f); + const float activ_beta = static_cast(0.5f); + const float activ_gamma = static_cast(0.5f); + +private: + void CreateTensors() + { + auto derivedBnDesc = miopen::TensorDescriptor{}; + miopen::DeriveBNTensorDescriptor(derivedBnDesc, + BNTestData::input.desc, + BNTestData::bn_mode); + scale = tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + shift = tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + saveMean = tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + saveVariance = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + runMean = tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + runVariance = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + } + + void InitTensorsWithRandValue() + { + auto gen_value = [](auto...) { + return prng::gen_descreet_uniform_sign(static_cast(1e-2), 100); + }; + scale.generate(gen_value); + shift.generate(gen_value); + + auto gen_var = [](auto...) { + return static_cast(1e-2) * + static_cast(prng::gen_0_to_B(100) + 1); + }; + runMean.generate(gen_var); + runVariance.generate(gen_var); + + saveMean_ref = saveMean; + saveVariance_ref = saveVariance; + runMean_ref = runMean; + runVariance_ref = runVariance; + } + void WriteToGPU() + { + auto&& handle = get_handle(); + scale_dev = handle.Write(scale.data); + shift_dev = handle.Write(shift.data); + saveMean_dev = handle.Write(saveMean.data); + saveVariance_dev = handle.Write(saveVariance.data); + runMean_dev = handle.Write(runMean.data); + runVariance_dev = handle.Write(runVariance.data); + } +}; diff --git a/test/gtest/test_operations.hpp b/test/gtest/test_operations.hpp index d1528fe2bb..da41212302 100644 --- a/test/gtest/test_operations.hpp +++ b/test/gtest/test_operations.hpp @@ -38,6 +38,41 @@ void ComputeCPUBNInference(DLModule& dl_module) dl_module.estVariance); } +template +void ComputeCPUBNBwd(DLModule& dl_module) +{ + batchNormSpatialHostBwdTrain(dl_module.input, + dl_module.dy, + dl_module.ref_out, + dl_module.bnScale, + dl_module.dScale_ref, + dl_module.dBias_ref, + dl_module.savedMean, + dl_module.savedInvVar); +} + +template +void ComputeCPUBNFwdTrain(DLModule& dl_module) +{ + batchNormSpatialHostFwdTrain(dl_module.input, + dl_module.ref_out, + dl_module.scale, + dl_module.shift, + dl_module.epsilon, + dl_module.averageFactor, + dl_module.saveMean_ref, + dl_module.saveVariance_ref, + dl_module.runMean_ref, + dl_module.runVariance_ref); +} + template void CompareTensor(const tensor& output, const tensor& ref_out,