Skip to content

Commit

Permalink
Merge pull request #1902 from nakajee/release/rocm-rel-6.1
Browse files Browse the repository at this point in the history
Hotfix: Fix WorkspaceCheck implementation when used in rocBLAS
  • Loading branch information
nakajee authored Mar 15, 2024
2 parents d0314ce + e61b297 commit be9f7da
Show file tree
Hide file tree
Showing 31 changed files with 313 additions and 118 deletions.
38 changes: 38 additions & 0 deletions HostLibraryTests/LibYamlToMsgpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

################################################################################
#
# Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
################################################################################

import sys
import yaml
import msgpack

if __name__ == "__main__":
args = sys.argv[1:]
infile = args[0]
outfile = args[1]
with open(infile) as f:
data = yaml.load(f)
with open(outfile, 'wb') as f:
msgpack.dump(data, f)
Binary file modified HostLibraryTests/configs/SolutionLibraries/KernelsLite.dat.gz
Binary file not shown.
Binary file modified HostLibraryTests/configs/SolutionLibraries/KernelsLite.yaml.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
31 changes: 31 additions & 0 deletions HostLibraryTests/configs/SolutionLibraries/readme
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Sample libraries can be rebuilt using TensileCreateLibrary and rocBLAS build.

To rebuild rocBLAS_Full, run rocBLAS build script.
Sample library currently includes gfx803, gfx900, gfx906, and gfx908.
To build yaml version, include the --no-msgpack flag.

./install.sh -dc -t ~/tensile -a "gfx803;gfx900;gfx906;gfx908" --merge-architectures --no-lazy-library-loading
./install.sh -dc -t ~/tensile -a "gfx803;gfx900;gfx906;gfx908" --merge-architectures --no-lazy-library-loading --no-msgpack

SampleTensileKernels are small samples written manually.
To update, make any required updates to SampleTensileKernels.yaml and call the script to convert to msgpack

cd HostLibraryTests
./LibYamlToMsgpack.py configs/SolutionLibraries/SampleTensileKernels.yaml configs/SolutionLibraries/SampleTensileKernels.dat

Other libs can be rebuilt by calling TensileCreateLibrary.

KernelsLite:
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=yaml ../HostLibraryTests/configs/lite_configs/ . HIP
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=msgpack ../HostLibraryTests/configs/lite_configs/ . HIP
KernelsLiteMixed:
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=yaml ../HostLibraryTests/configs/lite_configs_mixed/ . HIP
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=msgpack ../HostLibraryTests/configs/lite_configs_mixed/ . HIP
KernelsLiteNavi:
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=yaml ../Tensile/Source/lib/configs/lite_configs/ . HIP
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=msgpack ../Tensile/Source/lib/configs/lite_configs/ . HIP
KernelsTileLite:
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=yaml ../HostLibraryTests/configs/tile_aware_selection/ . HIP
../Tensile/bin/TensileCreateLibrary --merge-files --code-object-version=default --library-format=msgpack ../HostLibraryTests/configs/tile_aware_selection/ . HIP

All libraries are checked in as .gz to reduce checkout size
Binary file not shown.
Binary file modified HostLibraryTests/configs/SolutionLibraries/rocBLAS_Full.yaml.gz
Binary file not shown.
2 changes: 2 additions & 0 deletions HostLibraryTests/llvm/LLVMYAMLContraction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ TEST(LLVMYAMLContractionTest, Simple)
"index: 0\n"
"hardwarePredicate: { type: TruePred }\n"
"problemPredicate: { type: TruePred }\n"
"taskPredicate: { type: TruePred }\n"
"debugKernel: false\n"
"problemType:\n"
" operationIdentifier: foo\n"
Expand Down Expand Up @@ -119,6 +120,7 @@ TEST(LLVMYAMLContractionTest, ContractionLibrary)
" index: 0\n"
" hardwarePredicate: { type: TruePred }\n"
" problemPredicate: { type: TruePred }\n"
" taskPredicate: { type: TruePred }\n"
" debugKernel: false\n"
" problemType:\n"
" operationIdentifier: foo\n"
Expand Down
4 changes: 4 additions & 0 deletions HostLibraryTests/sample_library.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ solutions:

hardwarePredicate: { type: TruePred }
problemPredicate: { type: TruePred }
taskPredicate: { type: TruePred }
info: {}
debugKernel: false
index: 0
Expand Down Expand Up @@ -51,6 +52,7 @@ solutions:

hardwarePredicate: { type: TruePred }
problemPredicate: { type: TruePred }
taskPredicate: { type: TruePred }
info: {}
debugKernel: false
index: 1
Expand Down Expand Up @@ -79,6 +81,7 @@ solutions:

hardwarePredicate: { type: TruePred }
problemPredicate: { type: TruePred }
taskPredicate: { type: TruePred }
info: {}
debugKernel: false
index: 2
Expand Down Expand Up @@ -107,6 +110,7 @@ solutions:

hardwarePredicate: { type: TruePred }
problemPredicate: { type: TruePred }
taskPredicate: { type: TruePred }
info: {}
debugKernel: false
index: 3
Expand Down
21 changes: 17 additions & 4 deletions Tensile/Contractions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
################################################################################
#
# Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -354,9 +354,6 @@ def FromOriginalKeyPair(cls, pair):

return cls(tag, index=index, value=value)

if key == "_WorkspaceSizePerElemC" and value > 0:
return cls("WorkspaceCheck", index=0, value=value)

if key.startswith('Assert'):
raise RuntimeError("Unknown assertion key: {}".format(key))

Expand Down Expand Up @@ -446,6 +443,19 @@ def FromOriginalState(cls, d, problemType, morePreds=[]):
predicates = [p for p in map(cls.FromOriginalKeyPair, d.items()) if p is not None] + extraPreds
return cls.And(predicates)

class TaskPredicate(Properties.Predicate):
@classmethod
def FromOriginalKeyPair(cls, pair):
(key, value) = pair
if key == "_WorkspaceSizePerElemC" and value > 0:
return cls("WorkspaceCheck")
return None

@classmethod
def FromOriginalState(cls, d, problemType, morePreds=[]):
predicates = [p for p in map(cls.FromOriginalKeyPair, d.items()) if p is not None]
return cls.And(predicates)

