Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove INT8x4 support #2441

Merged
merged 12 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/datatypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ typedef enum {
miopenFloat = 1,
miopenInt32 = 2,
miopenInt8 = 3,
miopenInt8x4 = 4,
/* Value 4 is reserved. */
miopenBFloat16 = 5,
atamazov marked this conversation as resolved.
Show resolved Hide resolved
} miopenDataType_t;
```
Expand All @@ -22,7 +22,6 @@ Type descriptions:
* `miopenFloat` - 32-bit floating point
* `miopenInt32` - 32-bit integer, used primarily for `int8` convolution outputs
* `miopenInt8` - 8-bit integer, currently only supported by `int8` convolution forward path, tensor set, tensor copy, tensor cast, tensor transform, tensor transpose, and im2col.
* `miopenInt8x4` - 8-bit 4 element vector type used primarily with `int8` convolutions forward path.
* `miopenBFloat16` - brain float fp-16 (8-bit exponent, 7-bit fraction), currently only supported by convolutions, tensor set, and tensor copy.


Expand Down
11 changes: 5 additions & 6 deletions include/miopen/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,11 @@ MIOPEN_DECLARE_OBJECT(miopenReduceTensorDescriptor);
*/
typedef enum
{
miopenHalf = 0, /*!< 16-bit floating point (Fully supported) */
miopenFloat = 1, /*!< 32-bit floating point (Fully supported) */
miopenInt32 = 2, /*!< 32-bit int point (Partially supported) */
miopenInt8 = 3, /*!< 8-bit int point (Partially supported) */
miopenInt8x4 =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removal is necessary to remove useless cases from the library (otherwise tidy checks would fail).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted. Why: #2441 (comment)

averinevg marked this conversation as resolved.
Show resolved Hide resolved
4, /*!< Pack of four 8-bit int points in NCHW_VECT_C format (Partially supported) */
miopenHalf = 0, /*!< 16-bit floating point (Fully supported) */
miopenFloat = 1, /*!< 32-bit floating point (Fully supported) */
miopenInt32 = 2, /*!< 32-bit int point (Partially supported) */
miopenInt8 = 3, /*!< 8-bit int point (Partially supported) */
atamazov marked this conversation as resolved.
Show resolved Hide resolved
miopenInt8x4 = 4, /*!< Pack of four Int8 in NCHW_VECT_C format (Support discontinued) */
miopenBFloat16 = 5, /*!< 16-bit binary floating point (8-bit exponent, 7-bit fraction)
(Partially supported) */
miopenDouble = 6, /*!< 64-bit floating point (Partially supported) */
Expand Down
2 changes: 1 addition & 1 deletion src/check_numerics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ std::string GetKernelName(miopenDataType_t data_type)
case miopenBFloat8: return {"check_numerics_bf8"};
case miopenInt32:
case miopenInt8:
case miopenInt8x4:
case miopenInt8x4: // Support discontinued.
case miopenDouble:
default: return {""};
}
Expand Down
2 changes: 1 addition & 1 deletion src/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ ConvolutionDescriptor::GetForwardOutputTensorWithLayout(const TensorDescriptor&
std::vector<std::size_t> out_strides;
tensor_layout_to_strides(
out_lens, default_layout, yLayout, xDesc.GetVectorLength(), out_strides);
return {(xDesc.GetType() == miopenInt8 || xDesc.GetType() == miopenInt8x4
return {(xDesc.GetType() == miopenInt8
? (yType)
: xDesc.GetType()), // TODO: This function overrides the output type with
// essentially the input which is incorrect.
Expand Down
52 changes: 15 additions & 37 deletions src/gemm_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@
/// "disabled expansion of recursive macro" injected by rocblas headers.
#define AVOID_ROCBLAS_WRAPPERS_204 (MIOPEN_ROCBLAS_VERSION_FLAT >= 2004000)

/// Maintain API compatibility with various rocBLAS version
#define USE_GEMM_FLAGS_PACK_INT8X4 \
((MIOPEN_ROCBLAS_VERSION_FLAT >= 2038000) && (MIOPEN_ROCBLAS_VERSION_FLAT < 4000000))

/// Maintain API compatibility for versions not supporting FP16 alternate implementations
#define USE_GEMM_FLAGS_FP16_ALT_IMPL (MIOPEN_ROCBLAS_VERSION_FLAT >= 2043000)
/// Some 2.42 versions have rocblas_gemm_flags_fp16_alt_impl, but
Expand Down Expand Up @@ -110,7 +106,7 @@ static inline rocblas_datatype rocBlasComputeType(const miopen::GemmDescriptor&
{
// Complex compute types are only supported in newer version of the API
assert(desc.dataType == desc.a_cast_type && desc.dataType == desc.b_cast_type);
if(desc.dataType == miopenInt8 || desc.dataType == miopenInt8x4)
if(desc.dataType == miopenInt8)
return rocblas_datatype::rocblas_datatype_i32_r;
else
return rocblas_datatype::rocblas_datatype_f32_r;
Expand Down Expand Up @@ -441,7 +437,6 @@ miopenStatus_t CallGemm(const Handle& handle,

switch(gemm_desc.dataType)
{
case miopenInt8x4:
case miopenInt8: {
assert(gemm_desc.k % 4 == 0);

Expand Down Expand Up @@ -473,12 +468,7 @@ miopenStatus_t CallGemm(const Handle& handle,
rocBlasComputeType(gemm_desc), // rocblas_datatype::rocblas_datatype_i32_r,
rocblas_gemm_algo::rocblas_gemm_algo_standard,
0,
#if USE_GEMM_FLAGS_PACK_INT8X4
rocblas_gemm_flags_pack_int8x4
#else
0
#endif
);
0);
}
break;
case miopenInt32: break;
Expand Down Expand Up @@ -622,9 +612,9 @@ miopenStatus_t CallGemm(const Handle& handle,
};
break;

case miopenInt8x4:
case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
"miopenDouble data type not supported by MIOpenGEMM.");
MIOPEN_THROW(miopenStatusBadParm, "Unknown or unsupported data type.");
};
break;
}
Expand Down Expand Up @@ -695,7 +685,6 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,

switch(gemm_desc.dataType)
{
case miopenInt8x4:
case miopenInt8: {
assert(gemm_desc.k % 4 == 0);

Expand Down Expand Up @@ -731,12 +720,7 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,
rocblas_datatype::rocblas_datatype_i32_r,
rocblas_gemm_algo::rocblas_gemm_algo_standard,
0,
#if USE_GEMM_FLAGS_PACK_INT8X4
rocblas_gemm_flags_pack_int8x4
#else
0
#endif
);
0);
}
break;
case miopenInt32: break;
Expand Down Expand Up @@ -895,10 +879,10 @@ miopenStatus_t CallGemmStridedBatched(const Handle& handle,
break;
}

case miopenInt8x4:
case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
"miopenDouble data type not supported by MIOpenGEMM.");
}
MIOPEN_THROW(miopenStatusBadParm, "Unknown or unsupported data type.");
};
break;
}

Expand Down Expand Up @@ -971,7 +955,6 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,

switch(gemm_desc.dataType)
{
case miopenInt8x4:
case miopenInt8: {
assert(gemm_desc.k % 4 == 0);

Expand Down Expand Up @@ -1005,12 +988,7 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,
rocBlasComputeType(gemm_desc), // rocblas_datatype::rocblas_datatype_i32_r,
rocblas_gemm_algo::rocblas_gemm_algo_standard,
0,
#if USE_GEMM_FLAGS_PACK_INT8X4
rocblas_gemm_flags_pack_int8x4
#else
0
#endif
);
0);
}
}
break;
Expand Down Expand Up @@ -1166,10 +1144,10 @@ miopenStatus_t CallGemmStridedBatchedSequential(const Handle& handle,
break;
}

case miopenInt8x4:
case miopenDouble: {
MIOPEN_THROW(miopenStatusBadParm,
"miopenDouble data type not supported by MIOpenGEMM.");
}
MIOPEN_THROW(miopenStatusBadParm, "Unknown or unsupported data type.");
};
break;
}

Expand Down Expand Up @@ -1199,7 +1177,7 @@ GemmDescriptor CreateGemmDescriptorConvFwd(const TensorDescriptor& wDesc,
{
#ifndef NDEBUG
assert(wDesc.GetType() == xDesc.GetType());
if(wDesc.GetType() != miopenInt8 && wDesc.GetType() != miopenInt8x4)
if(wDesc.GetType() != miopenInt8)
assert(wDesc.GetType() == yDesc.GetType());
#endif

Expand Down Expand Up @@ -1354,7 +1332,7 @@ GemmDescriptor CreateGemmDescriptorConvCNHWFwd(const TensorDescriptor& wDesc,
{
#ifndef NDEBUG
assert(wDesc.GetType() == xDesc.GetType());
if(wDesc.GetType() != miopenInt8 && wDesc.GetType() != miopenInt8x4)
if(wDesc.GetType() != miopenInt8)
assert(wDesc.GetType() == yDesc.GetType());
#endif

Expand Down Expand Up @@ -1458,7 +1436,7 @@ GemmDescriptor CreateGemmStridedBatchedDescriptorConv1x1Fwd(const TensorDescript
{
#ifndef NDEBUG
assert(wDesc.GetType() == xDesc.GetType());
if(wDesc.GetType() != miopenInt8 && wDesc.GetType() != miopenInt8x4)
if(wDesc.GetType() != miopenInt8)
assert(wDesc.GetType() == yDesc.GetType());
#else
(void)yDesc;
Expand Down
6 changes: 5 additions & 1 deletion src/hip/batched_transpose_sol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,11 @@ BatchedTransposeSolution::BatchedTransposeSolution(const ExecutionContext& ctx,
uint32_t width_)
: data_type(data_type_), batch(batch_), height(height_), width(width_)
{
if(data_type == miopenInt8x4 || data_type == miopenDouble)
if(!(data_type == miopenHalf //
|| data_type == miopenFloat //
|| data_type == miopenInt32 //
|| data_type == miopenInt8 //
|| data_type == miopenBFloat16))
amberhassaan marked this conversation as resolved.
Show resolved Hide resolved
MIOPEN_THROW("These data type are not supported");
num_cu = ctx.GetStream().GetMaxComputeUnits();
std::size_t data_size = miopen::GetTypeSize(data_type);
Expand Down
12 changes: 6 additions & 6 deletions src/include/miopen/datatype.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ inline std::string GetDataType(miopenDataType_t type)
type_str = "bfloat16";
}
break;
case miopenInt8x4:
case miopenInt8x4: {
type_str = "UNSUPPORTED_TYPE";
}
break;
case miopenInt8: {
type_str = "int8_t";
}
Expand Down Expand Up @@ -137,7 +140,6 @@ inline KernelBuildParameters GetDataTypeKBP(miopenDataType_t type)
int use_fp16x8 = 0;
int use_fp32 = 0;
int use_int8 = 0;
int use_int8x4 = 0;
int use_int32 = 0;
int use_bfp16 = 0;
int use_fp64 = 0;
Expand All @@ -150,15 +152,14 @@ inline KernelBuildParameters GetDataTypeKBP(miopenDataType_t type)
case miopenHalf: use_fp16 = 1; break;
case miopenFloat: use_fp32 = 1; break;
case miopenInt8: use_int8 = 1; break;
case miopenInt8x4: use_int8x4 = 1; break;
case miopenBFloat16: use_bfp16 = 1; break;
case miopenInt32: use_int32 = 1; break;
case miopenDouble: use_fp64 = 1; break;
case miopenFloat8: use_fp8 = 1; break;
case miopenBFloat8: use_bfp8 = 1; break;
case miopenInt8x4: // fallthrough
default:
MIOPEN_THROW(
"Only float, half, bfloat16, int8, int8x4, float8, bfloat8 data type is supported.");
MIOPEN_THROW("Only float, half, bfloat16, int8, float8, bfloat8 data types are supported.");
atamazov marked this conversation as resolved.
Show resolved Hide resolved
break;
}

Expand All @@ -168,7 +169,6 @@ inline KernelBuildParameters GetDataTypeKBP(miopenDataType_t type)
{"MIOPEN_USE_FP16x8", use_fp16x8},
{"MIOPEN_USE_FP32", use_fp32},
{"MIOPEN_USE_INT8", use_int8},
{"MIOPEN_USE_INT8x4", use_int8x4},
{"MIOPEN_USE_BFP16", use_bfp16},
{"MIOPEN_USE_INT32", use_int32},
{"MIOPEN_USE_RNE_BFLOAT16", use_rne_bfloat16},
Expand Down
4 changes: 2 additions & 2 deletions src/include/miopen/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ inline std::size_t GetTypeSize(miopenDataType_t d)
case miopenFloat: return 4;
case miopenHalf:
case miopenBFloat16: return 2;
case miopenInt8x4:
case miopenInt8x4: break;
case miopenInt8:
case miopenFloat8:
case miopenBFloat8: return 1;
case miopenDouble: return 8;
}
MIOPEN_THROW("Unknown data type");
MIOPEN_THROW("Unknown or unsupported data type");
}

template <class X, class Y>
Expand Down
2 changes: 1 addition & 1 deletion src/include/miopen/visit_float.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ void visit_float(miopenDataType_t t, F f)
}
case miopenFloat8:
case miopenBFloat8:
case miopenInt8x4:
case miopenInt8: {
f(as_float<int8_t>{});
break;
Expand All @@ -92,6 +91,7 @@ void visit_float(miopenDataType_t t, F f)
f(as_float<double>{});
break;
}
case miopenInt8x4: MIOPEN_THROW("miopenInt8x4: Support discontinued.");
}
}

Expand Down
6 changes: 0 additions & 6 deletions src/kernels/MIOpenIm2d2Col.cl
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#ifndef MIOPEN_USE_INT32
#define MIOPEN_USE_INT32 0
#endif
Expand All @@ -58,8 +54,6 @@

#if MIOPEN_USE_INT8 || MIOPEN_USE_FP8 || MIOPEN_USE_BFP8
typedef char data_t;
#elif MIOPEN_USE_INT8x4
typedef uint data_t;
#elif MIOPEN_USE_INT32
typedef int data_t;
#elif(MIOPEN_USE_FP16 || MIOPEN_USE_BFP16)
Expand Down
6 changes: 0 additions & 6 deletions src/kernels/MIOpenIm3d2Col.cl
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,12 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#ifndef MIOPEN_USE_INT32
#define MIOPEN_USE_INT32 0
#endif

#if MIOPEN_USE_INT8
typedef char data_t;
#elif MIOPEN_USE_INT8x4
typedef uint data_t;
#elif MIOPEN_USE_INT32
typedef int data_t;
#elif(MIOPEN_USE_FP16 || MIOPEN_USE_BFP16)
Expand Down
6 changes: 1 addition & 5 deletions src/kernels/MIOpenSubTensorOpWithScalarKernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,13 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#ifndef MIOPEN_USE_INT32
#define MIOPEN_USE_INT32 0
#endif

#include "float_types.h"

#if MIOPEN_USE_INT8 == 1 || MIOPEN_USE_INT8x4 == 1
#if MIOPEN_USE_INT8 == 1
#define _FLOAT char
#endif

Expand Down
6 changes: 1 addition & 5 deletions src/kernels/MIOpenSubTensorOpWithSubTensorKernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#if MIOPEN_USE_INT8 == 1 || MIOPEN_USE_INT8x4 == 1
#if MIOPEN_USE_INT8 == 1
#define _FLOAT char
#ifndef FLT_MAX
#define MAX_VAL 127 /* max value */
Expand Down
6 changes: 1 addition & 5 deletions src/kernels/MIOpenSubTensorOpWithTransformKernel.cl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@
#define MIOPEN_USE_INT8 0
#endif

#ifndef MIOPEN_USE_INT8x4
#define MIOPEN_USE_INT8x4 0
#endif

#if MIOPEN_USE_INT8 == 1 || MIOPEN_USE_INT8x4 == 1
#if MIOPEN_USE_INT8 == 1
#define _FLOAT char
#ifndef FLT_MAX
#define MAX_VAL 127 /* max value */
Expand Down
Loading