Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add learner training report summary #1591

Merged
merged 7 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions burn-book/src/basic-workflow/training.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Training

We are now ready to write the necessary code to train our model on the MNIST dataset.
We shall define the code for this training section in the file: `src/training.rs`.
We are now ready to write the necessary code to train our model on the MNIST dataset. We shall
define the code for this training section in the file: `src/training.rs`.

Instead of a simple tensor, the model should output an item that can be understood by the learner, a struct whose
responsibility is to apply an optimizer to the model. The output struct is used for all metrics
calculated during the training. Therefore it should include all the necessary information to
Instead of a simple tensor, the model should output an item that can be understood by the learner, a
struct whose responsibility is to apply an optimizer to the model. The output struct is used for all
metrics calculated during the training. Therefore it should include all the necessary information to
calculate any metric that you want for a task.

Burn provides two basic output types: `ClassificationOutput` and `RegressionOutput`. They implement
Expand Down Expand Up @@ -110,8 +110,14 @@ pub struct TrainingConfig {
pub learning_rate: f64,
}

pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
fn create_artifact_dir(artifact_dir: &str) {
// Remove existing artifacts before to get an accurate learner summary
std::fs::remove_dir_all(artifact_dir).ok();
std::fs::create_dir_all(artifact_dir).ok();
}

pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
create_artifact_dir(artifact_dir);
config
.save(format!("{artifact_dir}/config.json"))
.expect("Config should be saved successfully");
Expand Down Expand Up @@ -152,6 +158,12 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
model_trained
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
.expect("Trained model should be saved successfully");

