Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor BnCKFwdInference::GetSolution for NHWC #3120

Merged
merged 16 commits into from
Aug 5, 2024
Merged

Conversation

xinlipn
Copy link
Contributor

@xinlipn xinlipn commented Jul 17, 2024

  1. Rename RunCKSolution to InitInvokerFactoryBnCKFwdInferenceNHWC to differentiate the upcoming the new API InitInvokerFactoryBnCKFwdInferenceNCHW

  2. Move common code to implicitgemm_ck_util.hpp

@@ -45,6 +50,142 @@ struct ProblemDescription;

namespace solver {
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
namespace batchnorm {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code should be moved back to a header used by batchnorm solvers only if there is common code between forward and backward batchnorm CKArgs etc., then create a separate header otherwise, move back to cpp files.

Structure the CKArgs class similar to how convolution's CKArgs classes are structured where there's a method called MakeArumentPointer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amberhassaan , I have reverted the changes in implicitgemm_ck_util.hpp. Also refactored BnCKFwdInference::GetSolution

src/solver/batchnorm/backward_ck.cpp
src/solver/conv/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp

@xinlipn xinlipn self-assigned this Jul 18, 2024
@xinlipn xinlipn marked this pull request as draft July 18, 2024 22:17
{
const auto& args = CKArgsBNormFwd{problem};
ConvSolution result;
result.invoker_factory = [bn_problem](const std::vector<Kernel>& kernels) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not capture problem_description by value as it's a rather large object. Create variables above line 132 for what you need from problem_description and capture those variables by values.

bn_problem,
[&](auto data_type_val) {
using T = decltype(data_type_val);
if constexpr(std::is_same_v<T, F16>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if/else-if logic can be simplified as follows:

using AccTy = std::conditional_t<std::is_same_v<T, F64>, 
                               T, // pick this if true
                               F32>; // pick this if false
InvokerFactoryNHWC<T, T, AccT, T, T, AccT>(bn_problem); 

In fact InvokerFactory coudl also be changed to take just two template type parameters T and AccT.

@xinlipn xinlipn marked this pull request as ready for review July 25, 2024 15:13
MeanVarDataType>(bn_problem)](
const std::vector<Kernel>& kernels) {
std::ignore = kernels;
return [&](const Handle& handle, const AnyInvokeParams& primitive_parameters) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@atamazov @DrizztDoUrden can we guarantee that invoker_factory context exists longer that its generated lambda?
I'm asking about this case:

result.invoker_factory = [=](...)
{
    retunr [&](...){...};
};

versus that one:

result.invoker_factory = [=](...)
{
    retunr [=](...){...};
};

If we can guarantee that, we can always capture by reference in the second lambda and save some memory and avoid extra copy operations. If we can't, we must always capture by value.

Copy link
Contributor

@DrizztDoUrden DrizztDoUrden Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it almost always exists for a shorter period. Capturing by a reference here is an error, both lambdas should always capture by copy or move, never referencing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 139 to 140
const std::vector<Kernel>& kernels) {
std::ignore = kernels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const std::vector<Kernel>& kernels) {
std::ignore = kernels;
const std::vector<Kernel>& /* kernels */) {

Copy link
Contributor Author

@xinlipn xinlipn Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CAHEK7 , thanks for the comments. Code updated. Could you mark this as resolved?

Copy link
Contributor

@amberhassaan amberhassaan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some minor refactoring needed.

BiasDataType,
MeanVarDataType>(bn_problem)](
const std::vector<Kernel>& /*kernels*/) mutable {
return [=, args = std::move(args)](const Handle& handle,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use = by itself, rather name each variable captured by value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{params.x, params.estimatedMean, params.estimatedVariance, params.bnScale, params.bnBias},
{params.y},
Normalize{params.epsilon});
auto argument_ptr = bn_ptr->MakeArgumentPointer(args.xyLengths,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a method to BnArgs called MakeArgPtr and you can then pick the fields from inside that class. Better design IMO for composability and hiding details.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const std::vector<Kernel>& /*kernels*/) mutable {
return [args = std::move(args), kernel_index = kernel_index](
const Handle& handle, const AnyInvokeParams& primitive_parameters) {
using DeviceOp = ck::tensor_operation::device::DeviceElementwise<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed this earlier. Please move lines (161-171), i.e. computation of bn_ptr outside and capture bn_ptr by move semantics. See the InitInvokerFactory for convolution solvers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amberhassaan , changes have been checked in. Thanks

Copy link
Contributor

@amberhassaan amberhassaan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks OK now. @junliume : ready to merge when it passes CI.

@junliume junliume merged commit cfabfbb into develop Aug 5, 2024
140 of 141 checks passed
@junliume junliume deleted the sl/batchnorm_nhwc branch August 5, 2024 05:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants