-
Notifications
You must be signed in to change notification settings - Fork 420
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added support of FasterRCNN for PyTorch (#691)
* feat: Added support of faster rcnn for Pytorch * style: Removed unnused import * test: Added unittest for Faster RCNN * refactor: Reflected changes on training script
- Loading branch information
Showing
5 changed files
with
235 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .faster_rcnn import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from doctr.file_utils import is_tf_available, is_torch_available | ||
|
||
if not is_tf_available() and is_torch_available(): | ||
from .pytorch import * # type: ignore[misc] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (C) 2021, Mindee. | ||
|
||
# This program is licensed under the Apache License version 2. | ||
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details. | ||
|
||
from typing import Any, Dict | ||
|
||
from torchvision.models.detection import FasterRCNN, faster_rcnn | ||
|
||
from ...utils import load_pretrained_params | ||
|
||
__all__ = ['fasterrcnn_mobilenet_v3_large_fpn'] | ||
|
||
|
||
default_cfgs: Dict[str, Dict[str, Any]] = { | ||
'fasterrcnn_mobilenet_v3_large_fpn': { | ||
'input_shape': (3, 1024, 1024), | ||
'mean': (0.485, 0.456, 0.406), | ||
'std': (0.229, 0.224, 0.225), | ||
'anchor_sizes': [32, 64, 128, 256, 512], | ||
'anchor_aspect_ratios': (0.5, 1., 2.), | ||
'num_classes': 5, | ||
'url': None, | ||
}, | ||
} | ||
|
||
|
||
def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN: | ||
|
||
_kwargs = { | ||
"image_mean": default_cfgs[arch]['mean'], | ||
"image_std": default_cfgs[arch]['std'], | ||
"box_detections_per_img": 150, | ||
"box_score_thresh": 0.15, | ||
"box_positive_fraction": 0.35, | ||
"box_nms_thresh": 0.2, | ||
"rpn_nms_thresh": 0.2, | ||
"num_classes": default_cfgs[arch]['num_classes'], | ||
} | ||
|
||
# Build the model | ||
_kwargs.update(kwargs) | ||
model = faster_rcnn.__dict__[arch](pretrained=False, pretrained_backbone=False, **_kwargs) | ||
|
||
if pretrained: | ||
# Load pretrained parameters | ||
load_pretrained_params(model, default_cfgs[arch]['url']) | ||
else: | ||
# Filter keys | ||
state_dict = { | ||
k: v for k, v in faster_rcnn.__dict__[arch](pretrained=True).state_dict().items() | ||
if not k.startswith('roi_heads.') | ||
} | ||
|
||
# Load state dict | ||
model.load_state_dict(state_dict, strict=False) | ||
|
||
return model | ||
|
||
|
||
def fasterrcnn_mobilenet_v3_large_fpn(pretrained: bool = False, **kwargs: Any) -> FasterRCNN: | ||
"""Faster-RCNN architecture with a MobileNet V3 backbone as described in `"Faster R-CNN: Towards Real-Time | ||
Object Detection with Region Proposal Networks" <https://arxiv.org/pdf/1506.01497.pdf>`_. | ||
Example:: | ||
>>> import torch | ||
>>> from doctr.models.obj_detection import fasterrcnn_mobilenet_v3_large_fpn | ||
>>> model = fasterrcnn_mobilenet_v3_large_fpn(pretrained=True).eval() | ||
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) | ||
>>> with torch.no_grad(): out = model(input_tensor) | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on our object detection dataset | ||
Returns: | ||
object detection architecture | ||
""" | ||
|
||
return _fasterrcnn('fasterrcnn_mobilenet_v3_large_fpn', pretrained, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.