Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added support of FasterRCNN for PyTorch #691

Merged
merged 4 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doctr/models/obj_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .faster_rcnn import *
4 changes: 4 additions & 0 deletions doctr/models/obj_detection/faster_rcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from doctr.file_utils import is_tf_available, is_torch_available

if not is_tf_available() and is_torch_available():
from .pytorch import * # type: ignore[misc]
79 changes: 79 additions & 0 deletions doctr/models/obj_detection/faster_rcnn/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (C) 2021, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import Any, Dict

from torchvision.models.detection import FasterRCNN, faster_rcnn

from ...utils import load_pretrained_params

__all__ = ['fasterrcnn_mobilenet_v3_large_fpn']


default_cfgs: Dict[str, Dict[str, Any]] = {
'fasterrcnn_mobilenet_v3_large_fpn': {
'input_shape': (3, 1024, 1024),
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'anchor_sizes': [32, 64, 128, 256, 512],
'anchor_aspect_ratios': (0.5, 1., 2.),
'num_classes': 5,
'url': None,
},
}


def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN:

_kwargs = {
"image_mean": default_cfgs[arch]['mean'],
"image_std": default_cfgs[arch]['std'],
"box_detections_per_img": 150,
"box_score_thresh": 0.15,
"box_positive_fraction": 0.35,
"box_nms_thresh": 0.2,
"rpn_nms_thresh": 0.2,
"num_classes": default_cfgs[arch]['num_classes'],
}

# Build the model
_kwargs.update(kwargs)
model = faster_rcnn.__dict__[arch](pretrained=False, pretrained_backbone=False, **_kwargs)

if pretrained:
# Load pretrained parameters
load_pretrained_params(model, default_cfgs[arch]['url'])
else:
# Filter keys
state_dict = {
k: v for k, v in faster_rcnn.__dict__[arch](pretrained=True).state_dict().items()
if not k.startswith('roi_heads.')
}

# Load state dict
model.load_state_dict(state_dict, strict=False)

return model


def fasterrcnn_mobilenet_v3_large_fpn(pretrained: bool = False, **kwargs: Any) -> FasterRCNN:
"""Faster-RCNN architecture with a MobileNet V3 backbone as described in `"Faster R-CNN: Towards Real-Time
Object Detection with Region Proposal Networks" <https://arxiv.org/pdf/1506.01497.pdf>`_.

Example::
>>> import torch
>>> from doctr.models.obj_detection import fasterrcnn_mobilenet_v3_large_fpn
>>> model = fasterrcnn_mobilenet_v3_large_fpn(pretrained=True).eval()
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
>>> with torch.no_grad(): out = model(input_tensor)

Args:
pretrained (bool): If True, returns a model pre-trained on our object detection dataset

Returns:
object detection architecture
"""

return _fasterrcnn('fasterrcnn_mobilenet_v3_large_fpn', pretrained, **kwargs)
182 changes: 118 additions & 64 deletions references/obj_detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
import datetime
import logging
import multiprocessing as mp
import time

import numpy as np
import torch
import torch.optim as optim
import torchvision
import wandb
from fastprogress.fastprogress import master_bar, progress_bar
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.ops import MultiScaleRoIAlign

from doctr import transforms as T
from doctr.datasets import DocArtefacts
from doctr.models import obj_detection
from doctr.utils import DetectionMetric


Expand All @@ -40,42 +41,66 @@ def convert_to_abs_coords(targets, img_shape):
return targets


def fit_one_epoch(model, train_loader, optimizer, scheduler, mb, ):
def fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp=False):

if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
for images, targets in progress_bar(train_iter, parent=mb):
optimizer.zero_grad()

