Skip to content

Commit

Permalink
[MLIR] Imlement mlir binary backend - step 3: wrw path (#892)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryyin authored Apr 28, 2021
1 parent 95b1c76 commit b004695
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ set( MIOpen_Source
solver/conv_hip_implicit_gemm_mlir_cpp_wrw.cpp
solver/conv_hip_implicit_gemm_mlir_bin_fwd.cpp
solver/conv_hip_implicit_gemm_mlir_bin_bwd.cpp
solver/conv_hip_implicit_gemm_mlir_bin_wrw.cpp
solver/conv_hip_implicit_gemm_wrw_v4r4.cpp
solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp
solver/conv_hip_implicit_gemm_xdlops_common.cpp
Expand Down
47 changes: 38 additions & 9 deletions src/conv/invokers/mlir_impl_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <miopen/memref.hpp>

#include <miopen/conv/data_invoke_params.hpp>
#include <miopen/conv/wrw_invoke_params.hpp>
#include <miopen/algorithm.hpp>
#include <miopen/handle.hpp>
#include <miopen/tensor_ops.hpp>
Expand Down Expand Up @@ -111,22 +112,22 @@ MlirConvArgs makeMlirConvArgs(const std::vector<size_t>& in_dims,
return {filter, input, output};
}

void setMlirConvArgsPtr(const ConvDataTensors& tensors, MlirConvArgs& args)
void setMlirConvArgsPtr(ConstData_t in, ConstData_t out, ConstData_t w, MlirConvArgs& args)
{
void* filter = nullptr;
void* input = nullptr;
void* output = nullptr;
#if MIOPEN_BACKEND_OPENCL
clGetMemObjectInfo(tensors.w, CL_MEM_HOST_PTR, sizeof(filter), &filter, nullptr);
clGetMemObjectInfo(tensors.in, CL_MEM_HOST_PTR, sizeof(input), &input, nullptr);
clGetMemObjectInfo(tensors.out, CL_MEM_HOST_PTR, sizeof(output), &output, nullptr);
clGetMemObjectInfo(w, CL_MEM_HOST_PTR, sizeof(filter), &filter, nullptr);
clGetMemObjectInfo(in, CL_MEM_HOST_PTR, sizeof(input), &input, nullptr);
clGetMemObjectInfo(out, CL_MEM_HOST_PTR, sizeof(output), &output, nullptr);
#elif MIOPEN_BACKEND_HIP
// NOLINTNEXTLINE (cppcoreguidelines-pro-type-const-cast)
filter = const_cast<void*>(tensors.w);
filter = const_cast<void*>(w);
// NOLINTNEXTLINE (cppcoreguidelines-pro-type-const-cast)
input = const_cast<void*>(tensors.in);
input = const_cast<void*>(in);
// NOLINTNEXTLINE (cppcoreguidelines-pro-type-const-cast)
output = const_cast<void*>(tensors.out);
output = const_cast<void*>(out);
#endif

if((filter == nullptr) || (input == nullptr) || (output == nullptr))
Expand Down Expand Up @@ -165,7 +166,7 @@ InvokerFactory MakeMlirFwdInvokerFactory(const ConvolutionContext& ctx)
primitive_parameters.CastTo<conv::DataInvokeParams>();
const auto& tensors = forward_invoke_params.tensors;

setMlirConvArgsPtr(tensors, args);
setMlirConvArgsPtr(tensors.in, tensors.out, tensors.w, args);
handle.Run(kernels[0])(args);
};
};
Expand Down Expand Up @@ -214,7 +215,7 @@ InvokerFactory MakeMlirBwdInvokerFactory(const ConvolutionContext& ctx)
if(handle.IsProfilingEnabled())
elapsed += handle.GetKernelTime();

setMlirConvArgsPtr(tensors, args);
setMlirConvArgsPtr(tensors.in, tensors.out, tensors.w, args);
handle.Run(kernels[0])(args);
if(handle.IsProfilingEnabled())
{
Expand All @@ -226,5 +227,33 @@ InvokerFactory MakeMlirBwdInvokerFactory(const ConvolutionContext& ctx)
};
}

