Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][Fusion] Restrict types of fusable command groups #12556

Merged
merged 4 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl/source/detail/jit_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
for (auto &RawCmd : InputKernels) {
auto *KernelCmd = static_cast<ExecCGCommand *>(RawCmd);
auto &CG = KernelCmd->getCG();
assert(CG.getType() == CG::Kernel);
assert(KernelCmd->isFusable());
auto *KernelCG = static_cast<CGExecKernel *>(&CG);

auto KernelName = KernelCG->MKernelName;
Expand Down
32 changes: 29 additions & 3 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@ bool Command::isHostTask() const {
CG::CGTYPE::CodeplayHostTask);
}

bool Command::isFusable() const {
return (MType == CommandType::RUN_CG) &&
((static_cast<const ExecCGCommand *>(this))->getCG().getType() ==
CG::CGTYPE::Kernel);
}

static void flushCrossQueueDeps(const std::vector<EventImplPtr> &EventImpls,
const QueueImplPtr &Queue) {
for (auto &EventImpl : EventImpls) {
Expand Down Expand Up @@ -1825,7 +1831,7 @@ void UpdateHostRequirementCommand::emitInstrumentationData() {
#endif
}

static std::string cgTypeToString(detail::CG::CGTYPE Type) {
static std::string_view cgTypeToString(detail::CG::CGTYPE Type) {
switch (Type) {
case detail::CG::Kernel:
return "Kernel";
Expand All @@ -1845,6 +1851,10 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) {
case detail::CG::CopyPtrToAcc:
return "copy ptr to acc";
break;
case detail::CG::Barrier:
return "barrier";
case detail::CG::BarrierWaitlist:
return "barrier waitlist";
case detail::CG::CopyUSM:
return "copy usm";
break;
Expand All @@ -1863,6 +1873,8 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) {
case detail::CG::Fill2DUSM:
return "fill 2d usm";
break;
case detail::CG::AdviseUSM:
return "advise usm";
case detail::CG::Memset2DUSM:
return "memset 2d usm";
break;
Expand All @@ -1872,6 +1884,16 @@ static std::string cgTypeToString(detail::CG::CGTYPE Type) {
case detail::CG::CopyFromDeviceGlobal:
return "copy from device_global";
break;
case detail::CG::ReadWriteHostPipe:
return "read_write host pipe";
case detail::CG::ExecCommandBuffer:
return "exec command buffer";
case detail::CG::CopyImage:
return "copy image";
case detail::CG::SemaphoreWait:
return "semaphore wait";
case detail::CG::SemaphoreSignal:
return "semaphore signal";
default:
return "unknown";
break;
Expand Down Expand Up @@ -2102,7 +2124,7 @@ void ExecCGCommand::emitInstrumentationData() {
KernelCG->getKernelName(), MAddress, FromSource);
} break;
default:
KernelName = cgTypeToString(MCommandGroup->getType());
KernelName = getTypeString();
break;
}

Expand Down Expand Up @@ -2150,7 +2172,7 @@ void ExecCGCommand::printDot(std::ostream &Stream) const {
break;
}
default:
Stream << "CG type: " << cgTypeToString(MCommandGroup->getType()) << "\\n";
Stream << "CG type: " << getTypeString() << "\\n";
break;
}

Expand All @@ -2165,6 +2187,10 @@ void ExecCGCommand::printDot(std::ostream &Stream) const {
}
}

std::string_view ExecCGCommand::getTypeString() const {
return cgTypeToString(MCommandGroup->getType());
}

// SYCL has a parallel_for_work_group variant where the only NDRange
// characteristics set by a user is the number of work groups. This does not
// map to the OpenCL clEnqueueNDRangeAPI, which requires global work size to
Expand Down
3 changes: 3 additions & 0 deletions sycl/source/detail/scheduler/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ class Command {

bool isHostTask() const;

bool isFusable() const;

protected:
QueueImplPtr MQueue;
EventImplPtr MEvent;
Expand Down Expand Up @@ -648,6 +650,7 @@ class ExecCGCommand : public Command {

void printDot(std::ostream &Stream) const final;
void emitInstrumentationData() final;
std::string_view getTypeString() const;

detail::CG &getCG() const { return *MCommandGroup; }

Expand Down
118 changes: 64 additions & 54 deletions sycl/source/detail/scheduler/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "detail/config.hpp"
#include <detail/context_impl.hpp>
#include <detail/event_impl.hpp>
#include <sstream>
#include <sycl/feature_test.hpp>
#if SYCL_EXT_CODEPLAY_KERNEL_FUSION
#include <detail/jit_compiler.hpp>
Expand Down Expand Up @@ -949,66 +950,75 @@ Scheduler::GraphBuildResult Scheduler::GraphBuilder::addCG(
if (!NewCmd)
throw runtime_error("Out of host memory", PI_ERROR_OUT_OF_HOST_MEMORY);

// Host tasks cannot participate in fusion. They take the regular route. If
// they create any requirement or event dependency on any of the kernels in
// the fusion list, this will lead to cancellation of the fusion in the
// GraphProcessor.
// Only device kernel command groups can participate in fusion. Otherwise,
// command groups take the regular route. If they create any requirement or
// event dependency on any of the kernels in the fusion list, this will lead
// to cancellation of the fusion in the GraphProcessor.
auto QUniqueID = std::hash<sycl::detail::queue_impl *>()(Queue.get());
if (isInFusionMode(QUniqueID) && !NewCmd->isHostTask()) {
auto *FusionCmd = findFusionList(QUniqueID)->second.get();

bool dependsOnFusion = false;
for (auto Ev = Events.begin(); Ev != Events.end();) {
auto *EvDepCmd = static_cast<Command *>((*Ev)->getCommand());
if (!EvDepCmd) {
continue;
}
// Handle event dependencies on any commands part of another active
// fusion.
if (EvDepCmd->getQueue() != Queue && isPartOfActiveFusion(EvDepCmd)) {
printFusionWarning("Aborting fusion because of event dependency from a "
"different fusion");
cancelFusion(EvDepCmd->getQueue(), ToEnqueue);
}
// Check if this command depends on the placeholder command for the fusion
// itself participates in.
if (EvDepCmd == FusionCmd) {
Ev = Events.erase(Ev);
dependsOnFusion = true;
} else {
++Ev;
if (isInFusionMode(QUniqueID)) {
victor-eds marked this conversation as resolved.
Show resolved Hide resolved
if (NewCmd->isFusable()) {
auto *FusionCmd = findFusionList(QUniqueID)->second.get();

bool dependsOnFusion = false;
for (auto Ev = Events.begin(); Ev != Events.end();) {
auto *EvDepCmd = static_cast<Command *>((*Ev)->getCommand());
if (!EvDepCmd) {
continue;
}
// Handle event dependencies on any commands part of another active
// fusion.
if (EvDepCmd->getQueue() != Queue && isPartOfActiveFusion(EvDepCmd)) {
printFusionWarning(
"Aborting fusion because of event dependency from a "
"different fusion");
cancelFusion(EvDepCmd->getQueue(), ToEnqueue);
}
// Check if this command depends on the placeholder command for the
// fusion itself participates in.
if (EvDepCmd == FusionCmd) {
Ev = Events.erase(Ev);
dependsOnFusion = true;
} else {
++Ev;
}
}
}

// If this command has an explicit event dependency on the placeholder
// command for this fusion (because it used depends_on on the event returned
// by submitting another kernel to this fusion earlier), add a dependency on
// all the commands in the fusion list so far.
if (dependsOnFusion) {
for (auto *Cmd : FusionCmd->getFusionList()) {
Events.push_back(Cmd->getEvent());
// If this command has an explicit event dependency on the placeholder
// command for this fusion (because it used depends_on on the event
// returned by submitting another kernel to this fusion earlier), add a
// dependency on all the commands in the fusion list so far.
if (dependsOnFusion) {
for (auto *Cmd : FusionCmd->getFusionList()) {
Events.push_back(Cmd->getEvent());
}
}
}

// Add the kernel to the graph, but delay the enqueue of any auxiliary
// commands (e.g., allocations) resulting from that process by adding them
// to the list of auxiliary commands of the fusion command.
createGraphForCommand(NewCmd.get(), NewCmd->getCG(),
isInteropHostTask(NewCmd.get()), Reqs, Events, Queue,
FusionCmd->auxiliaryCommands());

// Set the fusion command, so we recognize when another command depends on a
// kernel in the fusion list.
FusionCmd->addToFusionList(NewCmd.get());
NewCmd->MFusionCmd = FusionCmd;
std::vector<Command *> ToCleanUp;
// Add an event dependency from the fusion placeholder command to the new
// kernel.
auto ConnectionCmd = FusionCmd->addDep(NewCmd->getEvent(), ToCleanUp);
if (ConnectionCmd) {
FusionCmd->auxiliaryCommands().push_back(ConnectionCmd);
// Add the kernel to the graph, but delay the enqueue of any auxiliary
// commands (e.g., allocations) resulting from that process by adding them
// to the list of auxiliary commands of the fusion command.
createGraphForCommand(NewCmd.get(), NewCmd->getCG(),
isInteropHostTask(NewCmd.get()), Reqs, Events,
Queue, FusionCmd->auxiliaryCommands());

// Set the fusion command, so we recognize when another command depends on
// a kernel in the fusion list.
FusionCmd->addToFusionList(NewCmd.get());
NewCmd->MFusionCmd = FusionCmd;
std::vector<Command *> ToCleanUp;
// Add an event dependency from the fusion placeholder command to the new
// kernel.
auto ConnectionCmd = FusionCmd->addDep(NewCmd->getEvent(), ToCleanUp);
if (ConnectionCmd) {
FusionCmd->auxiliaryCommands().push_back(ConnectionCmd);
}
return {NewCmd.release(), FusionCmd->getEvent(), false};
} else {
std::string s;
std::stringstream ss(s);
ss << "Not fusing '" << NewCmd->getTypeString()
<< "' command group. Can only fuse device kernel command groups.";
printFusionWarning(ss.str());
}
return {NewCmd.release(), FusionCmd->getEvent(), false};
}
createGraphForCommand(NewCmd.get(), NewCmd->getCG(),
isInteropHostTask(NewCmd.get()), Reqs, Events, Queue,
Expand Down
120 changes: 120 additions & 0 deletions sycl/test-e2e/KernelFusion/non-kernel-cg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// RUN: %{build} -fsycl-embed-ir -o %t.out
// RUN: env SYCL_RT_WARNING_LEVEL=2 %{run} %t.out
// XFAIL: hip

// COM: Test fails on hip due to unsupported CG kinds being tested. This test
// only checks fusion does not crash on non-kernel CG (target independent test),
// so having multiple CG kinds has higher priority than running the test on all
// backends.

// Test non-kernel device command groups are not fused

#include <sycl/sycl.hpp>

using namespace sycl;

int main() {
constexpr size_t dataSize = 512;
constexpr float Pattern{10};

queue q{ext::codeplay::experimental::property::queue::enable_fusion{}};
ext::codeplay::experimental::fusion_wrapper fw(q);

constexpr size_t count = 64;
auto *dst = malloc_device<float>(count, q);
auto *src = malloc_device<float>(count, q);

{
// CHECK: Not fusing 'copy acc to ptr' command group. Can only fuse device kernel command groups.
buffer<float> src(dataSize);
std::shared_ptr<float> dst(new float[dataSize]);
fw.start_fusion();
q.submit([&](handler &cgh) {
accessor acc(src, cgh, read_only);
cgh.copy(acc, dst);
});
fw.complete_fusion();
}

{
// CHECK: Not fusing 'copy ptr to acc' command group. Can only fuse device kernel command groups.
buffer<float> dst(dataSize);
std::shared_ptr<float> src(new float[dataSize]);
fw.start_fusion();
q.submit([&](handler &cgh) {
accessor acc(dst, cgh, write_only);
cgh.copy(src, acc);
});
fw.complete_fusion();
}

{
// CHECK: Not fusing 'copy acc to acc' command group. Can only fuse device kernel command groups.
buffer<float> dst(dataSize);
buffer<float> src(dataSize);
fw.start_fusion();
q.submit([&](handler &cgh) {
accessor acc0(src, cgh, read_only);
accessor acc1(dst, cgh, write_only);
cgh.copy(acc0, acc1);
});
fw.complete_fusion();
}

{
// CHECK: Not fusing 'barrier' command group. Can only fuse device kernel command groups.
fw.start_fusion();
q.submit([&](handler &cgh) { cgh.ext_oneapi_barrier(); });
fw.complete_fusion();
}

{
// CHECK: Not fusing 'barrier waitlist' command group. Can only fuse device kernel command groups.
buffer<float> dst(dataSize);
buffer<float> src(dataSize);
std::vector<event> event_list;
event_list.push_back(q.submit([&](handler &cgh) {
accessor acc0(src, cgh, read_only);
accessor acc1(dst, cgh, write_only);
cgh.copy(acc0, acc1);
}));
fw.start_fusion();
q.submit([&](handler &cgh) { cgh.ext_oneapi_barrier(event_list); });
fw.complete_fusion();
}

{
// CHECK: Not fusing 'fill' command group. Can only fuse device kernel command groups.
buffer<float> dst(dataSize);
fw.start_fusion();
q.submit([&](handler &cgh) {
accessor acc(dst, cgh, write_only);
cgh.fill(acc, Pattern);
});
fw.complete_fusion();
}

{
// CHECK: Not fusing 'copy usm' command group. Can only fuse device kernel command groups.
fw.start_fusion();
q.submit([&](handler &cgh) { cgh.memcpy(dst, src, count); });
fw.complete_fusion();
}

{
// CHECK: Not fusing 'fill usm' command group. Can only fuse device kernel command groups.
fw.start_fusion();
q.submit([&](handler &cgh) { cgh.fill(dst, Pattern, count); });
fw.complete_fusion();
}

{
// CHECK: Not fusing 'prefetch usm' command group. Can only fuse device kernel command groups.
fw.start_fusion();
q.submit([&](handler &cgh) { cgh.prefetch(dst, count); });
fw.complete_fusion();
}

free(src, q);
free(dst, q);
}
Loading