Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bg/lwpmiopen 193 : Integrate CK's batch norm backward training into non-tunable MIOpen solver #2385

Merged
merged 44 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
e39aa04
bg/LWPMIOPEN-194 : ck batch norm test pass
bghimireamd Jul 6, 2023
766d060
bg/LWPMIOPEN-194 : removed stale code
bghimireamd Jul 24, 2023
75d8d95
bg/LWPMIOPEN-194 : minor mixes
bghimireamd Jul 24, 2023
bb1af7d
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/MIOpen into…
bghimireamd Jul 25, 2023
c995f79
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/MIOpen into…
bghimireamd Aug 14, 2023
171fa3a
bg/LWPMIOPEN-194 : fix rotate lens to pass NHWC layout is CK
bghimireamd Aug 29, 2023
619fe9e
fix clang format issue
junliume Aug 30, 2023
279664b
bg/LWPMIOPEN-194: inhert from non-tunable solver
bghimireamd Sep 3, 2023
4002995
bg/LWPMIOPEN-194 : add all data types supported by CK
bghimireamd Sep 3, 2023
4278231
bg/LWPMIOPEN-194 : fix merge conflicts
bghimireamd Sep 4, 2023
972c6c8
Merge branch 'bg/LWPMIOPEN-194' of github.com:ROCmSoftwarePlatform/MI…
bghimireamd Sep 4, 2023
bd7f9f6
bg/LWPMIOPEN-194 : minor fixes
bghimireamd Sep 4, 2023
291d752
bg/LWPMIOPEN-194: fix clang format
bghimireamd Sep 5, 2023
f09c16e
bg/LWPMIOPEN-193_bn_back : first commit test working
bghimireamd Sep 7, 2023
e7667c3
bg/LWPMIOPEN-194 : add test for all types
bghimireamd Sep 8, 2023
6afd0f5
bg/LWPMIOPEN-194: fixed merge conflict
bghimireamd Sep 8, 2023
ce4beef
bg/LWPMIOPEN-193_bn_back : add test for all types
bghimireamd Sep 8, 2023
714b636
bg/LWPMIOPEN-193_bn_back : fix bn backward host template
bghimireamd Sep 8, 2023
95efae5
bg/LWPMIOPEN-193_bn_back : fix merge conflict
bghimireamd Sep 11, 2023
e0792ee
bg/LWPMIOPEN-193_bn_back : clang format
bghimireamd Sep 11, 2023
ee86ed6
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/MIOpen into…
bghimireamd Sep 11, 2023
334ed9a
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/MIOpen into…
bghimireamd Sep 11, 2023
be51dbd
Merge branch 'bg/LWPMIOPEN-194' of github.com:ROCmSoftwarePlatform/MI…
bghimireamd Sep 11, 2023
335c212
bg/LWPMIOPEN-193_bn_back : add solver to registry
bghimireamd Sep 11, 2023
8a4177b
bg/LWPMIOPEN-193_bn_back : clean ups
bghimireamd Sep 11, 2023
8172f51
bg/LWPMIOPEN-194 : add static to CheckCKApplicability function
bghimireamd Sep 11, 2023
3ce76ff
bg/LWPMIOPEN-194: fix analyze error
bghimireamd Sep 11, 2023
cdc4cb7
bg/LWPMIOPEN-194: fix compile error
bghimireamd Sep 11, 2023
7da7650
Merge branch 'develop' into bg/LWPMIOPEN-193_bn_back
junliume Sep 12, 2023
8a40210
bg/LWPMIOPEN-194: sync with develop
bghimireamd Sep 12, 2023
8e3a3c6
bg/LWPMIOPEN-193_bn_back: merge with bg/LWPMIOPEN-194
bghimireamd Sep 17, 2023
9da8ef2
bg/LWPMIOPEN-193_bn_back : fix merge conflict
bghimireamd Sep 17, 2023
83bd77f
bg/LWPMIOPEN-193_bn_back: fix review comments
bghimireamd Sep 18, 2023
d6b19ec
bg/LWPMIOPEN-193_bn_back : fix tidy error
bghimireamd Sep 18, 2023
87fb611
bg/LWPMIOPEN-193_bn_back : make solver epsilon and test driver epsilo…
bghimireamd Sep 21, 2023
b77d93b
bg/LWPMIOPEN-193_bn_back : add CK's cpu ref for batch norm backward test
bghimireamd Sep 25, 2023
8929dbb
bg/LWPMIOPEN-193_bn_back: moved bn_spatial_nhwc_test.cpp to gtest
bghimireamd Sep 27, 2023
f348749
bg/LWPMIOPEN-193_bn_back : add new rand
bghimireamd Sep 27, 2023
3615ea5
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/MIOpen into…
bghimireamd Sep 27, 2023
6874015
bg/LWPMIOPEN-193_bn_back: handle all type in switch case of solver
bghimireamd Sep 27, 2023
11f66e9
bg/LWPMIOPEN-193_bn_back: clang format
bghimireamd Sep 27, 2023
c4216f6
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/MIOpen into…
bghimireamd Oct 2, 2023
d492864
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/MIOpen into…
bghimireamd Oct 4, 2023
d5eb31c
bg/LWPMIOPEN-192: Integrate CK's batch norm forward training into non…
bghimireamd Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ set( MIOpen_Source
solver/activ/bwd_1.cpp
solver/activ/fwd_0.cpp
solver/activ/fwd_1.cpp
solver/batchnorm/backward_ck.cpp
solver/batchnorm/backward_per_activation.cpp
solver/batchnorm/backward_per_activation_fused.cpp
solver/batchnorm/backward_spatial_multiple.cpp
Expand All @@ -163,6 +164,7 @@ set( MIOpen_Source
solver/batchnorm/forward_per_activation_fused.cpp
solver/batchnorm/forward_spatial_multiple.cpp
solver/batchnorm/forward_spatial_single.cpp
solver/batchnorm/forward_training_ck.cpp
solver/conv_asm_1x1u.cpp
solver/conv_asm_1x1u_bias_activ_fused.cpp
solver/conv_asm_1x1u_stride2.cpp
Expand Down
7 changes: 0 additions & 7 deletions src/batch_norm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,6 @@ miopenBatchNormalizationBackward(miopenHandle_t handle,
const void* savedMean,
const void* savedInvVariance)
{
// bfloat16 not supported for batchnorm operation
if(miopen::deref(xDesc).GetType() == miopenBFloat16 ||
miopen::deref(dyDesc).GetType() == miopenBFloat16 ||
miopen::deref(dxDesc).GetType() == miopenBFloat16)
{
return miopenStatusNotImplemented;
}

MIOPEN_LOG_FUNCTION(handle,
bn_mode,
Expand Down
20 changes: 20 additions & 0 deletions src/include/miopen/batchnorm/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,26 @@ struct BnCKFwdInference final : BatchnormSolver
const miopen::batchnorm::ProblemDescription& problem) const override;
};

struct BnCKBwdBackward final : BatchnormSolver
{
const std::string& SolverDbId() const override { return GetSolverDbId<BnCKBwdBackward>(); }

bool IsApplicable(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
ConvSolution GetSolution(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
};

struct BnCKFwdTraining final : BatchnormSolver
{
const std::string& SolverDbId() const override { return GetSolverDbId<BnCKFwdTraining>(); }

bool IsApplicable(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
ConvSolution GetSolution(const ExecutionContext& context,
const miopen::batchnorm::ProblemDescription& problem) const override;
};

} // namespace batchnorm

} // namespace solver
Expand Down
65 changes: 55 additions & 10 deletions src/include/miopen/solver/implicitgemm_ck_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ typename ConvPtrsType::iterator FindConvPtrByID(ConvPtrsType& conv_ptrs,
});
}

template <typename DeviceOpType, typename CKArgsType>
std::vector<std::string> FillValidKernelsIDs(const ProblemDescription& problem)
template <typename DeviceOpType,
typename CKArgsType,
typename ProblemDescriptionType = ProblemDescription>
std::vector<std::string> FillValidKernelsIDs(const ProblemDescriptionType& problem)
{
const auto args = CKArgsType{problem};
const auto conv_ptrs = DeviceOpType::GetInstances();
Expand All @@ -59,29 +61,36 @@ std::vector<std::string> FillValidKernelsIDs(const ProblemDescription& problem)
return valid_kernels;
}

template <typename DeviceOpType, typename CKArgsType>
bool IsCKArgsSupported(const ProblemDescription& problem, const std::string& kernel_id)
template <typename DeviceOpType,
typename CKArgsType,
typename ProblemDescriptionType = ProblemDescription>
bool IsCKArgsSupported(const ProblemDescriptionType& problem, const std::string& kernel_id)
{
auto conv_ptrs = DeviceOpType::GetInstances();
auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id);

return (ptr_iter != conv_ptrs.end()) && CKArgsType{problem}.IsSupportedBy(*ptr_iter);
}

