-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor BnCKFwdInference::GetSolution for NHWC #3120
Conversation
…KSolution to implicitgemm_ck_util.hpp
@@ -45,6 +50,142 @@ struct ProblemDescription; | |||
|
|||
namespace solver { | |||
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL | |||
namespace batchnorm { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this code should be moved back to a header used by batchnorm
solvers only if there is common code between forward and backward batchnorm CKArgs etc., then create a separate header otherwise, move back to cpp files.
Structure the CKArgs
class similar to how convolution's CKArgs classes are structured where there's a method called MakeArumentPointer
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amberhassaan , I have reverted the changes in implicitgemm_ck_util.hpp. Also refactored BnCKFwdInference::GetSolution
src/solver/batchnorm/backward_ck.cpp
src/solver/conv/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp
{ | ||
const auto& args = CKArgsBNormFwd{problem}; | ||
ConvSolution result; | ||
result.invoker_factory = [bn_problem](const std::vector<Kernel>& kernels) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not capture problem_description
by value as it's a rather large object. Create variables above line 132 for what you need from problem_description
and capture those variables by values.
bn_problem, | ||
[&](auto data_type_val) { | ||
using T = decltype(data_type_val); | ||
if constexpr(std::is_same_v<T, F16>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if/else-if
logic can be simplified as follows:
using AccTy = std::conditional_t<std::is_same_v<T, F64>,
T, // pick this if true
F32>; // pick this if false
InvokerFactoryNHWC<T, T, AccT, T, T, AccT>(bn_problem);
In fact InvokerFactory coudl also be changed to take just two template type parameters T
and AccT
.
MeanVarDataType>(bn_problem)]( | ||
const std::vector<Kernel>& kernels) { | ||
std::ignore = kernels; | ||
return [&](const Handle& handle, const AnyInvokeParams& primitive_parameters) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@atamazov @DrizztDoUrden can we guarantee that invoker_factory
context exists longer that its generated lambda?
I'm asking about this case:
result.invoker_factory = [=](...)
{
retunr [&](...){...};
};
versus that one:
result.invoker_factory = [=](...)
{
retunr [=](...){...};
};
If we can guarantee that, we can always capture by reference in the second lambda and save some memory and avoid extra copy operations. If we can't, we must always capture by value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it almost always exists for a shorter period. Capturing by a reference here is an error, both lambdas should always capture by copy or move, never referencing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vector<Kernel>& kernels) { | ||
std::ignore = kernels; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::vector<Kernel>& kernels) { | |
std::ignore = kernels; | |
const std::vector<Kernel>& /* kernels */) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CAHEK7 , thanks for the comments. Code updated. Could you mark this as resolved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some minor refactoring needed.
BiasDataType, | ||
MeanVarDataType>(bn_problem)]( | ||
const std::vector<Kernel>& /*kernels*/) mutable { | ||
return [=, args = std::move(args)](const Handle& handle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't use =
by itself, rather name each variable captured by value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amberhassaan done
{params.x, params.estimatedMean, params.estimatedVariance, params.bnScale, params.bnBias}, | ||
{params.y}, | ||
Normalize{params.epsilon}); | ||
auto argument_ptr = bn_ptr->MakeArgumentPointer(args.xyLengths, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a method to BnArgs
called MakeArgPtr
and you can then pick the fields from inside that class. Better design IMO for composability and hiding details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amberhassaan done
const std::vector<Kernel>& /*kernels*/) mutable { | ||
return [args = std::move(args), kernel_index = kernel_index]( | ||
const Handle& handle, const AnyInvokeParams& primitive_parameters) { | ||
using DeviceOp = ck::tensor_operation::device::DeviceElementwise< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I missed this earlier. Please move lines (161-171), i.e. computation of bn_ptr
outside and capture bn_ptr
by move
semantics. See the InitInvokerFactory
for convolution solvers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amberhassaan , changes have been checked in. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks OK now. @junliume : ready to merge when it passes CI.
Rename RunCKSolution to InitInvokerFactoryBnCKFwdInferenceNHWC to differentiate the upcoming the new API InitInvokerFactoryBnCKFwdInferenceNCHW
Move common code to implicitgemm_ck_util.hpp