From 16d842e8898b357c7c62d045120a903d1d1e3518 Mon Sep 17 00:00:00 2001 From: F-G Fernandez <76527547+fg-mindee@users.noreply.github.com> Date: Tue, 16 Nov 2021 17:52:53 +0100 Subject: [PATCH] feat: Added GPU support for classification and improve memory pinning (#629) * refactor: Removed unnecessary metric reset * feat: Added support of GPU for character classification * feat: Improved memory pinning mechanism for CPU --- references/classification/train_pytorch.py | 29 ++++++++++++++++++++-- references/detection/train_pytorch.py | 6 ++--- references/detection/train_tensorflow.py | 2 -- references/recognition/train_pytorch.py | 6 ++--- references/recognition/train_tensorflow.py | 2 -- 5 files changed, 31 insertions(+), 14 deletions(-) diff --git a/references/classification/train_pytorch.py b/references/classification/train_pytorch.py index c22ec1fdf4..789849814f 100644 --- a/references/classification/train_pytorch.py +++ b/references/classification/train_pytorch.py @@ -10,6 +10,7 @@ import datetime import multiprocessing as mp import time +import logging import numpy as np import torch @@ -38,6 +39,10 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, m for _ in progress_bar(range(len(train_loader)), parent=mb): images, targets = next(train_iter) + if torch.cuda.is_available(): + images = images.cuda() + targets = targets.cuda() + images = batch_transforms(images) optimizer.zero_grad() @@ -68,6 +73,11 @@ def evaluate(model, val_loader, batch_transforms, amp=False): val_iter = iter(val_loader) for images, targets in val_iter: images = batch_transforms(images) + + if torch.cuda.is_available(): + images = images.cuda() + targets = targets.cuda() + if amp: with torch.cuda.amp.autocast(): out = model(images) @@ -113,7 +123,7 @@ def main(args): drop_last=False, num_workers=args.workers, sampler=SequentialSampler(val_set), - pin_memory=True, + pin_memory=torch.cuda.is_available(), ) print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " f"{len(val_loader)} batches)") @@ -129,6 +139,21 @@ def main(args): checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint) + # GPU + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + args.device = 0 + else: + logging.warning("No accessible GPU, targe device set to CPU.") + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + if args.test_only: print("Running evaluation") val_loss, acc = evaluate(model, val_loader, batch_transforms) @@ -158,7 +183,7 @@ def main(args): drop_last=True, num_workers=args.workers, sampler=RandomSampler(train_set), - pin_memory=True, + pin_memory=torch.cuda.is_available(), ) print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " f"{len(train_loader)} batches)") diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 3f60ff858c..c6550bf355 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -120,7 +120,7 @@ def main(args): drop_last=False, num_workers=args.workers, sampler=SequentialSampler(val_set), - pin_memory=True, + pin_memory=torch.cuda.is_available(), collate_fn=val_set.collate_fn, ) print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " @@ -184,7 +184,7 @@ def main(args): drop_last=True, num_workers=args.workers, sampler=RandomSampler(train_set), - pin_memory=True, + pin_memory=torch.cuda.is_available(), collate_fn=train_set.collate_fn, ) print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " @@ -267,8 +267,6 @@ def main(args): 'precision': precision, 'mean_iou': mean_iou, }) - # Reset val metric - val_metric.reset() if args.wb: run.finish() diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index acc004ee7a..d42d8b0847 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -209,8 +209,6 @@ def main(args): 'precision': precision, 'mean_iou': mean_iou, }) - # Reset val metric - val_metric.reset() if args.wb: run.finish() diff --git a/references/recognition/train_pytorch.py b/references/recognition/train_pytorch.py index 5594ef02da..6f07bea704 100644 --- a/references/recognition/train_pytorch.py +++ b/references/recognition/train_pytorch.py @@ -124,7 +124,7 @@ def main(args): drop_last=False, num_workers=args.workers, sampler=SequentialSampler(val_set), - pin_memory=True, + pin_memory=torch.cuda.is_available(), collate_fn=val_set.collate_fn, ) print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " @@ -194,7 +194,7 @@ def main(args): drop_last=True, num_workers=args.workers, sampler=RandomSampler(train_set), - pin_memory=True, + pin_memory=torch.cuda.is_available(), collate_fn=train_set.collate_fn, ) print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " @@ -266,8 +266,6 @@ def main(args): 'exact_match': exact_match, 'partial_match': partial_match, }) - #reset val metric - val_metric.reset() if args.wb: run.finish() diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index fb877e54d4..23ba83b1e7 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -211,8 +211,6 @@ def main(args): 'exact_match': exact_match, 'partial_match': partial_match, }) - #reset val metric - val_metric.reset() if args.wb: run.finish()