targets = convert_to_abs_coords(targets, images.shape)
if torch.cuda.is_available():
images = images.cuda()
targets = [{k: v.cuda() for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
loss = sum(v for v in loss_dict.values())
loss.backward()
optimizer.step()

optimizer.zero_grad()
if amp:
with torch.cuda.amp.autocast():
loss_dict = model(images, targets)
loss = sum(v for v in loss_dict.values())
scaler.scale(loss).backward()
# Update the params
scaler.step(optimizer)
scaler.update()
else:
loss_dict = model(images, targets)
loss = sum(v for v in loss_dict.values())
loss.backward()
optimizer.step()

mb.child.comment = f'Training loss: {loss.item()}'
scheduler.step()


@torch.no_grad()
def evaluate(model, val_loader, metric):
def evaluate(model, val_loader, metric, amp=False):
model.eval()
metric.reset()
val_iter = iter(val_loader)
for images, targets in val_iter:

images, targets = next(val_iter)
targets = convert_to_abs_coords(targets, images.shape)
if torch.cuda.is_available():
images = images.cuda()
output = model(images)

if amp:
with torch.cuda.amp.autocast():
output = model(images)
else:
output = model(images)

# Compute metric
pred_labels = np.concatenate([o['labels'].cpu().numpy() for o in output])
pred_boxes = np.concatenate([o['boxes'].cpu().numpy() for o in output])
gt_boxes = np.concatenate([o['boxes'].cpu().numpy() for o in targets])
gt_labels = np.concatenate([o['labels'].cpu().numpy() for o in targets])
metric.update(gt_boxes, pred_boxes, gt_labels, pred_labels)
recall, precision, mean_iou = metric.summary()
return recall, precision, mean_iou

return metric.summary()


def main(args):
Expand All @@ -87,31 +112,33 @@ def main(args):

torch.backends.cudnn.benchmark = True

# Filter keys
state_dict = {
k: v for k, v in torchvision.models.detection.__dict__[args.arch](pretrained=True).state_dict().items()
if not k.startswith('roi_heads.')
}
defaults = {"min_size": 800, "max_size": 1300,
"box_fg_iou_thresh": 0.5,
"box_bg_iou_thresh": 0.5,
"box_detections_per_img": 150, "box_score_thresh": 0.15, "box_positive_fraction": 0.35,
"box_nms_thresh": 0.2,
"rpn_pre_nms_top_n_train": 2000, "rpn_pre_nms_top_n_test": 1000,
"rpn_post_nms_top_n_train": 2000, "rpn_post_nms_top_n_test": 1000,
"rpn_nms_thresh": 0.2,
"rpn_batch_size_per_image": 250
}
kwargs = {**defaults}

model = torchvision.models.detection.__dict__[args.arch](pretrained=False, num_classes=5, **kwargs)
model.load_state_dict(state_dict, strict=False)
model.roi_heads.box_roi_pool = MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7),
sampling_ratio=2)
anchor_sizes = ((16), (64), (128), (264))
aspect_ratios = ((0.5, 1.0, 2.0, 3.0,)) * len(anchor_sizes)
model.rpn.anchor_generator.sizes = anchor_sizes
model.rpn.anchor_generator.aspect_ratios = aspect_ratios
st = time.time()
val_set = DocArtefacts(
train=False,
download=True,
sample_transforms=T.Resize((args.input_size, args.input_size)),
)
val_loader = DataLoader(
val_set,
batch_size=args.batch_size,
drop_last=False,
num_workers=args.workers,
sampler=SequentialSampler(val_set),
pin_memory=torch.cuda.is_available(),
collate_fn=val_set.collate_fn,
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{len(val_loader)} batches)")

# Load doctr model
model = obj_detection.__dict__[args.arch](pretrained=args.pretrained, num_classes=5)

# Resume weights
if isinstance(args.resume, str):
print(f"Resuming {args.resume}")
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint)