template <typename DeviceOpType, typename CKArgsType>
bool IsCKApplicable(const ProblemDescription& problem)
template <typename DeviceOpType,
typename CKArgsType,
typename ProblemDescriptionType = ProblemDescription>
bool IsCKApplicable(const ProblemDescriptionType& problem)
{
const auto args = CKArgsType{problem};
if(!std::all_of(args.strides.begin(), args.strides.end(), [](auto x) { return x == 1; }))
return false;
// if(!std::all_of(args.strides.begin(), args.strides.end(), [](auto x) { return x == 1; }))
// return false;

const auto ptrs = DeviceOpType::GetInstances();
return std::any_of(
ptrs.begin(), ptrs.end(), [&args](auto& ptr) { return args.IsSupportedBy(ptr); });
}

template <typename DeviceOpType, typename CKArgsType, typename CastType>
ConvSolution InitInvokerFactory(const ProblemDescription& problem, const std::string& kernel_id)
template <typename DeviceOpType,
typename CKArgsType,
typename CastType,
typename ProblemDescriptionType = ProblemDescription>
ConvSolution InitInvokerFactory(const ProblemDescriptionType& problem, const std::string& kernel_id)
{
auto conv_ptrs = DeviceOpType::GetInstances();
auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id);
Expand Down Expand Up @@ -112,5 +121,41 @@ ConvSolution InitInvokerFactory(const ProblemDescription& problem, const std::st
return result;
}

