Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add learner training report summary #1591

Merged
merged 7 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions burn-book/src/basic-workflow/training.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Training

We are now ready to write the necessary code to train our model on the MNIST dataset.
We shall define the code for this training section in the file: `src/training.rs`.
We are now ready to write the necessary code to train our model on the MNIST dataset. We shall
define the code for this training section in the file: `src/training.rs`.

Instead of a simple tensor, the model should output an item that can be understood by the learner, a struct whose
responsibility is to apply an optimizer to the model. The output struct is used for all metrics
calculated during the training. Therefore it should include all the necessary information to
Instead of a simple tensor, the model should output an item that can be understood by the learner, a
struct whose responsibility is to apply an optimizer to the model. The output struct is used for all
metrics calculated during the training. Therefore it should include all the necessary information to
calculate any metric that you want for a task.

Burn provides two basic output types: `ClassificationOutput` and `RegressionOutput`. They implement
Expand Down Expand Up @@ -110,8 +110,14 @@ pub struct TrainingConfig {
pub learning_rate: f64,
}

pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
fn create_artifact_dir(artifact_dir: &str) {
// Remove existing artifacts before to get an accurate learner summary
std::fs::remove_dir_all(artifact_dir).ok();
std::fs::create_dir_all(artifact_dir).ok();
}

pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
create_artifact_dir(artifact_dir);
config
.save(format!("{artifact_dir}/config.json"))
.expect("Config should be saved successfully");
Expand Down Expand Up @@ -141,6 +147,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(
config.model.init::<B>(&device),
config.optimizer.init(),
Expand Down Expand Up @@ -181,8 +188,8 @@ Once the learner is created, we can simply call `fit` and provide the training a
dataloaders. For the sake of simplicity in this example, we employ the test set as the validation
set; however, we do not recommend this practice for actual usage.

Finally, the trained model is returned by the `fit` method, and the only remaining task is saving
the trained weights using the `CompactRecorder`. This recorder employs the `MessagePack` format with
`gzip` compression, `f16` for floats and `i16` for integers. Other recorders are available, offering
support for various formats, such as `BinCode` and `JSON`, with or without compression. Any backend,
regardless of precision, can load recorded data of any kind.
Finally, the trained model is returned by the `fit` method. The trained weights are then saved using
the `CompactRecorder`. This recorder employs the `MessagePack` format with half precision, `f16` for
floats and `i16` for integers. Other recorders are available, offering support for various formats,
such as `BinCode` and `JSON`, with or without compression. Any backend, regardless of precision, can
load recorded data of any kind.
2 changes: 2 additions & 0 deletions crates/burn-train/src/learner/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy
use crate::components::LearnerComponents;
use crate::learner::EarlyStoppingStrategy;
use crate::metric::store::EventStoreClient;
use crate::LearnerSummaryConfig;
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::Module;
use burn_core::optim::Optimizer;
Expand All @@ -26,6 +27,7 @@ pub struct Learner<LC: LearnerComponents> {
pub(crate) early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
pub(crate) event_processor: LC::EventProcessor,
pub(crate) event_store: Arc<EventStoreClient>,
pub(crate) summary: Option<LearnerSummaryConfig>,
}

#[derive(new)]
Expand Down
29 changes: 27 additions & 2 deletions crates/burn-train/src/learner/builder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashSet;
use std::sync::Arc;

use super::log::install_file_logger;
Expand All @@ -14,7 +15,7 @@ use crate::metric::processor::{FullEventProcessor, Metrics};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, LossMetric, Metric};
use crate::renderer::{default_renderer, MetricsRenderer};
use crate::LearnerCheckpointer;
use crate::{LearnerCheckpointer, LearnerSummaryConfig};
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::AutodiffModule;
use burn_core::optim::Optimizer;
Expand Down Expand Up @@ -53,6 +54,8 @@ where
num_loggers: usize,
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
summary_metrics: HashSet<String>,
summary: bool,
}

impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
Expand Down Expand Up @@ -94,6 +97,8 @@ where
.build(),
),
early_stopping: None,
summary_metrics: HashSet::new(),
summary: false,
}
}

Expand Down Expand Up @@ -140,7 +145,7 @@ where
where
T: Adaptor<Me::Input>,
{
self.metrics.register_metric_train(metric);
self.metrics.register_train_metric(metric);
self
}

Expand Down Expand Up @@ -174,6 +179,7 @@ where
Me: Metric + crate::metric::Numeric + 'static,
T: Adaptor<Me::Input>,
{
self.summary_metrics.insert(Me::NAME.to_string());
self.metrics.register_train_metric_numeric(metric);
self
}
Expand All @@ -186,6 +192,7 @@ where
where
V: Adaptor<Me::Input>,
{
self.summary_metrics.insert(Me::NAME.to_string());
self.metrics.register_valid_metric_numeric(metric);
self
}
Expand Down Expand Up @@ -266,6 +273,14 @@ where
self
}

/// Enable the training summary report.
///
/// The summary will be displayed at the end of `.fit()`.
pub fn summary(mut self) -> Self {
self.summary = true;
self
}

/// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer).
/// The [learning rate scheduler](LrScheduler) can also be a simple
/// [learning rate](burn_core::LearningRate).
Expand Down Expand Up @@ -320,6 +335,15 @@ where
LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
});

let summary = if self.summary {
Some(LearnerSummaryConfig {
directory: self.directory,
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
})
} else {
None
};

Learner {
model,
optim,
Expand All @@ -333,6 +357,7 @@ where
devices: self.devices,
interrupter: self.interrupter,
early_stopping: self.early_stopping,
summary,
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/burn-train/src/learner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod early_stopping;
mod epoch;
mod regression;
mod step;
mod summary;
mod train_val;

pub(crate) mod log;
Expand All @@ -16,5 +17,6 @@ pub use early_stopping::*;
pub use epoch::*;
pub use regression::*;
pub use step::*;
pub use summary::*;
pub use train::*;
pub use train_val::*;
Loading
Loading