Skip to content

Commit

Permalink
feat: Added GPU support for classification and improve memory pinning (
Browse files Browse the repository at this point in the history
…#629)

* refactor: Removed unnecessary metric reset

* feat: Added support of GPU for character classification

* feat: Improved memory pinning mechanism for CPU
  • Loading branch information
fg-mindee committed Nov 16, 2021
1 parent 62603cc commit 16d842e
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 14 deletions.
29 changes: 27 additions & 2 deletions references/classification/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import datetime
import multiprocessing as mp
import time
import logging

import numpy as np
import torch
Expand Down Expand Up @@ -38,6 +39,10 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, m
for _ in progress_bar(range(len(train_loader)), parent=mb):
images, targets = next(train_iter)

if torch.cuda.is_available():
images = images.cuda()
targets = targets.cuda()

images = batch_transforms(images)

optimizer.zero_grad()
Expand Down Expand Up @@ -68,6 +73,11 @@ def evaluate(model, val_loader, batch_transforms, amp=False):
val_iter = iter(val_loader)
for images, targets in val_iter:
images = batch_transforms(images)

if torch.cuda.is_available():
images = images.cuda()
targets = targets.cuda()

if amp:
with torch.cuda.amp.autocast():
out = model(images)
Expand Down Expand Up @@ -113,7 +123,7 @@ def main(args):
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(val_set),
pin_memory=True,
pin_memory=torch.cuda.is_available(),
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{len(val_loader)} batches)")
Expand All @@ -129,6 +139,21 @@ def main(args):
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint)

# GPU
if isinstance(args.device, int):
if not torch.cuda.is_available():
raise AssertionError("PyTorch cannot access your GPU. Please investigate!")
if args.device >= torch.cuda.device_count():
raise ValueError("Invalid device index")
# Silent default switch to GPU if available
elif torch.cuda.is_available():
args.device = 0
else:
logging.warning("No accessible GPU, targe device set to CPU.")
if torch.cuda.is_available():
torch.cuda.set_device(args.device)
model = model.cuda()

if args.test_only:
print("Running evaluation")
val_loss, acc = evaluate(model, val_loader, batch_transforms)
Expand Down Expand Up @@ -158,7 +183,7 @@ def main(args):
drop_last=True,
num_workers=args.workers,
sampler=RandomSampler(train_set),
pin_memory=True,
pin_memory=torch.cuda.is_available(),
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{len(train_loader)} batches)")
Expand Down
6 changes: 2 additions & 4 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main(args):
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(val_set),
pin_memory=True,
pin_memory=torch.cuda.is_available(),
collate_fn=val_set.collate_fn,
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
Expand Down Expand Up @@ -184,7 +184,7 @@ def main(args):
drop_last=True,
num_workers=args.workers,
sampler=RandomSampler(train_set),
pin_memory=True,
pin_memory=torch.cuda.is_available(),
collate_fn=train_set.collate_fn,
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
Expand Down Expand Up @@ -267,8 +267,6 @@ def main(args):
'precision': precision,
'mean_iou': mean_iou,
})
# Reset val metric
val_metric.reset()

if args.wb:
run.finish()
Expand Down
2 changes: 0 additions & 2 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,6 @@ def main(args):
'precision': precision,
'mean_iou': mean_iou,
})
# Reset val metric
val_metric.reset()

if args.wb:
run.finish()
Expand Down
6 changes: 2 additions & 4 deletions references/recognition/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def main(args):
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(val_set),
pin_memory=True,
pin_memory=torch.cuda.is_available(),
collate_fn=val_set.collate_fn,
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
Expand Down Expand Up @@ -194,7 +194,7 @@ def main(args):
drop_last=True,
num_workers=args.workers,
sampler=RandomSampler(train_set),
pin_memory=True,
pin_memory=torch.cuda.is_available(),
collate_fn=train_set.collate_fn,
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
Expand Down Expand Up @@ -266,8 +266,6 @@ def main(args):
'exact_match': exact_match,
'partial_match': partial_match,
})
#reset val metric
val_metric.reset()

if args.wb:
run.finish()
Expand Down
2 changes: 0 additions & 2 deletions references/recognition/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,6 @@ def main(args):
'exact_match': exact_match,
'partial_match': partial_match,
})
#reset val metric
val_metric.reset()

if args.wb:
run.finish()
Expand Down

0 comments on commit 16d842e

Please sign in to comment.