Skip to content

Commit

Permalink
feat: Added support of FasterRCNN for PyTorch (#691)
Browse files Browse the repository at this point in the history
* feat: Added support of faster rcnn for Pytorch

* style: Removed unnused import

* test: Added unittest for Faster RCNN

* refactor: Reflected changes on training script
  • Loading branch information
fg-mindee committed Dec 10, 2021
1 parent 6360bde commit c9fbe35
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 64 deletions.
1 change: 1 addition & 0 deletions doctr/models/obj_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .faster_rcnn import *
4 changes: 4 additions & 0 deletions doctr/models/obj_detection/faster_rcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from doctr.file_utils import is_tf_available, is_torch_available

if not is_tf_available() and is_torch_available():
from .pytorch import * # type: ignore[misc]
79 changes: 79 additions & 0 deletions doctr/models/obj_detection/faster_rcnn/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (C) 2021, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import Any, Dict

from torchvision.models.detection import FasterRCNN, faster_rcnn

from ...utils import load_pretrained_params

__all__ = ['fasterrcnn_mobilenet_v3_large_fpn']


default_cfgs: Dict[str, Dict[str, Any]] = {
'fasterrcnn_mobilenet_v3_large_fpn': {
'input_shape': (3, 1024, 1024),
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'anchor_sizes': [32, 64, 128, 256, 512],
'anchor_aspect_ratios': (0.5, 1., 2.),
'num_classes': 5,
'url': None,
},
}


def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN:

_kwargs = {
"image_mean": default_cfgs[arch]['mean'],
"image_std": default_cfgs[arch]['std'],
"box_detections_per_img": 150,
"box_score_thresh": 0.15,
"box_positive_fraction": 0.35,
"box_nms_thresh": 0.2,
"rpn_nms_thresh": 0.2,
"num_classes": default_cfgs[arch]['num_classes'],
}

# Build the model
_kwargs.update(kwargs)
model = faster_rcnn.__dict__[arch](pretrained=False, pretrained_backbone=False, **_kwargs)

if pretrained:
# Load pretrained parameters
load_pretrained_params(model, default_cfgs[arch]['url'])
else:
# Filter keys
state_dict = {
k: v for k, v in faster_rcnn.__dict__[arch](pretrained=True).state_dict().items()
if not k.startswith('roi_heads.')
}

# Load state dict
model.load_state_dict(state_dict, strict=False)

return model


def fasterrcnn_mobilenet_v3_large_fpn(pretrained: bool = False, **kwargs: Any) -> FasterRCNN:
"""Faster-RCNN architecture with a MobileNet V3 backbone as described in `"Faster R-CNN: Towards Real-Time
Object Detection with Region Proposal Networks" <https://arxiv.org/pdf/1506.01497.pdf>`_.
Example::
>>> import torch
>>> from doctr.models.obj_detection import fasterrcnn_mobilenet_v3_large_fpn
>>> model = fasterrcnn_mobilenet_v3_large_fpn(pretrained=True).eval()
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
>>> with torch.no_grad(): out = model(input_tensor)
Args:
pretrained (bool): If True, returns a model pre-trained on our object detection dataset
Returns:
object detection architecture
"""

return _fasterrcnn('fasterrcnn_mobilenet_v3_large_fpn', pretrained, **kwargs)
182 changes: 118 additions & 64 deletions references/obj_detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
import datetime
import logging
import multiprocessing as mp
import time

import numpy as np
import torch
import torch.optim as optim
import torchvision
import wandb
from fastprogress.fastprogress import master_bar, progress_bar
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.ops import MultiScaleRoIAlign

from doctr import transforms as T
from doctr.datasets import DocArtefacts
from doctr.models import obj_detection
from doctr.utils import DetectionMetric


Expand All @@ -40,42 +41,66 @@ def convert_to_abs_coords(targets, img_shape):
return targets


def fit_one_epoch(model, train_loader, optimizer, scheduler, mb, ):
def fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp=False):

if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
for images, targets in progress_bar(train_iter, parent=mb):
optimizer.zero_grad()

