Skip to content

Commit

Permalink
TensorDescriptor refactoring & enhancements (#3213)
Browse files Browse the repository at this point in the history
  • Loading branch information
averinevg authored Sep 5, 2024
1 parent 7e161f5 commit 21569c0
Show file tree
Hide file tree
Showing 28 changed files with 1,813 additions and 433 deletions.
8 changes: 4 additions & 4 deletions driver/conv_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,13 +932,13 @@ int ConvDriver<Tgpu, Tref>::GetandSetData()
SetConvDescriptorFromCmdLineArgs();

std::vector<int> out_len = GetOutputTensorLengths();
if(miopen::deref(inputTensor).GetLayout_t() == miopenTensorNCHWc4 ||
miopen::deref(inputTensor).GetLayout_t() == miopenTensorNCHWc8)
if(miopen::deref(inputTensor).GetLayoutEnum() == miopenTensorNCHWc4 ||
miopen::deref(inputTensor).GetLayoutEnum() == miopenTensorNCHWc8)
{
out_len[1] *= miopen::deref(inputTensor).GetVectorLength();
}
if(miopen::deref(inputTensor).GetLayout_t() == miopenTensorCHWNc4 ||
miopen::deref(inputTensor).GetLayout_t() == miopenTensorCHWNc8)
if(miopen::deref(inputTensor).GetLayoutEnum() == miopenTensorCHWNc4 ||
miopen::deref(inputTensor).GetLayoutEnum() == miopenTensorCHWNc8)
{
out_len[0] *= miopen::deref(inputTensor).GetVectorLength();
}
Expand Down
34 changes: 11 additions & 23 deletions src/conv/problem_description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,6 @@ std::string ProblemDescription::GetAlphaBetaCaseStr() const
}
}

void ProblemDescription::HeuristicUpdateLayouts()
{
const std::string labels = tensor_layout_get_default(in_layout.size());

static const std::vector<std::string> supported_layouts = {"NCHW", "NHWC", "CHWN", "NCDHW"};
for(const std::string& layout : supported_layouts)
{
// Skip layouts that doesn't match dimension sizes
if(layout.size() != labels.size())
continue;

if(in.IsPossibleLayout(labels, layout) && out.IsPossibleLayout(labels, layout) &&
weights.IsPossibleLayout(labels, layout))
{
in_layout = layout;
weights_layout = layout;
out_layout = layout;
return;
}
}
// If we did not find consistent layout, leave them as-is
}

void ProblemDescription::MakeNetworkConfig(std::string& conf_key) const
{
std::ostringstream ss;
Expand Down Expand Up @@ -294,5 +271,16 @@ void ProblemDescription::SetupFloats(ExecutionContext& ctx) const
<< "x" << GetDataTypeName(GetOutDataType()));
}

std::string ProblemDescription::ComputeLayout(const TensorDescriptor& td) const
{
return td.GetLayout_str();
}

std::string ProblemDescription::ComputeInLayout() const { return ComputeLayout(in); }

std::string ProblemDescription::ComputeOutLayout() const { return ComputeLayout(out); }

std::string ProblemDescription::ComputeWeightsLayout() const { return ComputeLayout(weights); }

} // namespace conv
} // namespace miopen
3 changes: 1 addition & 2 deletions src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,7 @@ TensorDescriptor ConvolutionDescriptor::GetForwardOutputTensor(const TensorDescr
miopenDataType_t yType) const
{
// output layout same as input
const std::string default_layout = tensor_layout_get_default(xDesc.GetNumDims());
const std::string in_layout = xDesc.GetLayout(default_layout);
const std::string in_layout = xDesc.GetLayout_str();
return GetForwardOutputTensorWithLayout(xDesc, wDesc, in_layout, yType);
}

