Skip to content

Commit

Permalink
bg/LWPMIOPEN-192: Integrate CK's batch norm forward training into non…
Browse files Browse the repository at this point in the history
…-tunable MIOpen solver (#2386)

* bg/LWPMIOPEN-192: add batch norm foward CK kernel

* bg/LWPMIOPEN-192 : analyze cleanup

* fix a typo

* bg/LWPMIOPEN-192: fix review comments

* bg/LWPMIOPEN-192 : fix compile error

* bg/LWPMIOPEN-192 : fix clang tidy

---------

Co-authored-by: Jun Liu <Liu.Jun@amd.com>
  • Loading branch information
bghimireamd and junliume authored Oct 5, 2023
1 parent d492864 commit d5eb31c
Show file tree
Hide file tree
Showing 10 changed files with 539 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ set( MIOpen_Source
solver/batchnorm/forward_per_activation_fused.cpp
solver/batchnorm/forward_spatial_multiple.cpp
solver/batchnorm/forward_spatial_single.cpp
solver/batchnorm/forward_training_ck.cpp
solver/conv_asm_1x1u.cpp
solver/conv_asm_1x1u_bias_activ_fused.cpp
solver/conv_asm_1x1u_stride2.cpp
Expand Down
10 changes: 10 additions & 0 deletions src/include/miopen/batchnorm/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ struct BnCKBwdBackward final : BatchnormSolver
const miopen::batchnorm::ProblemDescription& problem) const override;
};

struct BnCKFwdTraining final : BatchnormSolver
{
const std::string& SolverDbId() const override { return GetSolverDbId<BnCKFwdTraining>(); }

bool IsApplicable(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
ConvSolution GetSolution(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
};

} // namespace batchnorm

} // namespace solver
Expand Down
3 changes: 2 additions & 1 deletion src/ocl/batchnormocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ void BatchNormForwardTraining(Handle& handle,
return tmp;
}();

const auto solvers = solver::SolverContainer<solver::batchnorm::BnFwdTrainingSpatialSingle,
const auto solvers = solver::SolverContainer<solver::batchnorm::BnCKFwdTraining,
solver::batchnorm::BnFwdTrainingSpatialSingle,
solver::batchnorm::BnFwdTrainingSpatialMultiple,
solver::batchnorm::BnFwdTrainingPerActivation>{};

Expand Down
1 change: 1 addition & 0 deletions src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)
registry, ++id, ConvHipImplicitGemm3DGroupBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM);
Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdInference{}.SolverDbId());
Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKBwdBackward{}.SolverDbId());
Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdTraining{}.SolverDbId());

// IMPORTANT: New solvers should be added to the end of the function!
}
Expand Down
239 changes: 239 additions & 0 deletions src/solver/batchnorm/forward_training_ck.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@

/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/

#include <miopen/batchnorm/solvers.hpp>
#include <miopen/batchnorm/invoke_params.hpp>
#include <miopen/batch_norm.hpp>
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
#include <miopen/solver/ck_utility_common.hpp>
#include <ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp>
#include <miopen/solver/implicitgemm_ck_util.hpp>
#endif
MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_CK_BN_FWD_TRAINING)

namespace miopen {
namespace solver {
namespace batchnorm {
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL

using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
using index_t = int32_t;

constexpr index_t Rank = 4;
constexpr index_t NumBatchNormReduceDim = 3;

using F16 = ck::half_t;
using F32 = float;
using F64 = double;
using BF16 = ushort;

template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
using DeviceOpBNFwdTrainingPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchNormFwd<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType,
PassThroughOp,
Rank,
NumBatchNormReduceDim>>;

struct CKArgsBNormFwdTraining
{
CKArgsBNormFwdTraining(const miopen::batchnorm::ProblemDescription& problem)
{
std::copy(problem.GetXDesc().GetLengths().begin(),
problem.GetXDesc().GetLengths().end(),
xyLengths.begin());

std::copy(problem.GetXDesc().GetStrides().begin(),
problem.GetXDesc().GetStrides().end(),
xyStrides.begin());
arrScaleBiasMeanVarLengths[0] = xyLengths[1]; // get channel
arrScaleBiasMeanVarStrides[0] = 1;

// prep for CK
std::sort(xyStrides.begin(), xyStrides.end(), std::greater<>());
std::rotate(xyLengths.begin() + 1, xyLengths.begin() + 2, xyLengths.end());
}

CKArgsBNormFwdTraining(const CKArgsBNormFwdTraining&) = default;
CKArgsBNormFwdTraining(CKArgsBNormFwdTraining&&) = default;
CKArgsBNormFwdTraining& operator=(const CKArgsBNormFwdTraining&) = default;

template <typename InvokerPtr, typename InvokerParams>
auto MakeArgPtr(const InvokerPtr& invoker_ptr, const InvokerParams& data_ctx) const
{
return invoker_ptr->MakeArgumentPointer(xyLengths,
xyStrides,
xyStrides,
reduceDims,
arrScaleBiasMeanVarLengths,
arrScaleBiasMeanVarStrides,
arrScaleBiasMeanVarStrides,
arrScaleBiasMeanVarStrides,
data_ctx.x,
data_ctx.bnScale,
data_ctx.bnBias,
data_ctx.epsilon,
PassThroughOp{},
data_ctx.y,
data_ctx.resultSaveMean,
data_ctx.resultSaveInvVariance,
data_ctx.expAvgFactor,
data_ctx.resultRunningMean,
data_ctx.resultRunningVariance);
}

template <typename ConvPtr>
bool IsSupportedBy(const ConvPtr& invoker_ptr) const
{
auto arg_ptr = MakeArgPtr(invoker_ptr, miopen::batchnorm::InvokeParams{});
return invoker_ptr->IsSupportedArgument(arg_ptr.get());
}

std::array<ck::index_t, Rank> xyLengths;
std::array<ck::index_t, Rank> xyStrides;
std::vector<int> invariantDims;

std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarLengths;
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarStrides;

std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};
};

template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
static bool CheckCKApplicability(const miopen::batchnorm::ProblemDescription& problem)
{
return IsCKApplicable<DeviceOpBNFwdTrainingPtrs<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType>,
CKArgsBNormFwdTraining>(problem);
}
#endif

bool BnCKFwdTraining::IsApplicable(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& bn_problem) const
{
#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL
std::ignore = context;
std::ignore = fdesc_problem;
return false;
#else
if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_CK_BN_FWD_TRAINING{}))
return false;
if(!bn_problem.IsLayoutNHWC())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
return false;

switch(bn_problem.GetXDesc().GetType())
{
case miopenHalf: return CheckCKApplicability<F16, F16, F32, F16, F16, F32>(bn_problem);
case miopenFloat: return CheckCKApplicability<F32, F32, F32, F32, F32, F32>(bn_problem);
case miopenDouble: return CheckCKApplicability<F64, F64, F64, F64, F64, F64>(bn_problem);
case miopenBFloat16: return CheckCKApplicability<BF16, BF16, F32, BF16, BF16, F32>(bn_problem);
case miopenInt32:
case miopenInt8:
case miopenInt8x4:
case miopenBFloat8:
case miopenFloat8:
default: MIOPEN_THROW("BnCKFwdTraining operation does not supprot this data type");
}
return false;
#endif
}

template <typename XDataType,
typename YDataType,
typename AccDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem)
{
const auto& valid_kernel_ids = FillValidKernelsIDs<DeviceOpBNFwdTrainingPtrs<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType>,
CKArgsBNormFwdTraining>(bn_problem);
assert(!valid_kernel_ids.empty());
const auto& kernel_id = valid_kernel_ids[0];
return InitAnyInvokerFactory<DeviceOpBNFwdTrainingPtrs<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType>,
CKArgsBNormFwdTraining,
miopen::batchnorm::InvokeParams>(bn_problem, kernel_id);
}

ConvSolution BnCKFwdTraining::GetSolution(
[[maybe_unused]] const ExecutionContext& context,
[[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
switch(bn_problem.GetXDesc().GetType())
{

case miopenFloat: return MakeAnyInvokerFactory<F32, F32, F32, F32, F32, F32>(bn_problem);
case miopenDouble: return MakeAnyInvokerFactory<F64, F64, F64, F64, F64, F64>(bn_problem);
case miopenHalf: return MakeAnyInvokerFactory<F16, F16, F32, F16, F16, F32>(bn_problem);
case miopenBFloat16: return MakeAnyInvokerFactory<BF16, BF16, F32, BF16, BF16, F32>(bn_problem);
case miopenInt8:
case miopenInt32:
case miopenInt8x4:
case miopenBFloat8:
case miopenFloat8:
default:
MIOPEN_THROW(miopenStatusInternalError, "BnCKFwdTraining operation not for this data type");
}
#endif
return {};
}

} // namespace batchnorm
} // namespace solver
} // namespace miopen
10 changes: 5 additions & 5 deletions test/fusionHost.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,17 @@ void batchNormPerActivHostInference(const tensor<T>& input,
});
}

template <class T, class U>
template <class T, class U, class V = U>
void batchNormSpatialHostFwdTrain(const tensor<T>& input,
tensor<T>& out,
const tensor<U>& scale,
const tensor<U>& bias,
double epsilon,
double expAvgFactor,
tensor<U>& saveMean,
tensor<U>& saveInvVar,
tensor<U>& runMean,
tensor<U>& runVar)
tensor<V>& saveMean,
tensor<V>& saveInvVar,
tensor<V>& runMean,
tensor<V>& runVar)
{

int height, width, n_batch, channels;
Expand Down
Loading

0 comments on commit d5eb31c

Please sign in to comment.