Skip to content

Commit

Permalink
Add learner summary expected results tests
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Apr 9, 2024
1 parent f411de7 commit 76422c3
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 18 deletions.
3 changes: 2 additions & 1 deletion burn-book/src/basic-workflow/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
let summary = LearnerSummary::new(
artifact_dir,
&[AccuracyMetric::<B>::NAME, LossMetric::<B>::NAME],
);
)
.expect("Summary artifacts should exist");
println!("{}", summary);
}
```
Expand Down
118 changes: 104 additions & 14 deletions crates/burn-train/src/learner/summary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,22 @@ impl MetricSummary {
metric: &str,
split: Split,
num_epochs: usize,
) -> Self {
let entries = (1..num_epochs)
) -> Option<Self> {
let entries = (1..=num_epochs)
.filter_map(|epoch| {
event_store
.find_metric(metric, epoch, Aggregate::Mean, split)
.map(|value| MetricEntry { step: epoch, value })
})
.collect::<Vec<_>>();

Self {
name: metric.to_string(),
entries,
if entries.is_empty() {
None
} else {
Some(Self {
name: metric.to_string(),
entries,
})
}
}
}
Expand Down Expand Up @@ -67,14 +71,23 @@ impl LearnerSummary {
///
/// * `directory` - The directory containing the training artifacts (checkpoints and logs).
/// * `metrics` - The list of metrics to collect for the summary.
pub fn new(directory: &str, metrics: &[&str]) -> Self {
if !Path::new(directory).exists() {
panic!("Artifact directory does not exist at: {}", directory);
pub fn new(directory: &str, metrics: &[&str]) -> Result<Self, String> {
let directory_path = Path::new(directory);
if !directory_path.exists() {
return Err(format!("Artifact directory does not exist at: {directory}"));
}
let train_dir = directory_path.join("train");
let valid_dir = directory_path.join("valid");
if !train_dir.exists() & !valid_dir.exists() {
return Err(format!(
"No training or validation artifacts found at: {directory}"
));
}

let mut event_store = LogEventStore::default();

let train_logger = FileMetricLogger::new(format!("{directory}/train").as_str());
let valid_logger = FileMetricLogger::new(format!("{directory}/valid").as_str());
let train_logger = FileMetricLogger::new(train_dir.to_str().unwrap());
let valid_logger = FileMetricLogger::new(valid_dir.to_str().unwrap());

// Number of recorded epochs
let epochs = train_logger.epochs();
Expand All @@ -84,21 +97,21 @@ impl LearnerSummary {

let train_summary = metrics
.iter()
.map(|metric| MetricSummary::new(&mut event_store, metric, Split::Train, epochs))
.filter_map(|metric| MetricSummary::new(&mut event_store, metric, Split::Train, epochs))
.collect::<Vec<_>>();

let valid_summary = metrics
.iter()
.map(|metric| MetricSummary::new(&mut event_store, metric, Split::Valid, epochs))
.filter_map(|metric| MetricSummary::new(&mut event_store, metric, Split::Valid, epochs))
.collect::<Vec<_>>();

Self {
Ok(Self {
epochs,
metrics: SummaryMetrics {
train: train_summary,
valid: valid_summary,
},
}
})
}
}

Expand Down Expand Up @@ -188,3 +201,80 @@ impl Display for LearnerSummary {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
#[should_panic = "Summary artifacts should exist"]
fn test_artifact_dir_should_exist() {
let dir = "/tmp/learner-summary-not-found";
let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
}

#[test]
#[should_panic = "Summary artifacts should exist"]
fn test_train_valid_artifacts_should_exist() {
let dir = "/tmp/test-learner-summary-empty";
std::fs::create_dir_all(dir).ok();
let _summary = LearnerSummary::new(dir, &["Loss"]).expect("Summary artifacts should exist");
}

#[test]
fn test_summary_should_be_empty() {
let dir = Path::new("/tmp/test-learner-summary-empty-metrics");
std::fs::create_dir_all(dir).unwrap();
std::fs::create_dir_all(dir.join("train/epoch-1")).unwrap();
std::fs::create_dir_all(dir.join("valid/epoch-1")).unwrap();
let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
.expect("Summary artifacts should exist");

assert_eq!(summary.epochs, 1);

assert_eq!(summary.metrics.train.len(), 0);
assert_eq!(summary.metrics.valid.len(), 0);

std::fs::remove_dir_all(dir).unwrap();
}

#[test]
fn test_summary_should_be_collected() {
let dir = Path::new("/tmp/test-learner-summary");
let train_dir = dir.join("train/epoch-1");
let valid_dir = dir.join("valid/epoch-1");
std::fs::create_dir_all(dir).unwrap();
std::fs::create_dir_all(&train_dir).unwrap();
std::fs::create_dir_all(&valid_dir).unwrap();

std::fs::write(train_dir.join("Loss.log"), "1.0\n2.0").expect("Unable to write file");
std::fs::write(valid_dir.join("Loss.log"), "1.0").expect("Unable to write file");

let summary = LearnerSummary::new(dir.to_str().unwrap(), &["Loss"])
.expect("Summary artifacts should exist");

assert_eq!(summary.epochs, 1);

// Only Loss metric
assert_eq!(summary.metrics.train.len(), 1);
assert_eq!(summary.metrics.valid.len(), 1);

// Aggregated train metric entries for 1 epoch
let train_metric = &summary.metrics.train[0];
assert_eq!(train_metric.name, "Loss");
assert_eq!(train_metric.entries.len(), 1);
let entry = &train_metric.entries[0];
assert_eq!(entry.step, 1); // epoch = 1
assert_eq!(entry.value, 1.5); // (1 + 2) / 2

// Aggregated valid metric entries for 1 epoch
let valid_metric = &summary.metrics.valid[0];
assert_eq!(valid_metric.name, "Loss");
assert_eq!(valid_metric.entries.len(), 1);
let entry = &valid_metric.entries[0];
assert_eq!(entry.step, 1); // epoch = 1
assert_eq!(entry.value, 1.0);

std::fs::remove_dir_all(dir).unwrap();
}
}
3 changes: 2 additions & 1 deletion examples/custom-image-dataset/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ pub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {
let summary = LearnerSummary::new(
ARTIFACT_DIR,
&[AccuracyMetric::<B>::NAME, LossMetric::<B>::NAME],
);
)
.expect("Summary artifacts should exist");
println!("{}", summary);
}
3 changes: 2 additions & 1 deletion examples/guide/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, dev
let summary = LearnerSummary::new(
artifact_dir,
&[AccuracyMetric::<B>::NAME, LossMetric::<B>::NAME],
);
)
.expect("Summary artifacts should exist");
println!("{}", summary);
}
3 changes: 2 additions & 1 deletion examples/mnist/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
let summary = LearnerSummary::new(
ARTIFACT_DIR,
&[AccuracyMetric::<B>::NAME, LossMetric::<B>::NAME],
);
)
.expect("Summary artifacts should exist");
println!("{}", summary);
}

0 comments on commit 76422c3

Please sign in to comment.