let summary = LearnerSummary::new(
artifact_dir,
&[AccuracyMetric::<B>::NAME, LossMetric::<B>::NAME],
);
println!("{}", summary);
}
```

Expand Down Expand Up @@ -181,8 +193,9 @@ Once the learner is created, we can simply call `fit` and provide the training a
dataloaders. For the sake of simplicity in this example, we employ the test set as the validation
set; however, we do not recommend this practice for actual usage.

Finally, the trained model is returned by the `fit` method, and the only remaining task is saving
the trained weights using the `CompactRecorder`. This recorder employs the `MessagePack` format with
`gzip` compression, `f16` for floats and `i16` for integers. Other recorders are available, offering
support for various formats, such as `BinCode` and `JSON`, with or without compression. Any backend,
regardless of precision, can load recorded data of any kind.
Finally, the trained model is returned by the `fit` method. The trained weights are then saved using
the `CompactRecorder`. This recorder employs the `MessagePack` format with `gzip` compression, `f16`
laggui marked this conversation as resolved.
Show resolved Hide resolved
for floats and `i16` for integers. Other recorders are available, offering support for various
formats, such as `BinCode` and `JSON`, with or without compression. Any backend, regardless of
precision, can load recorded data of any kind. Once the weights have been saved, we use the
`LearnerSummary` to display the training report summary.
2 changes: 1 addition & 1 deletion crates/burn-train/src/learner/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ where
where
T: Adaptor<Me::Input>,
{
self.metrics.register_metric_train(metric);
self.metrics.register_train_metric(metric);
self
}

Expand Down
2 changes: 2 additions & 0 deletions crates/burn-train/src/learner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod early_stopping;
mod epoch;
mod regression;
mod step;
mod summary;
mod train_val;

pub(crate) mod log;
Expand All @@ -16,5 +17,6 @@ pub use early_stopping::*;
pub use epoch::*;
pub use regression::*;
pub use step::*;
pub use summary::*;
pub use train::*;
pub use train_val::*;
190 changes: 190 additions & 0 deletions crates/burn-train/src/learner/summary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
use core::cmp::Ordering;
use std::{fmt::Display, path::Path};

use crate::{
logger::FileMetricLogger,
metric::store::{Aggregate, EventStore, LogEventStore, Split},
};

/// Contains the metric value at a given time.
pub struct MetricEntry {
/// The step at which the metric was recorded (i.e., epoch).
pub step: usize,
/// The metric value.
pub value: f64,
}

/// Contains the summary of recorded values for a given metric.
pub struct MetricSummary {
/// The metric name.
pub name: String,
/// The metric entries.
pub entries: Vec<MetricEntry>,
}

impl MetricSummary {
fn new<E: EventStore>(
event_store: &mut E,
metric: &str,
split: Split,
num_epochs: usize,
) -> Self {
let entries = (1..num_epochs)
.filter_map(|epoch| {
event_store
.find_metric(metric, epoch, Aggregate::Mean, split)
.map(|value| MetricEntry { step: epoch, value })
})
.collect::<Vec<_>>();

Self {
name: metric.to_string(),
entries,
}
}
}

/// Contains the summary of recorded metrics for the training and validation steps.
pub struct SummaryMetrics {
/// Training metrics summary.
pub train: Vec<MetricSummary>,
/// Validation metrics summary.
pub valid: Vec<MetricSummary>,
}

/// Detailed training summary.
pub struct LearnerSummary {
/// The number of epochs completed.
pub epochs: usize,
/// The summary of recorded metrics during training.
pub metrics: SummaryMetrics,
}

impl LearnerSummary {
/// Creates a new learner summary for the specified metrics.
///
/// # Arguments
///
/// * `directory` - The directory containing the training artifacts (checkpoints and logs).
/// * `metrics` - The list of metrics to collect for the summary.
pub fn new(directory: &str, metrics: &[&str]) -> Self {
if !Path::new(directory).exists() {
panic!("Artifact directory does not exist at: {}", directory);
}
let mut event_store = LogEventStore::default();

let train_logger = FileMetricLogger::new(format!("{directory}/train").as_str());
let valid_logger = FileMetricLogger::new(format!("{directory}/valid").as_str());

// Number of recorded epochs
let epochs = train_logger.epochs();

event_store.register_logger_train(train_logger);
event_store.register_logger_valid(valid_logger);

let train_summary = metrics
.iter()
.map(|metric| MetricSummary::new(&mut event_store, metric, Split::Train, epochs))
.collect::<Vec<_>>();

let valid_summary = metrics
.iter()
.map(|metric| MetricSummary::new(&mut event_store, metric, Split::Valid, epochs))
.collect::<Vec<_>>();

Self {
epochs,
metrics: SummaryMetrics {
train: train_summary,
valid: valid_summary,
},
}
}
}

impl Display for LearnerSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Compute the max length for each column
let split_train = "Train";
let split_valid = "Valid";
let max_split_len = "Split".len().max(split_train.len()).max(split_valid.len());
let mut max_metric_len = "Metric".len();
for metric in self.metrics.train.iter() {
max_metric_len = max_metric_len.max(metric.name.len());
}
for metric in self.metrics.valid.iter() {
max_metric_len = max_metric_len.max(metric.name.len());
}

// Summary header
writeln!(
f,
"{:=>width_symbol$} Learner Summary {:=>width_symbol$}\nTotal Epochs: {epochs}\n\n",
"",
"",
width_symbol = 24,
epochs = self.epochs,
)?;

// Metrics table header
writeln!(
f,
"| {:<width_split$} | {:<width_metric$} | Min. | Epoch | Max. | Epoch |\n|{:->width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|",
"Split", "Metric", "", "",
width_split = max_split_len,
width_metric = max_metric_len,
)?;

// Table entries
fn cmp_f64(a: &f64, b: &f64) -> Ordering {
match (a.is_nan(), b.is_nan()) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
_ => a.partial_cmp(b).unwrap(),
}
}

let mut write_metrics_summary = |metrics: &[MetricSummary],
split: &str|
-> std::fmt::Result {
for metric in metrics.iter() {
if metric.entries.is_empty() {
continue; // skip metrics with no recorded values
}

// Compute the min & max for each metric
let metric_min = metric
.entries
.iter()
.min_by(|a, b| cmp_f64(&a.value, &b.value))
.unwrap();
let metric_max = metric
.entries
.iter()
.max_by(|a, b| cmp_f64(&a.value, &b.value))
.unwrap();

writeln!(
f,
"| {:<width_split$} | {:<width_metric$} | {:<9.3?}| {:<9?}| {:<9.3?}| {:<9.3?}|",
split,
metric.name,
metric_min.value,
metric_min.step,
metric_max.value,
metric_max.step,
width_split = max_split_len,
width_metric = max_metric_len,
)?;
}

Ok(())
};

write_metrics_summary(&self.metrics.train, split_train)?;
write_metrics_summary(&self.metrics.valid, split_valid)?;

Ok(())
}
}
52 changes: 42 additions & 10 deletions crates/burn-train/src/logger/metric.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
use crate::metric::MetricEntry;
use std::collections::HashMap;
use crate::metric::{MetricEntry, NumericEntry};
use std::{collections::HashMap, fs};

const EPOCH_PREFIX: &str = "epoch-";

/// Metric logger.
pub trait MetricLogger: Send {
Expand All @@ -19,7 +21,7 @@ pub trait MetricLogger: Send {
fn end_epoch(&mut self, epoch: usize);

/// Read the logs for an epoch.
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String>;
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String>;
}

/// The file metric logger.
Expand Down Expand Up @@ -47,14 +49,44 @@ impl FileMetricLogger {
}
}

/// Number of epochs recorded.
pub(crate) fn epochs(&self) -> usize {
let mut max_epoch = 0;

for path in fs::read_dir(&self.directory).unwrap() {
let path = path.unwrap();

if fs::metadata(path.path()).unwrap().is_dir() {
let dir_name = path.file_name().into_string().unwrap();

if !dir_name.starts_with(EPOCH_PREFIX) {
continue;
}

let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();

if let Some(epoch) = epoch {
if epoch > max_epoch {
max_epoch = epoch;
}
}
}
}

max_epoch
}

fn epoch_directory(&self, epoch: usize) -> String {
format!("{}/{}{}", self.directory, EPOCH_PREFIX, epoch)
}
fn file_path(&self, name: &str, epoch: usize) -> String {
let directory = format!("{}/epoch-{}", self.directory, epoch);
let directory = self.epoch_directory(epoch);
let name = name.replace(' ', "_");

format!("{directory}/{name}.log")
}
fn create_directory(&self, epoch: usize) {
let directory = format!("{}/epoch-{}", self.directory, epoch);
let directory = self.epoch_directory(epoch);
std::fs::create_dir_all(directory).ok();
}
}
Expand Down Expand Up @@ -88,7 +120,7 @@ impl MetricLogger for FileMetricLogger {
self.epoch = epoch + 1;
}

fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
if let Some(value) = self.loggers.get(name) {
value.sync()
}
Expand All @@ -104,7 +136,7 @@ impl MetricLogger for FileMetricLogger {
if value.is_empty() {
None
} else {
match value.parse::<f64>() {
match NumericEntry::deserialize(value) {
Ok(value) => Some(value),
Err(err) => {
log::error!("{err}");
Expand All @@ -117,7 +149,7 @@ impl MetricLogger for FileMetricLogger {
.collect();

if errors {
Err("Parsing float errors".to_string())
Err("Parsing numeric entry errors".to_string())
} else {
Ok(data)
}
Expand Down Expand Up @@ -154,7 +186,7 @@ impl MetricLogger for InMemoryMetricLogger {
}
}

fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<f64>, String> {
fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
let values = match self.values.get(name) {
Some(values) => values,
None => return Ok(Vec::new()),
Expand All @@ -164,7 +196,7 @@ impl MetricLogger for InMemoryMetricLogger {
Some(logger) => Ok(logger
.values
.iter()
.filter_map(|value| value.parse::<f64>().ok())
.filter_map(|value| NumericEntry::deserialize(value).ok())
.collect()),
None => Ok(Vec::new()),
}
Expand Down
Loading
Loading