targets = convert_to_abs_coords(targets, images.shape)
if torch.cuda.is_available():
images = images.cuda()
targets = [{k: v.cuda() for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
loss = sum(v for v in loss_dict.values())
loss.backward()
optimizer.step()

optimizer.zero_grad()
if amp:
with torch.cuda.amp.autocast():
loss_dict = model(images, targets)
loss = sum(v for v in loss_dict.values())
scaler.scale(loss).backward()
# Update the params
scaler.step(optimizer)
scaler.update()
else:
loss_dict = model(images, targets)
loss = sum(v for v in loss_dict.values())
loss.backward()
optimizer.step()

mb.child.comment = f'Training loss: {loss.item()}'
scheduler.step()


@torch.no_grad()
def evaluate(model, val_loader, metric):
def evaluate(model, val_loader, metric, amp=False):
model.eval()
metric.reset()
val_iter = iter(val_loader)
for images, targets in val_iter:

images, targets = next(val_iter)
targets = convert_to_abs_coords(targets, images.shape)
if torch.cuda.is_available():
images = images.cuda()
output = model(images)

if amp:
with torch.cuda.amp.autocast():
output = model(images)
else:
output = model(images)

# Compute metric
pred_labels = np.concatenate([o['labels'].cpu().numpy() for o in output])
pred_boxes = np.concatenate([o['boxes'].cpu().numpy() for o in output])
gt_boxes = np.concatenate([o['boxes'].cpu().numpy() for o in targets])
gt_labels = np.concatenate([o['labels'].cpu().numpy() for o in targets])
metric.update(gt_boxes, pred_boxes, gt_labels, pred_labels)
recall, precision, mean_iou = metric.summary()
return recall, precision, mean_iou

return metric.summary()


def main(args):
Expand All @@ -87,31 +112,33 @@ def main(args):

torch.backends.cudnn.benchmark = True

# Filter keys
state_dict = {
k: v for k, v in torchvision.models.detection.__dict__[args.arch](pretrained=True).state_dict().items()
if not k.startswith('roi_heads.')
}
defaults = {"min_size": 800, "max_size": 1300,
"box_fg_iou_thresh": 0.5,
"box_bg_iou_thresh": 0.5,
"box_detections_per_img": 150, "box_score_thresh": 0.15, "box_positive_fraction": 0.35,
"box_nms_thresh": 0.2,
"rpn_pre_nms_top_n_train": 2000, "rpn_pre_nms_top_n_test": 1000,
"rpn_post_nms_top_n_train": 2000, "rpn_post_nms_top_n_test": 1000,
"rpn_nms_thresh": 0.2,
"rpn_batch_size_per_image": 250
}
kwargs = {**defaults}

model = torchvision.models.detection.__dict__[args.arch](pretrained=False, num_classes=5, **kwargs)
model.load_state_dict(state_dict, strict=False)
model.roi_heads.box_roi_pool = MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7),
sampling_ratio=2)
anchor_sizes = ((16), (64), (128), (264))
aspect_ratios = ((0.5, 1.0, 2.0, 3.0,)) * len(anchor_sizes)
model.rpn.anchor_generator.sizes = anchor_sizes
model.rpn.anchor_generator.aspect_ratios = aspect_ratios
st = time.time()
val_set = DocArtefacts(
train=False,
download=True,
sample_transforms=T.Resize((args.input_size, args.input_size)),
)
val_loader = DataLoader(
val_set,
batch_size=args.batch_size,
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(val_set),
pin_memory=torch.cuda.is_available(),
collate_fn=val_set.collate_fn,
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{len(val_loader)} batches)")

# Load doctr model
model = obj_detection.__dict__[args.arch](pretrained=args.pretrained, num_classes=5)

# Resume weights
if isinstance(args.resume, str):
print(f"Resuming {args.resume}")
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint)

