From 58e9f097cacbe1ea786279e819e1c30c262e5ae1 Mon Sep 17 00:00:00 2001 From: "M. Amber Hassaan" Date: Sat, 16 Dec 2023 19:12:30 +0000 Subject: [PATCH] bwd swap bug fix --- .../conv_hip_implicit_gemm_bwd_data_xdlops.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp index bd3eeebd8b..7b805e8df3 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp @@ -81,6 +81,12 @@ struct CKArgs Y = ProblemInterpreter::GetFilterHeightY(problem); X = ProblemInterpreter::GetFilterWidthX(problem); + // On a backward pass, out is in and in is out and this is silly + std::swap(K1, C1); + std::swap(K, C); + std::swap(Hi, Ho); + std::swap(Wi, Wo); + input = {Hi, Wi}; output = {Ho, Wo}; filter = {Y, X}; @@ -101,11 +107,11 @@ struct CKArgs ~CKArgs() = default; template - auto MakeArgPtr(const ConvPtr& conv_ptr, Data_t out, ConstData_t w, ConstData_t in) const + auto MakeArgPtr(const ConvPtr& conv_ptr, Data_t in, ConstData_t w, ConstData_t out) const { - return conv_ptr->MakeArgumentPointer(out, + return conv_ptr->MakeArgumentPointer(in, w, - in, + out, N, K, C, @@ -124,6 +130,7 @@ struct CKArgs template auto MakeArgPtr(const ConvPtr& conv_ptr, const ConvDataTensors& tensors) const { + // in is out and out is in return MakeArgPtr(conv_ptr, tensors.out, tensors.w, tensors.in); }