-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TunaNet Integration: MI250x #2421
Changes from 2 commits
37aed92
4a4f2b2
99d61a6
a542497
5b62d80
c0adce9
4dfef67
a7c864c
915659f
d4428ba
9bcf58b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,6 +145,47 @@ class Model | |
virtual std::vector<float> ToFeatures(const ProblemDescription& problem) const = 0; | ||
}; | ||
|
||
class Gfx90aModel : public Model | ||
{ | ||
public: | ||
Gfx90aModel() : Model("gfx90a") {} | ||
bool IsProblemSupported(const ProblemDescription& problem, | ||
const ConvolutionContext& ctx) const override | ||
CAHEK7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
return true; | ||
} | ||
|
||
protected: | ||
std::vector<float> ToFeatures(const ProblemDescription& problem) const override | ||
{ | ||
const bool isFwd = problem.GetDirection() == conv::Direction::Forward; | ||
std::vector<float> features = { | ||
static_cast<float>(isFwd ? problem.GetInChannels_() : problem.GetOutChannels_()), | ||
static_cast<float>(isFwd ? problem.GetInHeight_() : problem.GetOutHeight_()), | ||
static_cast<float>(isFwd ? problem.GetInWidth_() : problem.GetOutWidth_()), | ||
static_cast<float>(isFwd ? problem.GetOutChannels_() : problem.GetInChannels_()), | ||
static_cast<float>(isFwd ? problem.GetOutHeight_() : problem.GetInHeight_()), | ||
static_cast<float>(isFwd ? problem.GetOutWidth_() : problem.GetInWidth_()), | ||
static_cast<float>(problem.GetWeightsHeight_()), | ||
static_cast<float>(problem.GetWeightsWidth_()), | ||
static_cast<float>(problem.GetPadH()), | ||
static_cast<float>(problem.GetPadW()), | ||
static_cast<float>(problem.GetKernelStrideH()), | ||
static_cast<float>(problem.GetKernelStrideW()), | ||
static_cast<float>(problem.GetDilationH()), | ||
static_cast<float>(problem.GetDilationW()), | ||
static_cast<float>(problem.GetOutBatchSize_()), | ||
static_cast<float>(metadata.EncodePrecision(problem.GetInDataType())), | ||
static_cast<float>(metadata.EncodeDirection(problem.GetDirection())), | ||
static_cast<float>(problem.GetGroupCount())}; | ||
|
||
// normalize | ||
for(size_t i = 0; i < features.size(); ++i) | ||
features[i] = (features[i] - metadata.features_mean[i]) / metadata.features_std[i]; | ||
|
||
return features; | ||
} | ||
}; | ||
class Gfx908Model : public Model | ||
{ | ||
public: | ||
|
@@ -255,7 +296,12 @@ class Gfx908Model : public Model | |
} | ||
}; | ||
|
||
std::unique_ptr<Model> GetModel(const std::string&) { return std::make_unique<Gfx908Model>(); } | ||
std::unique_ptr<Model> GetModel(const std::string& device) { | ||
if (device == "gfx90a") | ||
return std::make_unique<Gfx90aModel>(); | ||
else | ||
return std::make_unique<Gfx908Model>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to maintain the previous behavior, but on going through the code again, I think it was a bug. I have fixed it to return a nullpointer if the device is not gfx90a or gfx908 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @CAHEK7 Done. I now return a nullptr if the device is not supported There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @CAHEK7 I have revered this back after discussing with JD. If we have a model specifically trained for a given GPU, we use it. If we don't, we use the model for MI100. The assumption is that the fastest solver remains the same no matter the GPU (this is actually not necessarily the case, but we still want to have a heuristic even if we do not have a specially designed model). |
||
} | ||
|
||
std::vector<uint64_t> PredictSolver(const ProblemDescription& problem, | ||
const ExecutionContext& ctx, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
{ | ||
"generated_on": "15 Sep 2023, 09:39:00", | ||
"database": "tuna_net2", | ||
"gpu": { | ||
"arch": "gfx90a", | ||
"num_cu": "110" | ||
}, | ||
"golden_v": null, | ||
"kernels": "hip", | ||
"num_inputs": 18, | ||
"num_algos": 12, | ||
"num_solvers": 23, | ||
"num_outputs": 36, | ||
"encodings": { | ||
"Direction": { | ||
"B": 0, | ||
"W": 1, | ||
"F": 2 | ||
}, | ||
"Precision": { | ||
"FP32": 0, | ||
"FP16": 1, | ||
"BF16": 2 | ||
}, | ||
"Layout": { | ||
"NCHW": 0 | ||
}, | ||
"algorithm": { | ||
"miopenConvolutionBwdDataAlgoDirect": 0, | ||
"miopenConvolutionBwdWeightsAlgoImplicitGEMM": 1, | ||
"miopenConvolutionBwdDataAlgoWinograd": 2, | ||
"miopenConvolutionBwdWeightsAlgoWinograd": 3, | ||
"miopenConvolutionBwdDataAlgoImplicitGEMM": 4, | ||
"miopenConvolutionBwdWeightsAlgoDirect": 5, | ||
"miopenConvolutionFwdAlgoImplicitGEMM": 6, | ||
"miopenConvolutionFwdAlgoWinograd": 7, | ||
"miopenConvolutionBwdDataAlgoGEMM": 8, | ||
"miopenConvolutionBwdWeightsAlgoGEMM": 9, | ||
"miopenConvolutionFwdAlgoDirect": 10, | ||
"miopenConvolutionFwdAlgoGEMM": 11 | ||
}, | ||
"solver": { | ||
"ConvDirectNaiveConvBwd": 0, | ||
"ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC": 1, | ||
"ConvBinWinogradRxSf2x3g1": 2, | ||
"ConvBinWinogradRxSf2x3": 3, | ||
"ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC": 4, | ||
"ConvDirectNaiveConvWrw": 5, | ||
"ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC": 6, | ||
"GemmBwd1x1_stride1": 7, | ||
"GemmWrw1x1_stride1": 8, | ||
"ConvDirectNaiveConvFwd": 9, | ||
"ConvBinWinogradRxSf3x2": 10, | ||
"GemmFwd1x1_0_1": 11, | ||
"GemmBwdRest": 12, | ||
"ConvWinograd3x3MultipassWrW<3-6>": 13, | ||
"GemmFwd1x1_0_2": 14, | ||
"ConvWinograd3x3MultipassWrW<3-5>": 15, | ||
"ConvCkIgemmFwdV6r1DlopsNchw": 16, | ||
"GemmBwd1x1_stride2": 17, | ||
"GemmWrwUniversal": 18, | ||
"GemmFwdRest": 19, | ||
"ConvWinograd3x3MultipassWrW<3-2>": 20, | ||
"ConvWinograd3x3MultipassWrW<3-4>": 21, | ||
"ConvWinograd3x3MultipassWrW<5-4>": 22 | ||
} | ||
}, | ||
"stats": { | ||
"overall": { | ||
"features": { | ||
"mean": { | ||
"Inp_0": 325.4046325683594, | ||
"Inp_2": 209.39981079101562, | ||
"Inp_3": 222.79859924316406, | ||
"Out_0": 305.1923828125, | ||
"Out_2": 148.5453338623047, | ||
"Out_3": 158.0012969970703, | ||
"Fil_1": 2.183664083480835, | ||
"Fil_2": 2.188920259475708, | ||
"Pad_1": 0.5829629898071289, | ||
"Pad_2": 0.5821362733840942, | ||
"Str_1": 1.3522303104400635, | ||
"Str_2": 1.3522303104400635, | ||
"Dil_1": 1.0099201202392578, | ||
"Dil_2": 1.0099201202392578, | ||
"BatchSize": 21.316818237304688, | ||
"Precision": 0.10632874816656113, | ||
"Direction": 1.0000261068344116, | ||
"GroupSize": 1.4064335823059082 | ||
}, | ||
"std": { | ||
"Inp_0": 394.53448486328125, | ||
"Inp_2": 247.1642303466797, | ||
"Inp_3": 261.4166564941406, | ||
"Out_0": 388.3568420410156, | ||
"Out_2": 131.25115966796875, | ||
"Out_3": 138.9573974609375, | ||
"Fil_1": 1.852091670036316, | ||
"Fil_2": 1.8830686807632446, | ||
"Pad_1": 1.0488911867141724, | ||
"Pad_2": 1.0482264757156372, | ||
"Str_1": 0.4814859628677368, | ||
"Str_2": 0.4814859628677368, | ||
"Dil_1": 0.5027242302894592, | ||
"Dil_2": 0.5027242302894592, | ||
"BatchSize": 99.3274917602539, | ||
"Precision": 0.39628151059150696, | ||
"Direction": 0.8165163397789001, | ||
"GroupSize": 8.734230995178223 | ||
} | ||
}, | ||
"gt": { | ||
"mean": { | ||
"solver": 6.611753463745117, | ||
"algorithm": 5.844751834869385, | ||
"solverTime": 0.8747590780258179 | ||
}, | ||
"std": { | ||
"solver": 5.101718902587891, | ||
"algorithm": 3.8398962020874023, | ||
"solverTime": 8.71385669708252 | ||
} | ||
} | ||
}, | ||
"train": { | ||
"features": { | ||
"mean": { | ||
"Inp_0": 325.3140563964844, | ||
"Inp_2": 209.40371704101562, | ||
"Inp_3": 222.7877655029297, | ||
"Out_0": 305.2548522949219, | ||
"Out_2": 148.56141662597656, | ||
"Out_3": 158.00404357910156, | ||
"Fil_1": 2.183530569076538, | ||
"Fil_2": 2.1888225078582764, | ||
"Pad_1": 0.5826140642166138, | ||
"Pad_2": 0.5817989706993103, | ||
"Str_1": 1.3521391153335571, | ||
"Str_2": 1.3521391153335571, | ||
"Dil_1": 1.0094914436340332, | ||
"Dil_2": 1.0094914436340332, | ||
"BatchSize": 21.304414749145508, | ||
"Precision": 0.10644834488630295, | ||
"Direction": 1.0000931024551392, | ||
"GroupSize": 1.406988263130188 | ||
}, | ||
"std": { | ||
"Inp_0": 394.18212890625, | ||
"Inp_2": 247.01255798339844, | ||
"Inp_3": 261.5453796386719, | ||
"Out_0": 388.3930358886719, | ||
"Out_2": 131.1672821044922, | ||
"Out_3": 139.0129852294922, | ||
"Fil_1": 1.8518221378326416, | ||
"Fil_2": 1.8826038837432861, | ||
"Pad_1": 1.0420876741409302, | ||
"Pad_2": 1.041419267654419, | ||
"Str_1": 0.48132652044296265, | ||
"Str_2": 0.48132652044296265, | ||
"Dil_1": 0.488539457321167, | ||
"Dil_2": 0.488539457321167, | ||
"BatchSize": 98.52013397216797, | ||
"Precision": 0.3965224325656891, | ||
"Direction": 0.8164235353469849, | ||
"GroupSize": 8.81924057006836 | ||
} | ||
}, | ||
"gt": { | ||
"mean": { | ||
"solver": 6.609450340270996, | ||
"algorithm": 5.8430094718933105, | ||
"solverTime": 0.8733705282211304 | ||
}, | ||
"std": { | ||
"solver": 5.101729393005371, | ||
"algorithm": 3.8403995037078857, | ||
"solverTime": 8.770111083984375 | ||
} | ||
} | ||
}, | ||
"test": { | ||
"features": { | ||
"mean": { | ||
"Inp_0": 327.8919982910156, | ||
"Inp_2": 209.28514099121094, | ||
"Inp_3": 222.74436950683594, | ||
"Out_0": 304.9135437011719, | ||
"Out_2": 148.28598022460938, | ||
"Out_3": 157.90194702148438, | ||
"Fil_1": 2.1848666667938232, | ||
"Fil_2": 2.189800500869751, | ||
"Pad_1": 0.5861029028892517, | ||
"Pad_2": 0.5851719379425049, | ||
"Str_1": 1.353050708770752, | ||
"Str_2": 1.353050708770752, | ||
"Dil_1": 1.0137779712677002, | ||
"Dil_2": 1.0137779712677002, | ||
"BatchSize": 21.428457260131836, | ||
"Precision": 0.10525237768888474, | ||
"Direction": 0.9994227886199951, | ||
"GroupSize": 1.4014410972595215 | ||
}, | ||
"std": { | ||
"Inp_0": 397.6885681152344, | ||
"Inp_2": 248.52734375, | ||
"Inp_3": 260.25750732421875, | ||
"Out_0": 388.0343017578125, | ||
"Out_2": 132.00457763671875, | ||
"Out_3": 138.4573211669922, | ||
"Fil_1": 1.8545325994491577, | ||
"Fil_2": 1.8872634172439575, | ||
"Pad_1": 1.1082494258880615, | ||
"Pad_2": 1.10761559009552, | ||
"Str_1": 0.4829222559928894, | ||
"Str_2": 0.4829222559928894, | ||
"Dil_1": 0.6158485412597656, | ||
"Dil_2": 0.6158485412597656, | ||
"BatchSize": 106.31908416748047, | ||
"Precision": 0.39410850405693054, | ||
"Direction": 0.817358672618866, | ||
"GroupSize": 7.928310871124268 | ||
} | ||
}, | ||
"gt": { | ||
"mean": { | ||
"solver": 6.632482528686523, | ||
"algorithm": 5.860433101654053, | ||
"solverTime": 0.8872132897377014 | ||
}, | ||
"std": { | ||
"solver": 5.10162353515625, | ||
"algorithm": 3.835364580154419, | ||
"solverTime": 8.19027042388916 | ||
} | ||
} | ||
} | ||
}, | ||
"conv_params_used_as_features": [ | ||
"Inp_0", | ||
"Inp_2", | ||
"Inp_3", | ||
"Out_0", | ||
"Out_2", | ||
"Out_3", | ||
"Fil_1", | ||
"Fil_2", | ||
"Pad_1", | ||
"Pad_2", | ||
"Str_1", | ||
"Str_2", | ||
"Dil_1", | ||
"Dil_2", | ||
"BatchSize", | ||
"Precision", | ||
"Direction", | ||
"GroupSize" | ||
], | ||
"redundant_columns": { | ||
"SpatialDim": 2.0, | ||
"Inp_1": 1.0, | ||
"Out_1": 1.0, | ||
"Fil_0": 1.0, | ||
"Pad_0": 0.0, | ||
"Str_0": 1.0, | ||
"Dil_0": 1.0, | ||
"BiasFlag": 0.0, | ||
"Layout": 0.0 | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a minor question - should it be
final
?I guess yes and I guess all of the other models derived from
Model
should be final as well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I'll mark it final.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CAHEK7 Done