# GPU
if isinstance(args.device, int):
if not torch.cuda.is_available():
Expand All @@ -126,33 +153,54 @@ def main(args):
if torch.cuda.is_available():
torch.cuda.set_device(args.device)
model = model.cuda()
optimizer = optim.SGD([p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.7)
train_set = DocArtefacts(train=True, download=True)
val_set = DocArtefacts(train=False, download=True)
train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=args.workers,
sampler=RandomSampler(train_set), pin_memory=torch.cuda.is_available(),
collate_fn=train_set.collate_fn,
drop_last=True)
val_loader = DataLoader(val_set, batch_size=args.batch_size, num_workers=args.workers,
sampler=SequentialSampler(val_set), pin_memory=torch.cuda.is_available(),
collate_fn=val_set.collate_fn,
drop_last=False)

# Metrics
metric = DetectionMetric(iou_thresh=0.5)

if args.test_only:
print("Running evaluation")
recall, precision, mean_iou = evaluate(model, val_loader, metric)
recall, precision, mean_iou = evaluate(model, val_loader, metric, amp=args.amp)
print(f"Recall: {recall:.2%} | Precision: {precision:.2%} |IoU: {mean_iou:.2%}")
return

st = time.time()
# Load both train and val data generators
train_set = DocArtefacts(
train=True,
download=True,
sample_transforms=T.Resize((args.input_size, args.input_size)),
)

train_loader = DataLoader(
train_set,
batch_size=args.batch_size,
drop_last=True,
num_workers=args.workers,
sampler=RandomSampler(train_set),
pin_memory=torch.cuda.is_available(),
collate_fn=train_set.collate_fn,
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{len(train_loader)} batches)")

# Backbone freezing
if args.freeze_backbone:
for p in model.backbone.parameters():
p.reguires_grad_(False)

# Optimizer
optimizer = optim.SGD([p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=args.weight_decay)
# Scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.7)

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

# W&B
if args.wb:

run = wandb.init(
name=exp_name,
project="object-detection",
Expand All @@ -167,32 +215,36 @@ def main(args):
"framework": "pytorch",
"scheduler": args.sched,
"pretrained": args.pretrained,
"amp": args.amp,
}
)

mb = master_bar(range(args.epochs))
max_score = 0.

for epoch in mb:
fit_one_epoch(model, train_loader, optimizer, scheduler, mb)
recall, precision, mean_iou = evaluate(model, val_loader, metric)
fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp=args.amp)
# Validation loop at the end of each epoch
recall, precision, mean_iou = evaluate(model, val_loader, metric, amp=args.amp)
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.

mb.write(
f"Epoch {epoch + 1}/{args.epochs} - "
f"Recall: {recall:.2%} | Precision: {precision:.2%} "
f"|IoU: {mean_iou:.2%}")
if f1_score > max_score:
print(f"Validation metric increased {max_score:.6} --> {f1_score:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
max_score = f1_score
log_msg = f"Epoch {epoch + 1}/{args.epochs} - "
if any(val is None for val in (recall, precision, mean_iou)):
log_msg += "Undefined metric value, caused by empty GTs or predictions"
else:
log_msg += f"Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%}"
mb.write(log_msg)
# W&B
if args.wb:
wandb.log({
'recall': recall,
'precision': precision,
'iou': mean_iou,
})
if f1_score > max_score:
print(f"Validation metric increased {max_score:.6} --> {f1_score:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
max_score = f1_score

if args.wb:
run.finish()
Expand All @@ -207,17 +259,19 @@ def parse_args():
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for training')
parser.add_argument('--device', default=None, type=int, help='device')
# parser.add_argument('--input_size', type=int, default=1024, help='model input size, H = W')
parser.add_argument('--input_size', type=int, default=1024, help='model input size, H = W')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (SGD)')
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop")
parser.add_argument('--freeze-backbone', dest='freeze_backbone', action='store_true',
help='freeze model backbone for fine-tuning')
parser.add_argument('--wb', dest='wb', action='store_true',
help='Log to Weights & Biases')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop")
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()
return args

Expand Down
Loading

0 comments on commit c9fbe35

Please sign in to comment.