Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CK API] 2d backward data convolution composable kernel integration #1874

Merged
merged 15 commits into from
Jan 3, 2023
Merged
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ set( MIOpen_Source
solver/conv_direct_naive_conv_bwd.cpp
solver/conv_direct_naive_conv_fwd.cpp
solver/conv_direct_naive_conv_wrw.cpp
solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp
solver/conv_hip_implicit_gemm_bwd_v1r1.cpp
solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp
solver/conv_hip_implicit_gemm_bwd_v4r1.cpp
Expand Down
112 changes: 112 additions & 0 deletions src/include/miopen/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5740,6 +5740,118 @@ struct ConvHipImplicitGemmFwdXdlops final
const PerformanceConfigHipImplicitGemmFwdXdlops& config) const;
};

struct PerformanceConfigHipImplicitGemmBwdXdlops
: PerfConfigBase<PerformanceConfigHipImplicitGemmBwdXdlops>
{
int index;
std::string kernel_id;
int total_size;
PerformanceConfigHipImplicitGemmBwdXdlops(int idx, std::string kernl_id)
: index(idx), kernel_id(kernl_id), total_size(-1)
{
}
PerformanceConfigHipImplicitGemmBwdXdlops() : PerformanceConfigHipImplicitGemmBwdXdlops(0, "")
{
}
PerformanceConfigHipImplicitGemmBwdXdlops(bool)
: PerformanceConfigHipImplicitGemmBwdXdlops(0, "")
{
}
void HeuristicInit(const ProblemDescription&);
bool SetNextValue(const ConvolutionContext& ctx) { return SetNextValue(ctx.problem); }
bool SetNextValue(const ProblemDescription&);
bool IsValidValue() const;
bool IsValid(const ConvolutionContext& ctx) const { return IsValid(ctx.problem); }
bool IsValid(const ProblemDescription&) const;
template <typename Self, typename F>
static void Visit(Self&& s, F f)
{
f(s.kernel_id, "kernel_id");
}
bool operator==(const PerformanceConfigHipImplicitGemmBwdXdlops& other) const;

private:
template <typename DataType>
void Init(const ProblemDescription&);
template <typename DataType>
bool CheckIsSupportCKArgs(const ProblemDescription&) const;
};

struct ConvHipImplicitGemmBwdXdlops final
: ConvTunableSolver<PerformanceConfigHipImplicitGemmBwdXdlops>
{
// To suppress -Woverloaded-virtual
using ConvTunableSolver::GetDefaultPerformanceConfig;
using ConvTunableSolver::GetSolution;
using ConvTunableSolver::IsApplicable;
using ConvTunableSolver::IsValidPerformanceConfig;
using ConvTunableSolver::Search;

const std::string& SolverDbId() const override
{
return GetSolverDbId<ConvHipImplicitGemmBwdXdlops>();
}

PerformanceConfigHipImplicitGemmBwdXdlops
GetDefaultPerformanceConfig(const ConvolutionContext& ctx) const override
{
return GetDefaultPerformanceConfig(ctx.problem);
}
bool
IsValidPerformanceConfig(const ConvolutionContext& ctx,
const PerformanceConfigHipImplicitGemmBwdXdlops& config) const override
{
return IsValidPerformanceConfig(ctx.problem, config);
}
PerformanceConfigHipImplicitGemmBwdXdlops
Search(const ConvolutionContext& ctx, const AnyInvokeParams& invoke_ctx) const override
{
return Search(ctx, ctx.problem, invoke_ctx);
}
bool IsApplicable(const ConvolutionContext& ctx) const override
{
return IsApplicable(ctx, ctx.problem);
}
bool IsDynamic() const override { return true; }
ConvSolution GetSolution(const ConvolutionContext& ctx,
const PerformanceConfigHipImplicitGemmBwdXdlops& config) const override
{
return GetSolution(ctx, ctx.problem, config);
}
// Magic Number Alert:
// Naive convolutions have GetWti() that return very small value (0.01f).
// This allows MIOpen to use Naive Solvers if no other applicable Solvers
// have known WTIs. Right now this means that in case of find-db miss,
// the library will try to use Winograd or GEMM (whatever is faster according
// to their GetWti's), but if both are not applicable, the library will
// use Naive Solver
// Since we would like to us CK before naive, and use it instead (because
// we do expect that CK is faster than Naive), therefore we use a
// value bigger than 0.01f, e.g. 0.02f.
float GetWti(const ConvolutionContext&) const override { return 0.02f; };
iq136boy marked this conversation as resolved.
Show resolved Hide resolved

private:
bool IsApplicable(const ConvolutionContext&, const ProblemDescription&) const;
PerformanceConfigHipImplicitGemmBwdXdlops
GetDefaultPerformanceConfig(const ProblemDescription&) const;
bool IsValidPerformanceConfig(const ProblemDescription&,
const PerformanceConfigHipImplicitGemmBwdXdlops&) const;
PerformanceConfigHipImplicitGemmBwdXdlops Search(const ConvolutionContext&,
const ProblemDescription&,
const AnyInvokeParams& invoke_ctx) const;
ConvSolution GetSolution(const ConvolutionContext&,
const ProblemDescription&,
const PerformanceConfigHipImplicitGemmBwdXdlops&) const;

template <typename DataType>
bool CheckCKApplicability(const ProblemDescription&) const;
template <typename DataType>
void RunCKSolution(const Handle& handle,
const AnyInvokeParams& primitive_parameters,
const ProblemDescription& problem,
const PerformanceConfigHipImplicitGemmBwdXdlops& config) const;
};

struct AnySolver;

} // namespace solver
Expand Down
1 change: 1 addition & 0 deletions src/mlo_dir_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ static auto GetImplicitGemmSolvers()
miopen::solver::ConvCkIgemmFwdV6r1DlopsNchw,
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
miopen::solver::ConvHipImplicitGemmFwdXdlops,
miopen::solver::ConvHipImplicitGemmBwdXdlops,
#endif // MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
miopen::solver::ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC>{};
}
Expand Down
2 changes: 2 additions & 0 deletions src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)
miopenConvolutionAlgoImplicitGEMM);
RegisterWithSolver(
registry, ++id, ConvHipImplicitGemmFwdXdlops{}, miopenConvolutionAlgoImplicitGEMM);
RegisterWithSolver(
registry, ++id, ConvHipImplicitGemmBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM);

// IMPORTANT: New solvers should be added to the end of the function!
}
Expand Down
Loading