Expand Down
12 changes: 6 additions & 6 deletions src/driver_arguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ std::string ConvArgsForMIOpenDriver(const miopen::TensorDescriptor& xDesc,
<< " -v " << convDesc.GetConvStrides()[1] //
<< " -l " << convDesc.GetConvDilations()[0] //
<< " -j " << convDesc.GetConvDilations()[1];
std::string x_layout = xDesc.GetLayout("NCHW");
std::string w_layout = wDesc.GetLayout("NCHW");
std::string y_layout = yDesc.GetLayout("NCHW");
std::string x_layout = xDesc.GetLayout_str();
std::string w_layout = wDesc.GetLayout_str();
std::string y_layout = yDesc.GetLayout_str();
if(x_layout != "NCHW")
{
ss << " --in_layout " << x_layout;
Expand Down Expand Up @@ -182,9 +182,9 @@ std::string ConvArgsForMIOpenDriver(const miopen::TensorDescriptor& xDesc,
<< " -l " << convDesc.GetConvDilations()[1] //
<< " -j " << convDesc.GetConvDilations()[2] //
<< " --spatial_dim 3";
std::string x_layout = xDesc.GetLayout("NCDHW");
std::string w_layout = wDesc.GetLayout("NCDHW");
std::string y_layout = yDesc.GetLayout("NCDHW");
std::string x_layout = xDesc.GetLayout_str();
std::string w_layout = wDesc.GetLayout_str();
std::string y_layout = yDesc.GetLayout_str();
if(x_layout != "NCDHW")
{
ss << " --in_layout " << x_layout;
Expand Down
12 changes: 1 addition & 11 deletions src/include/miopen/batchnorm/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,7 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, Prob
NetworkConfig MakeForwardInferenceNetworkConfig() const;
NetworkConfig MakeBackwardNetworkConfig() const;

std::string ComputeLayout(const TensorDescriptor& td) const
{
if(spatial_dim == 2)
{
return td.GetLayout("NCHW");
}
else
{
return td.GetLayout("NCDHW");
}
}
std::string ComputeLayout(const TensorDescriptor& td) const { return td.GetLayout_str(); }
std::string ComputeInLayout() const { return ComputeLayout(xDesc); }
std::string ComputeOutLayout() const { return ComputeLayout(yOrDyDesc); }
std::string ComputeDinLayout() const { return ComputeLayout(dxDesc); }
Expand Down
46 changes: 6 additions & 40 deletions src/include/miopen/conv/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase
beta(beta_),
alpha_beta_case(ClassifyAlphaBeta(alpha, beta))
{
HeuristicUpdateLayouts();
}

// Conv descriptor getters
Expand Down Expand Up @@ -224,14 +223,14 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase
std::size_t GetWeightsDepth() const { return GetD5(GetSpatialDims(), weights.GetLengths()); }
std::size_t GetWeightsHeight() const
{
if(weights.GetLayout_str() == "CHWNc")
if(weights_layout == "CHWNc")
return GetHofCHWN(weights.GetLengths());
else
return GetH5(GetSpatialDims(), weights.GetLengths());
}
std::size_t GetWeightsWidth() const
{
if(weights.GetLayout_str() == "CHWNc")
if(weights_layout == "CHWNc")
return GetWofCHWN(weights.GetLengths());
else
return GetW5(GetSpatialDims(), weights.GetLengths());
Expand Down Expand Up @@ -369,8 +368,6 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase
out.AllLengthsFitIntoInt();
}

void HeuristicUpdateLayouts();

void MakeNetworkConfig(std::string& conf_key) const;

NetworkConfig MakeNetworkConfig() const override
Expand Down Expand Up @@ -443,41 +440,10 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase
void SetupFloats(ExecutionContext& ctx) const;

private:
std::string ComputeInLayout() const
{
if(GetSpatialDims() == 2)
{
return in.GetLayout(in.GetLayout_str());
}
else
{
return in.GetLayout("NCDHW");
}
}

std::string ComputeOutLayout() const
{
if(GetSpatialDims() == 2)
{
return out.GetLayout(out.GetLayout_str());
}
else
{
return out.GetLayout("NCDHW");
}
}

std::string ComputeWeightsLayout() const
{
if(GetSpatialDims() == 2)
{
return weights.GetLayout(weights.GetLayout_str());
}
else
{
return weights.GetLayout("NCDHW");
}
}
std::string ComputeLayout(const TensorDescriptor& td) const;
std::string ComputeInLayout() const;
std::string ComputeOutLayout() const;
std::string ComputeWeightsLayout() const;

TensorDescriptor in;
TensorDescriptor weights;
Expand Down
35 changes: 5 additions & 30 deletions src/include/miopen/graphapi/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class Tensor : public TensorDescriptor
int64_t mId = 0;
bool mVirtual = false;

// Deprecated
using TensorDescriptor::GetLayout_t;

public:
Tensor() noexcept = default;
Tensor(const Tensor&) = default;
Expand All @@ -60,50 +63,22 @@ class Tensor : public TensorDescriptor
const std::vector<std::size_t>& strides,
int64_t id,
bool isVirtual)
: TensorDescriptor(dataType, getLayout(strides), dimensions, strides),
mId(id),
mVirtual(isVirtual)
: TensorDescriptor(dataType, dimensions, strides), mId(id), mVirtual(isVirtual)
{
}
Tensor(miopenDataType_t dataType,
std::vector<std::size_t>&& dimensions,
std::vector<std::size_t>&& strides,
int64_t id,
bool isVirtual) noexcept
: TensorDescriptor(dataType, getLayout(strides), std::move(dimensions), std::move(strides)),
: TensorDescriptor(dataType, std::move(dimensions), std::move(strides)),
mId(id),
mVirtual(isVirtual)
{
}

int64_t getId() const noexcept { return mId; }
bool isVirtual() const noexcept { return mVirtual; }

private:
static miopenTensorLayout_t getLayout(const std::vector<std::size_t>& strides)
{
if(strides.size() >= 4)
{
int stride_c = strides[1];

// If channels have the smallest stride, or are tied for smallest stride, then we are
// assuming NHWC format. Otherwise, assume NCHW format.
if(std::all_of(strides.cbegin(), strides.cend(), [stride_c](std::size_t x) {
return x >= stride_c;
}))
{
return strides.size() == 4 ? miopenTensorLayout_t::miopenTensorNHWC
: miopenTensorLayout_t::miopenTensorNDHWC;
}
else
{
return strides.size() == 4 ? miopenTensorLayout_t::miopenTensorNCHW
: miopenTensorLayout_t::miopenTensorNCDHW;
}
}

return GetDefaultLayout();
}
};

class MIOPEN_INTERNALS_EXPORT TensorBuilder
Expand Down
90 changes: 34 additions & 56 deletions src/include/miopen/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,13 @@ struct MIOPEN_INTERNALS_EXPORT TensorDescriptor : miopenTensorDescriptor
unsigned GetNumDims() const;

miopenDataType_t GetType() const;
// clang-format off
[[deprecated("Use GetLayoutEnum() instead")]]
miopenTensorLayout_t GetLayout_t() const;
static std::string GetLayoutStr(miopenTensorLayout_t layout);
std::string GetLayout_str() const;
// clang-format on
const std::optional<miopenTensorLayout_t>& GetLayoutEnum() const;
static std::string LayoutEnumToStr(miopenTensorLayout_t layout);
const std::string& GetLayout_str() const;

std::size_t GetVectorLength() const;
std::optional<miopenDataType_t> GetCastType() const;
Expand Down Expand Up @@ -240,81 +244,40 @@ struct MIOPEN_INTERNALS_EXPORT TensorDescriptor : miopenTensorDescriptor

std::string ToString() const;

bool IsPossibleLayout(const std::string& labels, const std::string& layout) const;
// For vectorized layouts storage_layout must be without the ending 'c'
// \todo make private
bool IsPossibleLayout(const std::string& storage_layout, const std::string& layout) const;
// Layout could be NCHW, NHWC, NCDHW, NDHWC, NCHWc, ...
bool IsPossibleLayout4D5D(const std::string& layout) const;

static inline std::vector<int64_t> find_permutation(const std::vector<std::size_t>& lens,
const std::vector<std::size_t>& strides)
{
std::vector<std::int64_t> result(lens.size());
std::iota(result.begin(), result.end(), 0);
std::stable_sort(result.begin(), result.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(strides[x], lens[x]);
}));
return result;
}
static std::vector<int64_t> find_permutation(const std::vector<std::size_t>& lens,
const std::vector<std::size_t>& strides);

