Skip to content

Commit

Permalink
Composable kernel init integration v3 (#1097)
Browse files Browse the repository at this point in the history
* Squashed 'src/composable_kernel/' content from commit f6edda611

git-subtree-dir: src/composable_kernel
git-subtree-split: f6edda6119ebbb237dfa6270797b34f960d7b190

* add solver ConvIgemmFwdV6r1DlopsNchwKcyxNkhw; rename static ck source files

* Squashed 'src/composable_kernel/' changes from f6edda611..5781adf5c

5781adf5c Update develop (#5) (#6)
97e6d514f Merge pull request #4 from ROCmSoftwarePlatform/separate_online_compile
7b1ec41e5 refactor
49c33aaea refactor
54b3e73d1 rename

git-subtree-dir: src/composable_kernel
git-subtree-split: 5781adf5cf4ac753e2e36da7385791775b744bf7

* fix

* refactor

* remove online compilation from CK

* refactor

* fix

* add ctest

* add c-style pointer cast

* vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast

* fix clang warning suppression

* tidy

* suppress cppcheck

* fix enum issue

* revert chagnes to hip build

* fix kernel filename

* update CK build script

* rename

* rename

* make innner product compatiable on gfx900

* Update src/include/miopen/solver/ck_utility_common.hpp

Co-authored-by: JD <Jehandad.Khan@amd.com>

* compiler parameter use stream

* use int instead of index_t in kernel wrapper

* DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element

* refactor

* refactor

* change cmakelist

* change ck common utility

* fix

Co-authored-by: JD <Jehandad.Khan@amd.com>
  • Loading branch information
Chao Liu and JehandadKhan authored Aug 19, 2021
1 parent 946ee3a commit d3ee8a8
Show file tree
Hide file tree
Showing 250 changed files with 32,248 additions and 821 deletions.
18 changes: 13 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ add_definitions(-DBOOST_ALL_NO_LIB=1)
find_package(Boost REQUIRED COMPONENTS ${BOOST_COMPONENTS})

find_path(HALF_INCLUDE_DIR half.hpp)
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")

option( MIOPEN_DEBUG_FIND_DB_CACHING "Use system find-db caching" ON)

Expand Down Expand Up @@ -563,11 +564,18 @@ enable_cppcheck(
*:*src/sqlite/*.cpp
*:*.cl
*:*src/kernels/*.h
knownConditionTrueFalse:*src/kernels/composable_kernel/*/*
redundantAssignment:*src/kernels/composable_kernel/*/*
unreadVariable:*src/kernels/composable_kernel/*/*
unusedScopedObject:*src/kernels/composable_kernel/*/*
wrongPrintfScanfArgNum:*src/kernels/composable_kernel/*/*
knownConditionTrueFalse:*src/kernels/static_composable_kernel/*/*
redundantAssignment:*src/kernels/static_composable_kernel/*/*
unreadVariable:*src/kernels/static_composable_kernel/*/*
unusedScopedObject:*src/kernels/static_composable_kernel/*/*
wrongPrintfScanfArgNum:*src/kernels/static_composable_kernel/*/*
knownConditionTrueFalse:*src/composable_kernel/composable_kernel/*/*
identicalConditionAfterEarlyExit:*src/composable_kernel/composable_kernel/*/*
duplicateExpression:*src/composable_kernel/composable_kernel/*/*
multiCondition:*src/composable_kernel/composable_kernel/*/*
unreadVariable:*src/composable_kernel/composable_kernel/*/*
unreadVariable:*src/composable_kernel/host/*/*
unreadVariable:*src/composable_kernel/external/*/*
unmatchedSuppression
FORCE
SOURCES
Expand Down
263 changes: 141 additions & 122 deletions fin/src/base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,51 +36,60 @@
#include <algorithm>
#include <stdexcept>

//
// Depending on the url parameter in base64_chars, one of
// two sets of base64 characters needs to be chosen.
// They differ in their last two characters.
//
static const char* base64_chars[2] = {
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789"
"+/",

"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789"
"-_"};

static unsigned int pos_of_char(const unsigned char chr) {
//
// Return the position of chr within base64_encode()
//

if (chr >= 'A' && chr <= 'Z') return chr - 'A';
else if (chr >= 'a' && chr <= 'z') return chr - 'a' + ('Z' - 'A') + 1;
else if (chr >= '0' && chr <= '9') return chr - '0' + ('Z' - 'A') + ('z' - 'a') + 2;
else if (chr == '+' || chr == '-') return 62; // Be liberal with input and accept both url ('-') and non-url ('+') base 64 characters (
else if (chr == '/' || chr == '_') return 63; // Ditto for '/' and '_'
//
// Depending on the url parameter in base64_chars, one of
// two sets of base64 characters needs to be chosen.
// They differ in their last two characters.
//
static const char* base64_chars[2] = {"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789"
"+/",

"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789"
"-_"};

static unsigned int pos_of_char(const unsigned char chr)
{
//
// Return the position of chr within base64_encode()
//

if(chr >= 'A' && chr <= 'Z')
return chr - 'A';
else if(chr >= 'a' && chr <= 'z')
return chr - 'a' + ('Z' - 'A') + 1;
else if(chr >= '0' && chr <= '9')
return chr - '0' + ('Z' - 'A') + ('z' - 'a') + 2;
else if(chr == '+' || chr == '-')
return 62; // Be liberal with input and accept both url ('-') and non-url ('+') base 64
// characters (
else if(chr == '/' || chr == '_')
return 63; // Ditto for '/' and '_'
else
//
// 2020-10-23: Throw std::exception rather than const char*
//(Pablo Martin-Gomez, https://github.com/Bouska)
//
throw std::runtime_error("Input is not valid base64-encoded data.");
//
// 2020-10-23: Throw std::exception rather than const char*
//(Pablo Martin-Gomez, https://github.com/Bouska)
//
throw std::runtime_error("Input is not valid base64-encoded data.");
}

static std::string insert_linebreaks(std::string str, size_t distance) {
//
// Provided by https://github.com/JomaCorpFX, adapted by me.
//
if (!str.length()) {
static std::string insert_linebreaks(std::string str, size_t distance)
{
//
// Provided by https://github.com/JomaCorpFX, adapted by me.
//
if(!str.length())
{
return "";
}

size_t pos = distance;

while (pos < str.size()) {
while(pos < str.size())
{
str.insert(pos, "\n");
pos += distance + 1;
}
Expand All @@ -89,63 +98,75 @@ static std::string insert_linebreaks(std::string str, size_t distance) {
}

template <typename String, unsigned int line_length>
static std::string encode_with_line_breaks(String s) {
return insert_linebreaks(base64_encode(s, false), line_length);
static std::string encode_with_line_breaks(String s)
{
return insert_linebreaks(base64_encode(s, false), line_length);
}

template <typename String>
static std::string encode_pem(String s) {
return encode_with_line_breaks<String, 64>(s);
static std::string encode_pem(String s)
{
return encode_with_line_breaks<String, 64>(s);
}

template <typename String>
static std::string encode_mime(String s) {
return encode_with_line_breaks<String, 76>(s);
static std::string encode_mime(String s)
{
return encode_with_line_breaks<String, 76>(s);
}

template <typename String>
static std::string encode(String s, bool url) {
return base64_encode(reinterpret_cast<const unsigned char*>(s.data()), s.length(), url);
static std::string encode(String s, bool url)
{
return base64_encode(reinterpret_cast<const unsigned char*>(s.data()), s.length(), url);
}

std::string base64_encode(unsigned char const* bytes_to_encode, size_t in_len, bool url) {
std::string base64_encode(unsigned char const* bytes_to_encode, size_t in_len, bool url)
{

size_t len_encoded = (in_len +2) / 3 * 4;
size_t len_encoded = (in_len + 2) / 3 * 4;

unsigned char trailing_char = url ? '.' : '=';

//
// Choose set of base64 characters. They differ
// for the last two positions, depending on the url
// parameter.
// A bool (as is the parameter url) is guaranteed
// to evaluate to either 0 or 1 in C++ therefore,
// the correct character set is chosen by subscripting
// base64_chars with url.
//
//
// Choose set of base64 characters. They differ
// for the last two positions, depending on the url
// parameter.
// A bool (as is the parameter url) is guaranteed
// to evaluate to either 0 or 1 in C++ therefore,
// the correct character set is chosen by subscripting
// base64_chars with url.
//
const char* base64_chars_ = base64_chars[url];

std::string ret;
ret.reserve(len_encoded);

unsigned int pos = 0;

while (pos < in_len) {
while(pos < in_len)
{
ret.push_back(base64_chars_[(bytes_to_encode[pos + 0] & 0xfc) >> 2]);

if (pos+1 < in_len) {
ret.push_back(base64_chars_[((bytes_to_encode[pos + 0] & 0x03) << 4) + ((bytes_to_encode[pos + 1] & 0xf0) >> 4)]);

if (pos+2 < in_len) {
ret.push_back(base64_chars_[((bytes_to_encode[pos + 1] & 0x0f) << 2) + ((bytes_to_encode[pos + 2] & 0xc0) >> 6)]);
ret.push_back(base64_chars_[ bytes_to_encode[pos + 2] & 0x3f]);
}
else {
ret.push_back(base64_chars_[(bytes_to_encode[pos + 1] & 0x0f) << 2]);
ret.push_back(trailing_char);
}
if(pos + 1 < in_len)
{
ret.push_back(base64_chars_[((bytes_to_encode[pos + 0] & 0x03) << 4) +
((bytes_to_encode[pos + 1] & 0xf0) >> 4)]);

if(pos + 2 < in_len)
{
ret.push_back(base64_chars_[((bytes_to_encode[pos + 1] & 0x0f) << 2) +
((bytes_to_encode[pos + 2] & 0xc0) >> 6)]);
ret.push_back(base64_chars_[bytes_to_encode[pos + 2] & 0x3f]);
}
else
{
ret.push_back(base64_chars_[(bytes_to_encode[pos + 1] & 0x0f) << 2]);
ret.push_back(trailing_char);
}
}
else {
else
{

ret.push_back(base64_chars_[(bytes_to_encode[pos + 0] & 0x03) << 4]);
ret.push_back(trailing_char);
Expand All @@ -155,78 +176,81 @@ std::string base64_encode(unsigned char const* bytes_to_encode, size_t in_len, b
pos += 3;
}


return ret;
}

template <typename String>
static std::string decode(String encoded_string, bool remove_linebreaks) {
//
// decode(…) is templated so that it can be used with String = const std::string&
// or std::string_view (requires at least C++17)
//
static std::string decode(String encoded_string, bool remove_linebreaks)
{
//
// decode(…) is templated so that it can be used with String = const std::string&
// or std::string_view (requires at least C++17)
//

if (encoded_string.empty()) return std::string();
if(encoded_string.empty())
return std::string();

if (remove_linebreaks) {
if(remove_linebreaks)
{

std::string copy(encoded_string);
std::string copy(encoded_string);

copy.erase(std::remove(copy.begin(), copy.end(), '\n'), copy.end());
copy.erase(std::remove(copy.begin(), copy.end(), '\n'), copy.end());

return base64_decode(copy, false);
return base64_decode(copy, false);
}

size_t length_of_string = encoded_string.length();
size_t pos = 0;

//
// The approximate length (bytes) of the decoded string might be one or
// two bytes smaller, depending on the amount of trailing equal signs
// in the encoded string. This approximation is needed to reserve
// enough space in the string to be returned.
//
size_t pos = 0;

//
// The approximate length (bytes) of the decoded string might be one or
// two bytes smaller, depending on the amount of trailing equal signs
// in the encoded string. This approximation is needed to reserve
// enough space in the string to be returned.
//
size_t approx_length_of_decoded_string = length_of_string / 4 * 3;
std::string ret;
ret.reserve(approx_length_of_decoded_string);

while (pos < length_of_string) {
while(pos < length_of_string)
{

unsigned int pos_of_char_1 = pos_of_char(encoded_string[pos+1] );
unsigned int pos_of_char_1 = pos_of_char(encoded_string[pos + 1]);

ret.push_back(static_cast<std::string::value_type>( ( (pos_of_char(encoded_string[pos+0]) ) << 2 ) + ( (pos_of_char_1 & 0x30 ) >> 4)));
ret.push_back(static_cast<std::string::value_type>(
((pos_of_char(encoded_string[pos + 0])) << 2) + ((pos_of_char_1 & 0x30) >> 4)));

if (encoded_string[pos+2] != '=' && encoded_string[pos+2] != '.') { // accept URL-safe base 64 strings, too, so check for '.' also.
if(encoded_string[pos + 2] != '=' && encoded_string[pos + 2] != '.')
{ // accept URL-safe base 64 strings, too, so check for '.' also.

unsigned int pos_of_char_2 = pos_of_char(encoded_string[pos+2] );
ret.push_back(static_cast<std::string::value_type>( (( pos_of_char_1 & 0x0f) << 4) + (( pos_of_char_2 & 0x3c) >> 2)));
unsigned int pos_of_char_2 = pos_of_char(encoded_string[pos + 2]);
ret.push_back(static_cast<std::string::value_type>(((pos_of_char_1 & 0x0f) << 4) +
((pos_of_char_2 & 0x3c) >> 2)));

if (encoded_string[pos+3] != '=' && encoded_string[pos+3] != '.') {
ret.push_back(static_cast<std::string::value_type>( ( (pos_of_char_2 & 0x03 ) << 6 ) + pos_of_char(encoded_string[pos+3]) ));
}
}
if(encoded_string[pos + 3] != '=' && encoded_string[pos + 3] != '.')
{
ret.push_back(static_cast<std::string::value_type>(
((pos_of_char_2 & 0x03) << 6) + pos_of_char(encoded_string[pos + 3])));
}
}

pos += 4;
pos += 4;
}

return ret;
}

std::string base64_decode(std::string const& s, bool remove_linebreaks) {
return decode(s, remove_linebreaks);
std::string base64_decode(std::string const& s, bool remove_linebreaks)
{
return decode(s, remove_linebreaks);
}

std::string base64_encode(std::string const& s, bool url) {
return encode(s, url);
}
std::string base64_encode(std::string const& s, bool url) { return encode(s, url); }

std::string base64_encode_pem (std::string const& s) {
return encode_pem(s);
}
std::string base64_encode_pem(std::string const& s) { return encode_pem(s); }

std::string base64_encode_mime(std::string const& s) {
return encode_mime(s);
}
std::string base64_encode_mime(std::string const& s) { return encode_mime(s); }

#if __cplusplus >= 201703L
//
Expand All @@ -235,20 +259,15 @@ std::string base64_encode_mime(std::string const& s) {
// Provided by Yannic Bonenberger (https://github.com/Yannic)
//

std::string base64_encode(std::string_view s, bool url) {
return encode(s, url);
}
std::string base64_encode(std::string_view s, bool url) { return encode(s, url); }

std::string base64_encode_pem(std::string_view s) {
return encode_pem(s);
}
std::string base64_encode_pem(std::string_view s) { return encode_pem(s); }

std::string base64_encode_mime(std::string_view s) {
return encode_mime(s);
}
std::string base64_encode_mime(std::string_view s) { return encode_mime(s); }

std::string base64_decode(std::string_view s, bool remove_linebreaks) {
return decode(s, remove_linebreaks);
std::string base64_decode(std::string_view s, bool remove_linebreaks)
{
return decode(s, remove_linebreaks);
}

#endif // __cplusplus >= 201703L
#endif // __cplusplus >= 201703L
Loading

0 comments on commit d3ee8a8

Please sign in to comment.