Skip to content

Commit

Permalink
bwd swap bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
amberhassaan committed Dec 16, 2023
1 parent 641cf69 commit 58e9f09
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ struct CKArgs
Y = ProblemInterpreter::GetFilterHeightY(problem);
X = ProblemInterpreter::GetFilterWidthX(problem);

// On a backward pass, out is in and in is out and this is silly
std::swap(K1, C1);
std::swap(K, C);
std::swap(Hi, Ho);
std::swap(Wi, Wo);

input = {Hi, Wi};
output = {Ho, Wo};
filter = {Y, X};
Expand All @@ -101,11 +107,11 @@ struct CKArgs
~CKArgs() = default;

template <typename ConvPtr>
auto MakeArgPtr(const ConvPtr& conv_ptr, Data_t out, ConstData_t w, ConstData_t in) const
auto MakeArgPtr(const ConvPtr& conv_ptr, Data_t in, ConstData_t w, ConstData_t out) const
{
return conv_ptr->MakeArgumentPointer(out,
return conv_ptr->MakeArgumentPointer(in,
w,
in,
out,
N,
K,
C,
Expand All @@ -124,6 +130,7 @@ struct CKArgs
template <typename ConvPtr>
auto MakeArgPtr(const ConvPtr& conv_ptr, const ConvDataTensors& tensors) const
{
// in is out and out is in
return MakeArgPtr(conv_ptr, tensors.out, tensors.w, tensors.in);
}

Expand Down

0 comments on commit 58e9f09

Please sign in to comment.