std::string GetLayout(std::string labels) const
{
if(*(labels.end() - 1) != 'c')
{
if(labels.size() != strides.size())
{
MIOPEN_THROW(
"Invalid labels size. Layout labels size must be equavalent to stride size");
}

// Copy construct the result string from labels. This allocates the space at one go
// and is faster than calling push_back in transform.
auto result = labels;
auto p = find_permutation(lens, strides);
std::transform(p.begin(), p.end(), result.begin(), [&](auto i) { return labels[i]; });
return result;
}
else
{
const std::string base_label = labels.substr(0, labels.size() - 1);
if(base_label.size() != strides.size())
{
MIOPEN_THROW(
"Invalid labels size. Layout labels size must be equavalent to stride size");
}
auto result = base_label;
auto p = find_permutation(lens, strides);
std::transform(p.begin(), p.end(), result.begin(), [&](auto i) { return labels[i]; });
return result + 'c';
}
}
// storage_layout must be NCHW or NCHWc for NCHWc, CHWN or CHWNc for CHWNc, NCHW for other 4D
// layouts, NCDHW for 5D layouts
std::string GetLayout(std::string storage_layout) const;

friend MIOPEN_INTERNALS_EXPORT std::ostream& operator<<(std::ostream& stream,
const TensorDescriptor& t);

