Skip to content

Commit

Permalink
Merge pull request #1905 from yenong-amd/release/rocm-rel-6.1
Browse files Browse the repository at this point in the history
Hotfix: Fix MasterSolutionLibrary indexing for multiple architecture build (#1888)
  • Loading branch information
nakajee authored Apr 18, 2024
2 parents be9f7da + 2b55ccf commit bf05992
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 11 deletions.
60 changes: 55 additions & 5 deletions Tensile/SolutionLibrary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
################################################################################

import itertools
import re

from . import Properties
from . import Hardware
Expand Down Expand Up @@ -248,6 +249,28 @@ def remapSolutionIndices(self, indexMap):

class MasterSolutionLibrary:
StateKeys = ["solutions", "library"]
ArchitectureSet = set()

@classmethod
def ArchitectureIndexMap(cls, architectureName):
# 'fallback', 'gfx803', 'gfx900', 'gfx906', 'gfx908', 'gfx90a',
# 'gfx940', 'gfx941', 'gfx942', 'gfx1010', 'gfx1011', 'gfx1012',
# 'gfx1030', 'gfx1031', 'gfx1032', 'gfx1034', 'gfx1035', 'gfx1100',
# 'gfx1101', 'gfx1102'
archval = -1
if architectureName == "fallback":
archval = 0
elif architectureName.startswith("gfx"):
archString = re.search('(?<=gfx)[0-9a-f]*', architectureName)
if archString is not None:
archLiteral = archString.group(0)
archval = (int(archLiteral, 16) << 18)
# Check for duplicate architecture values
if archval >= 0 and not archval in cls.ArchitectureSet:
cls.ArchitectureSet.add(archval)
else:
raise RuntimeError("ERROR in architecture solution index mapping.")
return archval

@classmethod
def FixSolutionIndices(cls, solutions):
Expand Down Expand Up @@ -478,17 +501,44 @@ def applyNaming(self, naming=None):
s.name = OriginalSolution.getNameMin(s.originalSolution.getKernels()[0], naming)

def remapSolutionIndicesStartingFrom(self, curIndex):
if self.lazyLibraries:
lazyLibrary = {}
for name, lib in self.lazyLibraries.items():
reIndexMap = {}
newSolutions = {}

for k, s in lib.solutions.items():
reIndexMap[s.index] = curIndex
s.index = curIndex
newSolutions[curIndex] = s
curIndex += 1

lib.solutions = newSolutions
lib.library.remapSolutionIndices(reIndexMap)

lazyLibrary[name] = lib
self.lazyLibraries = lazyLibrary

reIndexMap = {}
solutionCopy = self.solutions
self.solutions = dict()
for k, s in solutionCopy.items():
newSolutions = {}
for k, s in self.solutions.items():
reIndexMap[s.index] = curIndex
s.index = curIndex
self.solutions[curIndex] = s
newSolutions[curIndex] = s
curIndex += 1

self.solutions = newSolutions
self.library.remapSolutionIndices(reIndexMap)

def insert(self, other):
assert self.__class__ == other.__class__

for name, lib in other.lazyLibraries.items():
self.lazyLibraries[name] = lib

for _, s in other.solutions.items():
self.solutions[s.index] = s
self.library.merge(other.library)

def merge(self, other, startIndex=0):
assert self.__class__ == other.__class__

Expand Down
7 changes: 7 additions & 0 deletions Tensile/Source/lib/source/UserDrivenTuningParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ namespace Tensile
strideB = 1;
strideC = 1;

if(b > 1)
{
strideA = ldA * (transA ? m : k);
strideB = ldB * (transB ? k : n);
strideC = ldC * n;
}

if(entries_n == 15)
{
// Expected layout: transA,transB,M,N,batch_count,K,alpha,beta,lda,ldb,ldc,input_type,output_type,compute_type,solution_index
Expand Down
37 changes: 31 additions & 6 deletions Tensile/TensileCreateLibrary.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,8 +911,7 @@ def generateLogicDataAndSolutions(logicFiles, args):
masterLibraries = {}
fullMasterLibrary = None

nextSolIndex = 0

nextSolIndex = {}
for logic in Utils.tqdm(libraries, "Processing logic data"):
(_, architectureName, _, solutionsForSchedule, _, newLibrary) = logic

Expand All @@ -923,10 +922,14 @@ def generateLogicDataAndSolutions(logicFiles, args):
masterLibraries[architectureName] = deepcopy(newLibrary)
masterLibraries[architectureName].version = args.version
elif globalParameters["SeparateArchitectures"] or globalParameters["LazyLibraryLoading"]:

if architectureName in masterLibraries:
nextSolIndex = masterLibraries[architectureName].merge(deepcopy(newLibrary), nextSolIndex)
nextSolIndex[architectureName] = masterLibraries[architectureName].merge(deepcopy(newLibrary), nextSolIndex[architectureName])
else:
masterLibraries[architectureName] = deepcopy(newLibrary)
archIndexMap = MasterSolutionLibrary.ArchitectureIndexMap(architectureName)
masterLibraries[architectureName].remapSolutionIndicesStartingFrom(archIndexMap)
nextSolIndex[architectureName] = archIndexMap
masterLibraries[architectureName].version = args.version
else:
if fullMasterLibrary is None:
Expand All @@ -944,8 +947,7 @@ def generateLogicDataAndSolutions(logicFiles, args):
if "fallback" in masterLibraries.keys():
for key, value in masterLibraries.items():
if key != "fallback":
value.merge(deepcopy(masterLibraries["fallback"]))

value.insert(deepcopy(masterLibraries["fallback"]))
masterLibraries.pop("fallback")

for _, masterLibrary in masterLibraries.items():
Expand Down Expand Up @@ -1017,6 +1019,23 @@ def WriteClientLibraryFromSolutions(solutionList, libraryWorkingPath, tensileSou

return (codeObjectFiles, newLibrary)

################################################################################
# Write Master Solution Index CSV
################################################################################
def writeMasterSolutionIndexCSV(outputPath, masterLibraries):
libraryPath = os.path.join(outputPath, "library")
ensurePath(libraryPath)
try:
with open(os.path.join(libraryPath, "TensileMasterSolutionIndex.csv"), "w") as indexFile:
indexFile.write("architectureName,libraryName,libraryIndex,solutionIndex,solutionName\n")
for arch,lib in masterLibraries.items():
for lazylibname,lazylibvals in lib.lazyLibraries.items():
for solidx,solution in lazylibvals.solutions.items():
line = ",".join(str(x) for x in [arch, lazylibname, solidx, solution.index, solution.name])
indexFile.write("%s\n" %(line))
except IOError as err:
print1("Error writing MasterSolutionIndex %s" % err)

################################################################################
# Tensile Create Library
################################################################################
Expand Down Expand Up @@ -1084,6 +1103,8 @@ def splitExtraParameters(par):
argParser.add_argument("--global-parameters", nargs="+", type=splitExtraParameters, default=[])
argParser.add_argument("--ignore-asm-cap-cache", dest="IgnoreAsmCapCache", action="store_true", default=False,
help="Ignore asm cap cache and derive the asm caps at runtime")
argParser.add_argument("--write-master-solution-index", dest="WriteMasterSolutionIndex", action="store_true",
default=False, help="Output master solution index in csv format.")
args = argParser.parse_args()

logicPath = args.LogicPath
Expand Down Expand Up @@ -1123,7 +1144,8 @@ def splitExtraParameters(par):
arguments["CpuThreads"] = args.CpuThreads
arguments["PrintLevel"] = args.PrintLevel
arguments["IgnoreAsmCapCache"] = args.IgnoreAsmCapCache

arguments["WriteMasterSolutionIndex"] = args.WriteMasterSolutionIndex

for key, value in args.global_parameters:
arguments[key] = value

Expand Down Expand Up @@ -1174,6 +1196,9 @@ def splitExtraParameters(par):
# Parse logicData, solutions, and masterLibraries from logic files
solutions, masterLibraries, fullMasterLibrary = generateLogicDataAndSolutions(logicFiles, args)

if globalParameters["LazyLibraryLoading"] and arguments["WriteMasterSolutionIndex"]:
writeMasterSolutionIndexCSV(outputPath, masterLibraries)

kernels, kernelHelperObjs, _ = generateKernelObjectsFromSolutions(solutions)

# if any kernels are assembly, append every ISA supported
Expand Down
71 changes: 71 additions & 0 deletions Tensile/Utilities/validate_library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
################################################################################
#
# Copyright (C) 2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
################################################################################

import argparse
import csv
import pathlib
import sys

from collections import defaultdict

def gather_data(indexFile):

indexData = defaultdict(set)
with open(indexFile, "rt") as csvfile:
indexreader = csv.DictReader(csvfile, delimiter=",")
for row in indexreader: # read a row as {column1: value1, column2: value2,...}
for key, value in row.items():
indexData[key].add(value)
return indexData

if __name__ == "__main__":
argParser = argparse.ArgumentParser()
argParser.add_argument("library_path", help="Tensile library path")
args = argParser.parse_args()
libraryPath = args.library_path
indexFileName = "TensileMasterSolutionIndex.csv"

# Check that path exists
if not pathlib.Path(libraryPath).is_dir():
print(f"ERROR: {libraryPath} does not exists.")
sys.exit(1)

# Check that TensileMasterSolutionIndex.csv exists
csvpath = pathlib.Path(libraryPath) / indexFileName
if not csvpath.is_file():
print(f"ERROR: {csvpath} does not exists.")
sys.exit(1)

data = gather_data(csvpath)

# List files in library path
datFiles = [f.stem for f in pathlib.Path(libraryPath).glob("*.dat")]
coFiles = [f.stem for f in pathlib.Path(libraryPath).glob("*.co")]
lazyArchFiles = [f for f in datFiles if "_lazy_" in f]
metaDataFiles = [f for f in datFiles if not "_lazy_" in f]
nonfallback = set([f for f in data['libraryName'] if not "fallback" in f])

print(f"MetaData files should match library names in index file: {set(metaDataFiles) == data['libraryName']}")
print(f"Asm files should match non-fallback library names in index file: {set(coFiles) == nonfallback}")
print(f"Lazy library files should match number of architectures in index file: {len(set(lazyArchFiles)) == len(data['architectureName'])}")

0 comments on commit bf05992

Please sign in to comment.