template <typename DeviceOpType,
typename CKArgsType,
typename CastType,
typename ProblemDescriptionType = ProblemDescription>
ConvSolution InitAnyInvokerFactory(const ProblemDescriptionType& problem,
const std::string& kernel_id)
{
auto conv_ptrs = DeviceOpType::GetInstances();
auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id);

if(ptr_iter == conv_ptrs.end())
return {miopenStatusInvalidValue};

ConvSolution result;
result.invoker_factory =
[ck_args = CKArgsType{problem},
sh_conv_ptr = std::shared_ptr{std::move(*ptr_iter)}](const std::vector<Kernel>&) mutable {
return [ck_args = std::move(ck_args), sh_conv_ptr = std::move(sh_conv_ptr)](
const Handle& handle, const AnyInvokeParams& primitive_parameters) {
const auto& data_ctx = primitive_parameters.CastTo<CastType>();
auto argument_ptr = ck_args.MakeArgPtr(sh_conv_ptr, data_ctx);
auto invoker_ptr = sh_conv_ptr->MakeInvokerPointer();

const auto enable_profiling = handle.IsProfilingEnabled();
float elapsed_time =
invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling});
if(enable_profiling)
{
handle.ResetKernelTime();
handle.AccumKernelTime(elapsed_time);
}
};
};
return result;
}

} // namespace solver
} // namespace miopen
9 changes: 5 additions & 4 deletions src/ocl/batchnormocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ void BatchNormForwardTraining(Handle& handle,
return tmp;
}();

const auto solvers = solver::SolverContainer<solver::batchnorm::BnFwdTrainingSpatialSingle,
const auto solvers = solver::SolverContainer<solver::batchnorm::BnCKFwdTraining,
solver::batchnorm::BnFwdTrainingSpatialSingle,
solver::batchnorm::BnFwdTrainingSpatialMultiple,
solver::batchnorm::BnFwdTrainingPerActivation>{};

Expand Down Expand Up @@ -300,7 +301,7 @@ void BatchNormBackward(Handle& handle,
{
MIOPEN_THROW(miopenStatusBadParm);
}
if(dxDesc.GetType() != dyDesc.GetType() || dyDesc.GetType() != xDesc.GetType())
if(dxDesc.GetType() != dyDesc.GetType())
{
MIOPEN_THROW(miopenStatusBadParm);
}
Expand Down Expand Up @@ -338,15 +339,15 @@ void BatchNormBackward(Handle& handle,
tmp.dx = dx;
tmp.bnScale = bnScale;
tmp.resultBnScaleDiff = resultBnScaleDiff;
tmp.resultBnScaleDiff = resultBnScaleDiff;
tmp.resultBnBiasDiff = resultBnBiasDiff;
tmp.epsilon = epsilon;
tmp.savedMean = savedMean;
tmp.savedInvVariance = savedInvVariance;
return tmp;
}();

const auto solvers = solver::SolverContainer<solver::batchnorm::BnBwdTrainingSpatialSingle,
const auto solvers = solver::SolverContainer<solver::batchnorm::BnCKBwdBackward,
solver::batchnorm::BnBwdTrainingSpatialSingle,
solver::batchnorm::BnBwdTrainingSpatialMultiple,
solver::batchnorm::BnBwdTrainingPerActivation>{};

Expand Down
2 changes: 2 additions & 0 deletions src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)
RegisterWithSolver(
registry, ++id, ConvHipImplicitGemm3DGroupBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM);
Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdInference{}.SolverDbId());
Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKBwdBackward{}.SolverDbId());
Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdTraining{}.SolverDbId());

// IMPORTANT: New solvers should be added to the end of the function!
}
Expand Down
Loading