Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward data GEMM solvers and invokers #800

Merged
merged 11 commits into from
Apr 2, 2021
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ set( MIOpen_Source
conv_algo_name.cpp
conv/problem_description.cpp
solver/gemm.cpp
solver/gemm_bwd.cpp
dropout.cpp
dropout_api.cpp
readonlyramdb.cpp
Expand Down
101 changes: 18 additions & 83 deletions src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,21 +286,17 @@ ConvolutionDescriptor::BackwardGetValidWorkSpaceSizeGemm(const TensorDescriptor&
#if MIOPEN_USE_GEMM
if(!miopen::IsDisabled(MIOPEN_DEBUG_CONV_GEMM{}))
{
const auto wei_spatial =
boost::adaptors::slice(wDesc.GetLengths(), 2, 2 + GetSpatialDimension());

if(GetSpatialDimension() == 2 &&
miopen::all_of(wei_spatial, [](auto v) { return v == 1; }) &&
miopen::all_of(GetConvPads(), [](auto v) { return v == 0; }) &&
miopen::all_of(GetConvStrides(), [](auto v) { return v == 2; }))
return BackwardDataGetWorkSpaceSizeGEMMTranspose(dyDesc, dxDesc);

if(miopen::all_of(wei_spatial, [](auto v) { return v == 1; }) &&
miopen::all_of(GetConvPads(), [](auto v) { return v == 0; }) &&
miopen::all_of(GetConvStrides(), [](auto v) { return v == 1; }))
return 0;
const auto ctx =
ConvolutionContext{dxDesc, wDesc, dyDesc, *this, conv::Direction::BackwardData};
decltype(auto) gemm_ws_sz_pairs = AllGemmWorkspaceSize(ctx);

return BackwardDataGetWorkSpaceSizeGEMM(wDesc, dyDesc);
if(!gemm_ws_sz_pairs.empty())
{
decltype(auto) gemm_ws_szs =
gemm_ws_sz_pairs |
boost::adaptors::transformed([](const auto& p) { return p.second; });
return *std::max_element(gemm_ws_szs.begin(), gemm_ws_szs.end());
}
}
return 0;
#else
Expand Down Expand Up @@ -484,29 +480,19 @@ ConvolutionDescriptor::BackwardDataGetWorkSpaceSize(Handle& handle,
std::max({direct_workspace, implicit_gemm_workspace, workspace_size_winograd});
if(!miopen::IsDisabled(MIOPEN_DEBUG_CONV_GEMM{}))
{
workspace_size_gemm = BackwardDataGetWorkSpaceSizeGEMM(wDesc, dyDesc);
if(workspace_size_gemm > MAX_MEM_ALLOC_SZ(handle))
workspace_size_gemm = 0;

const auto wei_spatial =
boost::adaptors::slice(wDesc.GetLengths(), 2, 2 + GetSpatialDimension());
decltype(auto) gemm_ws_sz_pairs = AllGemmWorkspaceSize(ctx);

if(miopen::all_of(wei_spatial, [](auto v) { return v == 1; }) &&
miopen::all_of(GetConvPads(), [](auto v) { return v == 0; }) &&
miopen::all_of(GetConvStrides(), [](auto v) { return v == 2; }))
if(!gemm_ws_sz_pairs.empty())
{
size_t gemm_trans = BackwardDataGetWorkSpaceSizeGEMMTranspose(dyDesc, dxDesc);
if(gemm_trans > MAX_MEM_ALLOC_SZ(handle))
gemm_trans = 0;
tmp_max_workspace = std::max(gemm_trans, tmp_max_workspace);
MIOPEN_LOG_I2(tmp_max_workspace);
return tmp_max_workspace;
decltype(auto) gemm_ws_szs =
gemm_ws_sz_pairs |
boost::adaptors::transformed([](const auto& p) { return p.second; });
workspace_size_gemm = *std::max_element(gemm_ws_szs.begin(), gemm_ws_szs.end());
}

if(miopen::any_of(GetConvDilations(), [](auto v) { return v > 1; }))
{
tmp_max_workspace = std::max(workspace_size_gemm, tmp_max_workspace);
MIOPEN_LOG_I2(tmp_max_workspace);
return tmp_max_workspace;
return std::max({workspace_size_gemm, tmp_max_workspace});
}
}
#endif
Expand All @@ -522,57 +508,6 @@ ConvolutionDescriptor::BackwardDataGetWorkSpaceSize(Handle& handle,
return workspace_size;
}

std::size_t
ConvolutionDescriptor::BackwardDataGetWorkSpaceSizeGEMM(const TensorDescriptor& wDesc,
const TensorDescriptor& dyDesc) const
{
const std::size_t spatial_dim = GetSpatialDimension();

auto wei_spatial = boost::adaptors::slice(wDesc.GetLengths(), 2, 2 + spatial_dim);
auto out_spatial = boost::adaptors::slice(dyDesc.GetLengths(), 2, 2 + spatial_dim);

const std::size_t wei_c = wDesc.GetLengths()[1];

std::size_t gemm_size = wei_c * std::accumulate(wei_spatial.begin(),
wei_spatial.end(),
std::size_t(1),
std::multiplies<std::size_t>()) *
std::accumulate(out_spatial.begin(),
out_spatial.end(),
std::size_t(1),
std::multiplies<std::size_t>()) *
GetTypeSize(dyDesc.GetType()) * group_count;

// No workspace is needed for 1x1_stride=1 convolutions
if(miopen::all_of(wei_spatial, [](auto v) { return v == 1; }) &&
miopen::all_of(GetConvStrides(), [](auto v) { return v == 1; }) &&
miopen::all_of(GetConvPads(), [](auto v) { return v == 0; }))
{
return 0;
}

return gemm_size;
}

std::size_t ConvolutionDescriptor::BackwardDataGetWorkSpaceSizeGEMMTranspose(
const TensorDescriptor& dyDesc, const TensorDescriptor& dxDesc) const
{
std::size_t in_n, in_c;
std::tie(in_n, in_c) = miopen::tie_pick<0, 1>{}(dxDesc.GetLengths());

auto out_spatial = boost::adaptors::slice(dyDesc.GetLengths(), 2, 2 + GetSpatialDimension());

const std::size_t dx_t_size = in_n * in_c * std::accumulate(out_spatial.begin(),
out_spatial.end(),
std::size_t(1),
std::multiplies<std::size_t>()) *
GetTypeSize(dxDesc.GetType());

const std::size_t dy_t_size = dyDesc.GetElementSize() * GetTypeSize(dyDesc.GetType());

return dx_t_size + dy_t_size;
}

std::size_t
ConvolutionDescriptor::BackwardWeightsGetWorkSpaceSizeGEMM(const TensorDescriptor& dyDesc,
const TensorDescriptor& dwDesc) const
Expand Down
2 changes: 1 addition & 1 deletion src/find_db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ bool CheckInvokerSupport(const std::string& algo)
algo == "miopenConvolutionBwdDataAlgoImplicitGEMM" ||
algo == "miopenConvolutionBwdWeightsAlgoImplicitGEMM" ||
algo == "miopenConvolutionFwdAlgoFFT" || algo == "miopenConvolutionBwdDataAlgoFFT" ||
algo == "miopenConvolutionFwdAlgoGEMM";
algo == "miopenConvolutionFwdAlgoGEMM" || algo == "miopenConvolutionBwdDataAlgoGEMM";
}

template <class TDb>
Expand Down
6 changes: 0 additions & 6 deletions src/include/miopen/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,6 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor
std::size_t workSpaceSize,
solver::Id solver_id) const;

std::size_t BackwardDataGetWorkSpaceSizeGEMM(const TensorDescriptor& wDesc,
const TensorDescriptor& dyDesc) const;

std::size_t BackwardDataGetWorkSpaceSizeGEMMTranspose(const TensorDescriptor& dyDesc,
const TensorDescriptor& dxDesc) const;

std::size_t BackwardDataGetWorkSpaceSize(Handle& handle,
const TensorDescriptor& wDesc,
const TensorDescriptor& dyDesc,
Expand Down
74 changes: 74 additions & 0 deletions src/include/miopen/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,80 @@ struct GemmFwdRest : GemmFwdBase
ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const;
};

struct GemmBwdBase : SolverBase<ConvolutionContext>
{
bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const;
bool IsDynamic() const { return true; }
float GetWti(const ConvolutionContext& ctx) const { return GetWti(ctx, ctx.conv_problem); }
float GetWti(const ExecutionContext& context, const conv::ProblemDescription& problem) const;
};

struct GemmBwd1x1_stride2 : GemmBwdBase
{
size_t GetWorkspaceSize(const ConvolutionContext& ctx) const
{
return GetWorkspaceSize(ctx, ctx.conv_problem);
}

bool IsApplicable(const ConvolutionContext& ctx) const
{
return IsApplicable(ctx, ctx.conv_problem);
}

ConvSolution GetSolution(const ConvolutionContext& ctx) const
{
return GetSolution(ctx, ctx.conv_problem);
}

size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const;
bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const;
ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const;
};

struct GemmBwd1x1_stride1 : GemmBwdBase
{
size_t GetWorkspaceSize(const ConvolutionContext& ctx) const
{
return GetWorkspaceSize(ctx, ctx.conv_problem);
}

bool IsApplicable(const ConvolutionContext& ctx) const
{
return IsApplicable(ctx, ctx.conv_problem);
}

ConvSolution GetSolution(const ConvolutionContext& ctx) const
{
return GetSolution(ctx, ctx.conv_problem);
}

size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const;
bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const;
ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const;
};

struct GemmBwdRest : GemmBwdBase
{
size_t GetWorkspaceSize(const ConvolutionContext& ctx) const
{
return GetWorkspaceSize(ctx, ctx.conv_problem);
}

bool IsApplicable(const ConvolutionContext& ctx) const
{
return IsApplicable(ctx, ctx.conv_problem);
}

ConvSolution GetSolution(const ConvolutionContext& ctx) const
{
return GetSolution(ctx, ctx.conv_problem);
}

size_t GetWorkspaceSize(const ExecutionContext&, const conv::ProblemDescription&) const;
bool IsApplicable(const ExecutionContext&, const conv::ProblemDescription&) const;
ConvSolution GetSolution(const ExecutionContext&, const conv::ProblemDescription&) const;
};

struct AnySolver;

} // namespace solver
Expand Down
Loading