Skip to content

Commit

Permalink
[Enhancement] Add checks on workspace params (#2498)
Browse files Browse the repository at this point in the history
* added checks on workspace params

* addressed review comments

* fix release build warning
  • Loading branch information
amberhassaan authored Nov 17, 2023
1 parent 7275429 commit aa878a8
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/ocl/convolutionocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ static inline void ValidateGroupCount(const TensorDescriptor& x,
MIOPEN_THROW(miopenStatusBadParm, "Invalid group number");
}

static inline void ValidateWorkspace(Data_t workSpace, const size_t workSpaceSize)
{

[[maybe_unused]] bool x = (workSpace != nullptr);
[[maybe_unused]] bool y = (workSpaceSize != 0);

assert(((x && y) || (!x && !y)) && "workspace pointer and size don't match. Either both should "
"be zero or both should be non-zero");

/// \todo could add a check here that workSpace points to GPU memory
}

static Invoker PrepareInvoker(ExecutionContext ctx,
const conv::ProblemDescription& problem,
const NetworkConfig& config,
Expand Down Expand Up @@ -260,6 +272,7 @@ void ConvolutionDescriptor::FindConvFwdAlgorithm(Handle& handle,
bool exhaustiveSearch) const
{
MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
if(x == nullptr || w == nullptr || y == nullptr)
MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL");
if(returnedAlgoCount == nullptr)
Expand Down Expand Up @@ -496,6 +509,7 @@ void ConvolutionDescriptor::ConvolutionForward(Handle& handle,
size_t workSpaceSize) const
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);

const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y};
ValidateTensors(tensors);
Expand Down Expand Up @@ -813,6 +827,7 @@ void ConvolutionDescriptor::ConvolutionForwardImmediate(Handle& handle,
const solver::Id solver_id) const
{
MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y};

ValidateTensors(tensors);
Expand Down Expand Up @@ -847,6 +862,7 @@ void ConvolutionDescriptor::FindConvBwdDataAlgorithm(Handle& handle,
bool exhaustiveSearch) const
{
MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
if(dx == nullptr || w == nullptr || dy == nullptr)
MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL");
if(returnedAlgoCount == nullptr)
Expand Down Expand Up @@ -946,6 +962,7 @@ void ConvolutionDescriptor::ConvolutionBackwardData(Handle& handle,
size_t workSpaceSize) const
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);

auto tensors = ConvBwdTensors{dyDesc, dy, wDesc, w, dxDesc, dx};

Expand Down Expand Up @@ -1017,6 +1034,7 @@ void ConvolutionDescriptor::ConvolutionBackwardImmediate(Handle& handle,
solver::Id solver_id) const
{
MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
auto tensors = ConvBwdTensors{dyDesc, dy, wDesc, w, dxDesc, dx};

ValidateTensors(tensors);
Expand Down Expand Up @@ -1057,6 +1075,7 @@ void ConvolutionDescriptor::FindConvBwdWeightsAlgorithm(Handle& handle,
bool exhaustiveSearch) const
{
MIOPEN_LOG_I("requestAlgoCount = " << requestAlgoCount << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
if(x == nullptr || dw == nullptr || dy == nullptr)
MIOPEN_THROW(miopenStatusBadParm, "Buffers cannot be NULL");
if(returnedAlgoCount == nullptr)
Expand Down Expand Up @@ -1154,6 +1173,7 @@ void ConvolutionDescriptor::ConvolutionBackwardWeights(const Handle& handle,
size_t workSpaceSize) const
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
decltype(auto) tensors = ConvWrwTensors{dyDesc, dy, xDesc, x, dwDesc, dw};
ValidateTensors(tensors);
ValidateAlphaBeta(alpha, beta);
Expand Down Expand Up @@ -1221,6 +1241,7 @@ void ConvolutionDescriptor::ConvolutionWrwImmediate(Handle& handle,
solver::Id solver_id) const
{
MIOPEN_LOG_I("solver_id = " << solver_id.ToString() << ", workspace = " << workSpaceSize);
ValidateWorkspace(workSpace, workSpaceSize);
auto tensors = ConvWrwTensors{dyDesc, dy, xDesc, x, dwDesc, dw};
ValidateTensors(tensors);

Expand Down

0 comments on commit aa878a8

Please sign in to comment.