class SizeMapping:
StateKeys = ['workGroup',
'macroTile',
Expand Down Expand Up @@ -514,6 +524,7 @@ class Solution:
'problemType',
'hardwarePredicate',
'problemPredicate',
'taskPredicate',
'sizeMapping',
'debugKernel',
'libraryLogicIndex',
Expand All @@ -537,6 +548,7 @@ def FromOriginalState(cls, d, deviceInfo=None):
rv.problemType = ProblemType.FromOriginalState(d['ProblemType'])

rv.problemPredicate = ProblemPredicate.FromOriginalState(d, rv.problemType)
rv.taskPredicate = TaskPredicate.FromOriginalState(d, rv.problemType)

if 'DebugKernel' in d:
rv.debugKernel = d['DebugKernel']
Expand Down Expand Up @@ -579,6 +591,7 @@ def __init__(self, **kwargs):
self.problemType = None
self.hardwarePredicate = Hardware.HardwarePredicate('TruePred')
self.problemPredicate = ProblemPredicate('TruePred')
self.taskPredicate = TaskPredicate('TruePred')
self.sizeMapping = None
self.debugKernel = False
self.libraryLogicIndex = {}
Expand Down
8 changes: 5 additions & 3 deletions Tensile/Source/client/source/SolutionIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2020-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -88,15 +88,17 @@ namespace Tensile

// Test if the persistent kernel is eligible for the current hw and solution
m_problem.checkPersistentKernelEligibility(solution, *m_hardware);
m_problem.checkRequiredWorkspaceSize(solution, *m_hardware);
if(!(*solution.problemPredicate)(m_problem))
Task task(*m_hardware, m_problem, solution);
if(!(*solution.problemPredicate)(m_problem) || !(*solution.taskPredicate)(task))
{
m_reporter->report(ResultKey::Validation, "DID_NOT_SATISFY_ASSERTS");
if(m_reporter->logAtLevel(LogLevel::Verbose))
{
std::ostringstream msg;
solution.problemPredicate->debugEval(m_problem, msg);
msg << std::endl;
solution.taskPredicate->debugEval(task, msg);
msg << std::endl;
m_reporter->log(LogLevel::Verbose, msg.str());
}

Expand Down
10 changes: 1 addition & 9 deletions Tensile/Source/lib/include/Tensile/ContractionProblem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -820,19 +820,12 @@ namespace Tensile

void checkPersistentKernelEligibility(ContractionSolution const& solution,
Hardware const& hardware);
void checkRequiredWorkspaceSize(ContractionSolution const& solution,
Hardware const& hardware);

bool getPersistentKernelEligibility() const
{
return m_eligibleForPK;
}

size_t getRequiredWorkspaceSize() const
{
return m_requiredWorkspaceSize;
}

private:
TensorDescriptor m_a;
TensorDescriptor m_b;
Expand Down Expand Up @@ -860,7 +853,6 @@ namespace Tensile
bool m_fp16AltImpl = false;
bool m_fp16AltImplRound = false;
bool m_stochasticRounding = false;
size_t m_requiredWorkspaceSize = 0;
DataType m_f32XdlMathOp = DataType::Float;
ArithmeticUnit m_arithmeticUnit = ArithmeticUnit::Any;
KernelLanguage m_kernelLanguage = KernelLanguage::Any;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -1236,45 +1236,6 @@ namespace Tensile
}
};

struct WorkspaceCheck : public Predicate_CRTP<WorkspaceCheck, ContractionProblem>
{
enum
{
HasIndex = true,
HasValue = true
};
size_t index;
size_t value;

WorkspaceCheck() = default;
WorkspaceCheck(size_t index, size_t value)
: index(index)
, value(value)
{
}

static std::string Type()
{
return "WorkspaceCheck";
}

virtual bool operator()(ContractionProblem const& problem) const override
{
return problem.getRequiredWorkspaceSize() <= problem.workspaceSize();
}

virtual bool debugEval(ContractionProblem const& problem,
std::ostream& stream) const override
{
bool rv = (*this)(problem);

stream << *this << ": (" << problem.getRequiredWorkspaceSize()
<< " <= " << problem.workspaceSize() << ") == " << rv;

return rv;
}
};

struct PersistentKernelCheck
: public Predicate_CRTP<PersistentKernelCheck, ContractionProblem>
{
Expand Down
5 changes: 4 additions & 1 deletion Tensile/Source/lib/include/Tensile/ContractionSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2019-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -37,6 +37,7 @@
#include <Tensile/ContractionProblem_fwd.hpp>
#include <Tensile/DataTypes.hpp>
#include <Tensile/Predicates.hpp>
#include <Tensile/Task.hpp>
#include <Tensile/Utils.hpp>

namespace Tensile
Expand Down Expand Up @@ -324,6 +325,8 @@ namespace Tensile
bool debugKernel = false;
bool kernelArgsLog = false;

std::shared_ptr<Predicates::Predicate<Task>> taskPredicate
= std::make_shared<Predicates::True<Task>>();
std::shared_ptr<Predicates::Predicate<Problem>> problemPredicate
= std::make_shared<Predicates::True<Problem>>();
std::shared_ptr<Predicates::Predicate<Hardware>> hardwarePredicate
Expand Down
Loading

0 comments on commit be9f7da

Please sign in to comment.