Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[burn-fusion] save all execution plans for any trigger #1143

Merged
merged 14 commits into from
Jan 16, 2024
4 changes: 4 additions & 0 deletions backend-comparison/src/persistence/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ pub fn save<B: Backend>(
.join("burn")
.join("backend-comparison");

for bench in benches.iter() {
println!("{bench}");
}

if !cache_dir.exists() {
fs::create_dir_all(&cache_dir)?;
}
Expand Down
23 changes: 15 additions & 8 deletions burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
client::FusionClient,
stream::{Context, TensorOpsDescription},
stream::{Context, OperationDescription},
FusionClientLocator, FusionTensor,
};
use burn_tensor::{backend::Backend, Device, Shape};
Expand Down Expand Up @@ -70,7 +70,7 @@ pub struct OptimizationProperties {
}

/// The fusion operation abstraction allows implementations to fuse many
/// [tensor operations](TensorOpsDescription) into one, improving the performance of the backend.
/// [tensor operations](OperationDescription) into one, improving the performance of the backend.
///
///
/// # Notes
Expand All @@ -79,19 +79,25 @@ pub struct OptimizationProperties {
/// the speed and efficiency of the computational graph. It doesn't mean that all registered
/// operations should be fused, but that another way of executing them is more efficient.
///
/// Also, it is important to return (FusionStatus::Closed) when no more registered operation can
/// Also, it is important to return (OptimizationStatus::Closed) when no more registered operation can
/// improve the performance.
pub trait OptimizationBuilder<B: FusionBackend>: Send {
/// Register a new [tensor operation](TensorOpsDescription).
fn register(&mut self, ops: &TensorOpsDescription);
pub trait OptimizationBuilder<O>: Send {
/// Register a new [tensor operation](OperationDescription).
fn register(&mut self, operation: &OperationDescription);
/// Finish the optimization and create a fusion operation.
fn build(&self) -> B::Optimization;
fn build(&self) -> O;
/// Reset the state.
fn reset(&mut self);
/// Return the builder [status](OptimizationStatus).
fn status(&self) -> OptimizationStatus;
/// Return the builder [properties](OptimizationProperties).
fn properties(&self) -> OptimizationProperties;
/// The number of operation fused.
fn len(&self) -> usize;
/// If no operations are fused.
fn is_empty(&self) -> bool {
self.len() == 0
}
}

/// The operation created from the [builder](OptimizationBuilder).
Expand Down Expand Up @@ -143,7 +149,8 @@ pub trait FusionBackend: Backend {
type FusionClient: FusionClient<FusionBackend = Self>;

/// The list of optimizations that will be used to optimize the computational graph.
fn optimizations(device: Device<Self>) -> Vec<Box<dyn OptimizationBuilder<Self>>>;
fn optimizations(device: Device<Self>)
-> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;

/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::TensorPrimitive).
fn float_tensor<const D: usize>(
Expand Down
10 changes: 5 additions & 5 deletions burn-fusion/src/client/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
stream::{Ops, TensorOpsDescription},
stream::{Operation, OperationDescription},
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
};
use burn_tensor::{
Expand All @@ -14,11 +14,11 @@ pub trait FusionClient: Send + Sync + Clone {

/// Create a new client for the given [fusion device](FusionBackend::FusionDevice).
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
/// Register a new [tensor operation description](TensorOpsDescription).
fn register<O: Ops<Self::FusionBackend> + 'static>(
/// Register a new [tensor operation description](OperationDescription).
fn register<O: Operation<Self::FusionBackend> + 'static>(
&self,
description: TensorOpsDescription,
ops: O,
description: OperationDescription,
operation: O,
);
/// Register all lazy computation.
fn drain(&self);
Expand Down
15 changes: 10 additions & 5 deletions burn-fusion/src/client/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::FusionClient;
use crate::{stream::TensorOpsDescription, FusionBackend, FusionServer, FusionTensor, Handle};
use crate::{
stream::{Operation, OperationDescription},
FusionBackend, FusionServer, FusionTensor, Handle,
};
use burn_tensor::ops::FloatElem;
use spin::Mutex;
use std::sync::Arc;
Expand Down Expand Up @@ -38,12 +41,14 @@ where
}
}

fn register<O: crate::stream::Ops<Self::FusionBackend> + 'static>(
fn register<O: Operation<Self::FusionBackend> + 'static>(
&self,
description: TensorOpsDescription,
ops: O,
description: OperationDescription,
operation: O,
) {
self.server.lock().register(description, Box::new(ops))
self.server
.lock()
.register(description, Box::new(operation))
}

fn drain(&self) {
Expand Down
16 changes: 8 additions & 8 deletions burn-fusion/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ macro_rules! binary_float_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: BinaryOpsDescription,
desc: BinaryOperationDescription,
}

impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_float_tensor(&self.desc.rhs);
Expand All @@ -31,10 +31,10 @@ macro_rules! binary_float_cmp_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: BinaryOpsDescription,
desc: BinaryOperationDescription,
}

impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_float_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_float_tensor(&self.desc.rhs);
Expand All @@ -55,10 +55,10 @@ macro_rules! binary_int_cmp_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: BinaryOpsDescription,
desc: BinaryOperationDescription,
}

impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_int_tensor(&self.desc.rhs);
Expand Down Expand Up @@ -89,10 +89,10 @@ macro_rules! binary_int_ops {
) => {
#[derive(new)]
struct $name<const D: usize> {
desc: BinaryOpsDescription,
desc: BinaryOperationDescription,
}

impl<const D: usize, B: FusionBackend> Ops<B> for $name<D> {
impl<const D: usize, B: FusionBackend> Operation<B> for $name<D> {
fn execute(self: Box<Self>, handles: &mut $crate::HandleContainer<B>) {
let lhs = handles.get_int_tensor::<D>(&self.desc.lhs);
let rhs = handles.get_int_tensor(&self.desc.rhs);
Expand Down
Loading