# GPU
if isinstance(args.device, int):
if not torch.cuda.is_available():
Expand All @@ -126,33 +153,54 @@ def main(args):
if torch.cuda.is_available():
torch.cuda.set_device(args.device)
model = model.cuda()
optimizer = optim.SGD([p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.7)
train_set = DocArtefacts(train=True, download=True)
val_set = DocArtefacts(train=False, download=True)
train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=args.workers,
sampler=RandomSampler(train_set), pin_memory=torch.cuda.is_available(),
collate_fn=train_set.collate_fn,
drop_last=True)
val_loader = DataLoader(val_set, batch_size=args.batch_size, num_workers=args.workers,
sampler=SequentialSampler(val_set), pin_memory=torch.cuda.is_available(),
collate_fn=val_set.collate_fn,
drop_last=False)

# Metrics
metric = DetectionMetric(iou_thresh=0.5)

if args.test_only:
print("Running evaluation")
recall, precision, mean_iou = evaluate(model, val_loader, metric)
recall, precision, mean_iou = evaluate(model, val_loader, metric, amp=args.amp)
print(f"Recall: {recall:.2%} | Precision: {precision:.2%} |IoU: {mean_iou:.2%}")
return

st = time.time()
# Load both train and val data generators
train_set = DocArtefacts(
train=True,
download=True,
sample_transforms=T.Resize((args.input_size, args.input_size)),
)

train_loader = DataLoader(
train_set,
batch_size=args.batch_size,
drop_last=True,
num_workers=args.workers,
sampler=RandomSampler(train_set),
pin_memory=torch.cuda.is_available(),
collate_fn=train_set.collate_fn,
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{len(train_loader)} batches)")

# Backbone freezing
if args.freeze_backbone:
for p in model.backbone.parameters():
p.reguires_grad_(False)

# Optimizer
optimizer = optim.SGD([p for p in model.parameters() if p.requires_grad],
lr=args.lr, weight_decay=args.weight_decay)
# Scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.7)

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

# W&B
if args.wb:

run = wandb.init(
name=exp_name,
project="object-detection",
Expand All @@ -167,32 +215,36 @@ def main(args):
"framework": "pytorch",
"scheduler": args.sched,
"pretrained": args.pretrained,
"amp": args.amp,
}
)

mb = master_bar(range(args.epochs))
max_score = 0.

for epoch in mb:
fit_one_epoch(model, train_loader, optimizer, scheduler, mb)
recall, precision, mean_iou = evaluate(model, val_loader, metric)
fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp=args.amp)
# Validation loop at the end of each epoch
recall, precision, mean_iou = evaluate(model, val_loader, metric, amp=args.amp)
f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.

mb.write(
f"Epoch {epoch + 1}/{args.epochs} - "
f"Recall: {recall:.2%} | Precision: {precision:.2%} "
f"|IoU: {mean_iou:.2%}")
if f1_score > max_score:
print(f"Validation metric increased {max_score:.6} --> {f1_score:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
max_score = f1_score
log_msg = f"Epoch {epoch + 1}/{args.epochs} - "
if any(val is None for val in (recall, precision, mean_iou)):
log_msg += "Undefined metric value, caused by empty GTs or predictions"
else:
log_msg += f"Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%}"
mb.write(log_msg)
# W&B
if args.wb:
wandb.log({
'recall': recall,
'precision': precision,
'iou': mean_iou,
})
if f1_score > max_score:
print(f"Validation metric increased {max_score:.6} --> {f1_score:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
max_score = f1_score

if args.wb:
run.finish()
Expand All @@ -207,17 +259,19 @@ def parse_args():
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for training')
parser.add_argument('--device', default=None, type=int, help='device')
# parser.add_argument('--input_size', type=int, default=1024, help='model input size, H = W')
parser.add_argument('--input_size', type=int, default=1024, help='model input size, H = W')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (SGD)')
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop")
parser.add_argument('--freeze-backbone', dest='freeze_backbone', action='store_true',
help='freeze model backbone for fine-tuning')
parser.add_argument('--wb', dest='wb', action='store_true',
help='Log to Weights & Biases')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop")
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()
return args

Expand Down
Loading