friend void to_json(nlohmann::json& j, const TensorDescriptor& descriptor);
friend void from_json(const nlohmann::json& j, TensorDescriptor& descriptor);

protected:
static miopenTensorLayout_t GetDefaultLayout() { return miopenTensorNCHW; };

private:
TensorDescriptor(miopenDataType_t t,
miopenTensorLayout_t layout_in,
const std::optional<miopenTensorLayout_t>& layout_in,
const std::vector<std::size_t>& lens_in,
const std::vector<std::size_t>& strides_in,
bool use_strides);

TensorDescriptor(miopenDataType_t t,
miopenTensorLayout_t layout_in,
const std::optional<miopenTensorLayout_t>& layout_in,
std::vector<std::size_t>&& lens_in,
std::vector<std::size_t>&& strides_in,
bool use_strides);

void CheckArgsAndInit(bool use_strides);

void SetStrideNd(const std::string& layout);
void LensReorder(const std::string& layout);

void CalculateStrides();
void CalculateVectorLength();

std::vector<std::size_t> lens;
std::vector<std::size_t> strides;

Expand All @@ -323,7 +286,22 @@ struct MIOPEN_INTERNALS_EXPORT TensorDescriptor : miopenTensorDescriptor

miopenDataType_t type = miopenFloat;
std::optional<miopenDataType_t> cast_type;
miopenTensorLayout_t tensorLayout = GetDefaultLayout();
std::optional<miopenTensorLayout_t> tensorLayout;

// For GetLayoutEnum()
mutable std::optional<miopenTensorLayout_t> cached_layout_enum;
mutable bool cached_layout_enum_calculated = false;

// For GetLayout_str()
mutable std::string cached_layout_str;

// For GetLayout
mutable std::vector<int64_t> cached_permutation;

// For AllLengthsFitIntoInt()
mutable std::optional<bool> cached_lengths_fit_into_int;
// For AllDimsFitIntoInt()
mutable std::optional<bool> cached_strides_fit_into_int;
};

template <class TElement>
Expand Down
Loading

0 comments on commit 21569c0

Please sign in to comment.