InvokerFactory MakeMlirWrWInvokerFactory(const ConvolutionContext& ctx)
{
assert((ctx.direction.IsBackwardWrW()));

std::vector<size_t> in_dims, in_strides;
std::vector<size_t> weights_dims, weights_strides;
std::vector<size_t> out_dims, out_strides;
permuteDimStridesAllDir(ctx.conv_problem,
in_dims,
in_strides,
weights_dims,
weights_strides,
out_dims,
out_strides);
MlirConvArgs args =
makeMlirConvArgs(in_dims, in_strides, weights_dims, weights_strides, out_dims, out_strides);

return [=](const std::vector<Kernel>& kernels) mutable {
return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable {
const auto& wrw_invoke_params = primitive_parameters.CastTo<conv::WrWInvokeParams>();
const auto& tensors = wrw_invoke_params.tensors;

setMlirConvArgsPtr(tensors.x, tensors.dy, tensors.dw, args);
handle.Run(kernels[0])(args);
};
};
}

} // namespace conv
} // namespace miopen
1 change: 1 addition & 0 deletions src/include/miopen/conv/invokers/mlir_impl_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace conv {

InvokerFactory MakeMlirFwdInvokerFactory(const ConvolutionContext& ctx);
InvokerFactory MakeMlirBwdInvokerFactory(const ConvolutionContext& ctx);
InvokerFactory MakeMlirWrWInvokerFactory(const ConvolutionContext& ctx);

} // namespace conv
} // namespace miopen
6 changes: 6 additions & 0 deletions src/include/miopen/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,12 @@ struct ConvHipImplicitGemmMlirCppWrW : SolverBase<ConvolutionContext>
ConvSolution GetSolution(const ConvolutionContext& ctx) const;
};

struct ConvHipImplicitGemmMlirBinWrW : SolverBase<ConvolutionContext>
{
bool IsApplicable(const ConvolutionContext& ctx) const;
ConvSolution GetSolution(const ConvolutionContext& ctx) const;
};

struct PerformanceImplicitGemmXdlops : Serializable<PerformanceImplicitGemmXdlops>
{
int BPerBlock; // 2^n[8..16]
Expand Down
1 change: 1 addition & 0 deletions src/mlo_dir_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ static auto GetImplicitGemmWrWSolvers()
miopen::solver::ConvHipImplicitGemmV4R4WrW,
miopen::solver::ConvAsmImplicitGemmV4R1DynamicWrw,
miopen::solver::ConvHipImplicitGemmMlirCppWrW,
miopen::solver::ConvHipImplicitGemmMlirBinWrW,
miopen::solver::ConvAsmImplicitGemmGTCDynamicWrwXdlops>{};
}

Expand Down
2 changes: 2 additions & 0 deletions src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)
registry, ++id, ConvHipImplicitGemmMlirBinFwd{}, miopenConvolutionAlgoImplicitGEMM);
RegisterWithSolver(
registry, ++id, ConvHipImplicitGemmMlirBinBwd{}, miopenConvolutionAlgoImplicitGEMM);
RegisterWithSolver(
registry, ++id, ConvHipImplicitGemmMlirBinWrW{}, miopenConvolutionAlgoImplicitGEMM);
// IMPORTANT: New solvers should be added to the end of the function!
}

Expand Down
160 changes: 160 additions & 0 deletions src/solver/conv_hip_implicit_gemm_mlir_bin_wrw.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include <miopen/mlir_build.hpp>
#include <miopen/conv/invokers/mlir_impl_gemm.hpp>
#include <miopen/conv/wrw_invoke_params.hpp>
#include <miopen/config.h>
#include <miopen/env.hpp>
#include <miopen/solver.hpp>
#include <miopen/solver/implicitgemm_util.hpp>

MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_HIP_IMPLICIT_GEMM_MLIR_BIN_WRW)

namespace miopen {
namespace solver {

namespace {
#if MIOPEN_USE_MLIR
std::tuple<int, int, int> CalculateGemmSize(const ConvolutionContext& ctx)
{
const size_t n = ConvolutionContextInterpreter::GetBatchN(ctx);
const size_t c = ConvolutionContextInterpreter::GetInputChannelC(ctx);
const size_t k = ConvolutionContextInterpreter::GetOutputChannelK(ctx);
const size_t ho = ConvolutionContextInterpreter::GetOutputHeightHo(ctx);
const size_t wo = ConvolutionContextInterpreter::GetOutputWidthWo(ctx);
const size_t y = ConvolutionContextInterpreter::GetFilterHeightY(ctx);
const size_t x = ConvolutionContextInterpreter::GetFilterWidthX(ctx);

const auto gemm_m = k;
const auto gemm_n =
c * y * x * (ctx.Is3d() ? ConvolutionContextInterpreter::GetFilterDepthZ(ctx) : 1);
const auto gemm_k =
n * ho * wo * (ctx.Is3d() ? ConvolutionContextInterpreter::GetOutputDepthDo(ctx) : 1);

return std::make_tuple(gemm_m, gemm_n, gemm_k);
}
#endif
} // Anonymous namespace

bool ConvHipImplicitGemmMlirBinWrW::IsApplicable(const ConvolutionContext& ctx) const
{
#if MIOPEN_USE_MLIR
if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_HIP_IMPLICIT_GEMM_MLIR_BIN_WRW{}))
return false;
// Future: MLIR will support non-default layouts.
if(!ctx.IsLayoutDefault())
return false;
// Future: MLIR will support 3d convolution
if(!ctx.Is2d())
return false;
if(!IsComposableKernelSupportedHardware(ctx))
return false;
if(!ctx.direction.IsBackwardWrW())
return false;
if(!ctx.IsFp32())
return false;
if(ctx.group_counts != 1)
return false;

int gemm_m = 0;
int gemm_n = 0;
int gemm_k = 0;

std::tie(gemm_m, gemm_n, gemm_k) = CalculateGemmSize(ctx);
return gemm_m % 32 == 0 && gemm_n % 32 == 0 && gemm_k % 4 == 0;
#else
std::ignore = ctx;
return false;
#endif
}

ConvSolution ConvHipImplicitGemmMlirBinWrW::GetSolution(const ConvolutionContext& ctx) const
{
#if MIOPEN_USE_MLIR
ConvSolution result;
KernelInfo construction_parameters;

std::string version = "_v4r4";
std::string direction = "_wrw";
std::string operation = "conv2d_bwd_weight";

construction_parameters.kernel_name = "mlir_gen_igemm_conv2d" + version + direction;
construction_parameters.kernel_file = construction_parameters.kernel_name + ".mlir";

// Arguments for mlir-miopen-driver.
// clang-format off
using CI = ConvolutionContextInterpreter;
construction_parameters.comp_options =
std::string(" --operation ") + operation +
std::string(" --num_cu ") + std::to_string(ctx.GetStream().GetMaxComputeUnits()) +
std::string(" --arch ") + ctx.GetStream().GetDeviceName() +
std::string(" --fil_layout ") + CI::GetFilterLayout(ctx) +
std::string(" --fil_type ") + "fp32" +
std::string(" --in_layout ") + CI::GetInputLayout(ctx) +
std::string(" --in_type ") + "fp32" +
std::string(" --out_layout ") + CI::GetOutputLayout(ctx) +
std::string(" --out_type ") + "fp32" +
std::string(" --batchsize ") + std::to_string(CI::GetBatchN(ctx)) +
std::string(" --in_channels ") + std::to_string(CI::GetInputChannelC(ctx)) +
std::string(" --out_channels ") + std::to_string(CI::GetOutputChannelK(ctx)) +
std::string(" --in_h ") + std::to_string(CI::GetInputHeightHi(ctx)) +
std::string(" --in_w ") + std::to_string(CI::GetInputWidthWi(ctx)) +
std::string(" --out_h ") + std::to_string(CI::GetOutputHeightHo(ctx)) +
std::string(" --out_w ") + std::to_string(CI::GetOutputWidthWo(ctx)) +
std::string(" --fil_h ") + std::to_string(CI::GetFilterHeightY(ctx)) +
std::string(" --fil_w ") + std::to_string(CI::GetFilterWidthX(ctx)) +
std::string(" --dilation_h ") + std::to_string(CI::GetAdjustedConvolutionDilationH(ctx)) +
std::string(" --dilation_w ") + std::to_string(CI::GetAdjustedConvolutionDilationW(ctx)) +
std::string(" --conv_stride_h ") + std::to_string(CI::GetAdjustedConvolutionStrideH(ctx)) +
std::string(" --conv_stride_w ") + std::to_string(CI::GetAdjustedConvolutionStrideW(ctx)) +
std::string(" --padding_h ") + std::to_string(CI::GetInputLeftPadH(ctx)) +
std::string(" --padding_w ") + std::to_string(CI::GetInputLeftPadW(ctx)) +
std::string(" --kernel_name ") + construction_parameters.kernel_name;
// clang-format on

size_t local_size = 0;
size_t global_size = 0;
MiirGenLaunchParams(construction_parameters.comp_options, local_size, global_size);

construction_parameters.l_wk.push_back(local_size);
construction_parameters.l_wk.push_back(1);
construction_parameters.l_wk.push_back(1);

construction_parameters.g_wk.push_back(global_size);
construction_parameters.g_wk.push_back(1);
construction_parameters.g_wk.push_back(1);

result.invoker_factory = conv::MakeMlirWrWInvokerFactory(ctx);
result.construction_params.push_back(construction_parameters);
return result;
#else
std::ignore = ctx;
return {};
#endif
}

} // namespace solver
} // namespace miopen
12 changes: 8 additions & 4 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ if(MIOPEN_TEST_MLIR)
set(IMPLICITGEMM_MLIR_ENV_F_CPP ${IMPLICITGEMM_MLIR_ENV_BASE} MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvHipImplicitGemmMlirCppFwd)
set(IMPLICITGEMM_MLIR_ENV_F_BIN ${IMPLICITGEMM_MLIR_ENV_BASE} MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvHipImplicitGemmMlirBinFwd)
set(IMPLICITGEMM_MLIR_ENV_B ${IMPLICITGEMM_MLIR_ENV_BASE} MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvHipImplicitGemmMlirCppBwd)
set(IMPLICITGEMM_MLIR_ENV_W ${IMPLICITGEMM_MLIR_ENV_BASE} MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvHipImplicitGemmMlirCppWrW)
set(IMPLICITGEMM_MLIR_ENV_W_CPP ${IMPLICITGEMM_MLIR_ENV_BASE} MIOPEN_DEBUG_CONV_GEMM=0 MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvHipImplicitGemmMlirCppWrW)
set(IMPLICITGEMM_MLIR_ENV_W_BIN ${IMPLICITGEMM_MLIR_ENV_BASE} MIOPEN_DEBUG_CONV_GEMM=0 MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvHipImplicitGemmMlirBinWrW)

set(IMPLICITGEMM_MLIR_ARGS_F ${IMPLICITGEMM_ARGS} --verbose --disable-backward-data --disable-backward-weights)
set(IMPLICITGEMM_MLIR_ARGS_B ${IMPLICITGEMM_ARGS} --verbose --disable-forward --disable-backward-weights)
Expand All @@ -334,9 +335,12 @@ if(MIOPEN_TEST_MLIR)
COMMAND ${IMPLICITGEMM_MLIR_ENV_B} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_B} --input 64 256 56 56 --weights 256 256 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${IMPLICITGEMM_MLIR_ENV_B} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_B} --input 64 128 58 58 --weights 128 128 3 3 --pads_strides_dilations 0 0 1 1 1 1

COMMAND ${IMPLICITGEMM_MLIR_ENV_W} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_W} --input 64 1024 14 14 --weights 1024 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${IMPLICITGEMM_MLIR_ENV_W} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_W} --input 64 256 56 56 --weights 256 256 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${IMPLICITGEMM_MLIR_ENV_W} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_W} --input 64 128 58 58 --weights 128 128 3 3 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${IMPLICITGEMM_MLIR_ENV_W_CPP} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_W} --input 64 1024 14 14 --weights 1024 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${IMPLICITGEMM_MLIR_ENV_W_CPP} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_W} --input 64 256 56 56 --weights 256 256 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${IMPLICITGEMM_MLIR_ENV_W_CPP} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_W} --input 64 128 58 58 --weights 128 128 3 3 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${IMPLICITGEMM_MLIR_ENV_W_BIN} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_W} --input 64 1024 14 14 --weights 1024 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${IMPLICITGEMM_MLIR_ENV_W_BIN} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_W} --input 64 256 56 56 --weights 256 256 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${IMPLICITGEMM_MLIR_ENV_W_BIN} $<TARGET_FILE:test_conv2d> ${IMPLICITGEMM_MLIR_ARGS_W} --input 64 128 58 58 --weights 128 128 3 3 --pads_strides_dilations 0 0 1 1 1 1
)
endif()

Expand Down

0 comments on commit b004695

Please sign in to comment.