Skip to content

Commit

Permalink
feat: Enabled backbone pretraining by default (#435)
Browse files Browse the repository at this point in the history
* refactor: Added top silent include_top arg

* feat: Enabled backbone pretraining and FPN dynamic sizing

* refactor: Refactored config of models

* refactor: Refactored LinkNet

* feat: Updated model zoos

* docs: Updated documentation

* test: Updated unittests

* refactor: Removed unused imports
  • Loading branch information
fg-mindee committed Aug 26, 2021
1 parent 0676769 commit bcd9f9e
Show file tree
Hide file tree
Showing 17 changed files with 251 additions and 151 deletions.
2 changes: 2 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Detection models
Models expect a TensorFlow tensor as input and produces one in return. DocTR includes implementations and pretrained versions of the following models:

.. autofunction:: doctr.models.detection.db_resnet50
.. autofunction:: doctr.models.detection.db_mobilenet_v3_large
.. autofunction:: doctr.models.detection.linknet16

Detection predictors
Expand Down Expand Up @@ -119,6 +120,7 @@ Models expect a TensorFlow tensor as input and produces one in return. DocTR inc


.. autofunction:: doctr.models.recognition.crnn_vgg16_bn
.. autofunction:: doctr.models.recognition.crnn_mobilenet_v3_large
.. autofunction:: doctr.models.recognition.sar_vgg16_bn
.. autofunction:: doctr.models.recognition.sar_resnet31
.. autofunction:: doctr.models.recognition.master
Expand Down
2 changes: 2 additions & 0 deletions doctr/models/backbones/resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class ResNet(Sequential):
conv_seq: wether to add a conv_sequence after each stage
pooling: pooling to add after each stage (if None, no pooling)
input_shape: shape of inputs
include_top: whether the classifier head should be instantiated
"""

def __init__(
Expand All @@ -134,6 +135,7 @@ def __init__(
Optional[Tuple[int, int]]
],
input_shape: Tuple[int, int, int] = (640, 640, 3),
include_top: bool = False,
) -> None:

_layers = [
Expand Down
2 changes: 2 additions & 0 deletions doctr/models/backbones/vgg/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ class VGG(Sequential):
planes: number of output channels in each stage
rect_pools: whether pooling square kernels should be replace with rectangular ones
input_shape: shapes of the input tensor
include_top: whether the classifier head should be instantiated
"""
def __init__(
self,
num_blocks: Tuple[int, int, int, int, int],
planes: Tuple[int, int, int, int, int],
rect_pools: Tuple[bool, bool, bool, bool, bool],
input_shape: Tuple[int, int, int] = (512, 512, 3),
include_top: bool = False,
) -> None:

_layers = []
Expand Down
40 changes: 23 additions & 17 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
from torch.nn import functional as F
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.deform_conv import DeformConv2d
from torchvision.models import resnet34, resnet50, mobilenet_v3_large
from torchvision.models import resnet34, resnet50
from typing import List, Dict, Any, Optional

from .base import DBPostProcessor, _DBNet
from ...backbones import mobilenet_v3_large
from ...utils import load_pretrained_params

__all__ = ['DBNet', 'db_resnet50', 'db_resnet34', 'db_mobilenet_v3']
__all__ = ['DBNet', 'db_resnet50', 'db_resnet34', 'db_mobilenet_v3_large']


default_cfgs: Dict[str, Dict[str, Any]] = {
'db_resnet50': {
'backbone': resnet50,
'backbone_submodule': None,
'fpn_layers': ['layer1', 'layer2', 'layer3', 'layer4'],
'fpn_channels': [256, 512, 1024, 2048],
'input_shape': (3, 1024, 1024),
'mean': (.5, .5, .5),
'std': (1., 1., 1.),
Expand All @@ -32,17 +32,15 @@
'backbone': resnet34,
'backbone_submodule': None,
'fpn_layers': ['layer1', 'layer2', 'layer3', 'layer4'],
'fpn_channels': [64, 128, 256, 512],
'input_shape': (3, 1024, 1024),
'mean': (.5, .5, .5),
'std': (1., 1., 1.),
'url': None,
},
'db_mobilenet_v3': {
'db_mobilenet_v3_large': {
'backbone': mobilenet_v3_large,
'backbone_submodule': 'features',
'fpn_layers': ['3', '6', '12', '16'],
'fpn_channels': [24, 40, 112, 960],
'input_shape': (3, 1024, 1024),
'mean': (.5, .5, .5),
'std': (1., 1., 1.),
Expand Down Expand Up @@ -102,7 +100,6 @@ class DBNet(_DBNet, nn.Module):
def __init__(
self,
feat_extractor: IntermediateLayerGetter,
fpn_channels: List[int],
head_chans: int = 256,
deform_conv: bool = False,
num_classes: int = 1,
Expand All @@ -113,14 +110,21 @@ def __init__(
super().__init__()
self.cfg = cfg

if len(feat_extractor.return_layers) != len(fpn_channels):
raise AssertionError

conv_layer = DeformConv2d if deform_conv else nn.Conv2d

self.rotated_bbox = rotated_bbox

self.feat_extractor = feat_extractor
# Identify the number of channels for the head initialization
_is_training = self.feat_extractor.training
self.feat_extractor = self.feat_extractor.eval()
with torch.no_grad():
out = self.feat_extractor(torch.zeros((1, 3, 224, 224)))
fpn_channels = [v.shape[1] for _, v in out.items()]

if _is_training:
self.feat_extractor = self.feat_extractor.train()

self.fpn = FeaturePyramidNetwork(fpn_channels, head_chans, deform_conv)
# Conv1 map to channels

Expand Down Expand Up @@ -245,8 +249,10 @@ def compute_loss(
return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss


def _dbnet(arch: str, pretrained: bool, pretrained_backbone: bool = False, **kwargs: Any) -> DBNet:
def _dbnet(arch: str, pretrained: bool, pretrained_backbone: bool = True, **kwargs: Any) -> DBNet:

# Starting with Imagenet pretrained params introduces some NaNs in layer3 & layer4 of resnet50
pretrained_backbone = pretrained_backbone and not arch.split('_')[1].startswith('resnet')
pretrained_backbone = pretrained_backbone and not pretrained

# Feature extractor
Expand All @@ -259,7 +265,7 @@ def _dbnet(arch: str, pretrained: bool, pretrained_backbone: bool = False, **kwa
)

# Build the model
model = DBNet(feat_extractor, default_cfgs[arch]['fpn_channels'], cfg=default_cfgs[arch], **kwargs)
model = DBNet(feat_extractor, cfg=default_cfgs[arch], **kwargs)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]['url'])
Expand Down Expand Up @@ -309,14 +315,14 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
return _dbnet('db_resnet50', pretrained, **kwargs)


def db_mobilenet_v3(pretrained: bool = False, **kwargs: Any) -> DBNet:
def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 backbone.
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
Example::
>>> import torch
>>> from doctr.models import db_mobilenet_v3
>>> model = db_mobilenet_v3(pretrained=True)
>>> from doctr.models import db_mobilenet_v3_large
>>> model = db_mobilenet_v3_large(pretrained=True)
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
>>> out = model(input_tensor)
Expand All @@ -327,4 +333,4 @@ def db_mobilenet_v3(pretrained: bool = False, **kwargs: Any) -> DBNet:
text detection architecture
"""

return _dbnet('db_mobilenet_v3', pretrained, **kwargs)
return _dbnet('db_mobilenet_v3_large', pretrained, **kwargs)
78 changes: 41 additions & 37 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,32 @@
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50
from typing import List, Tuple, Optional, Any, Dict

from ... import backbones
from ...backbones import mobilenet_v3_large
from doctr.utils.repr import NestedObject
from doctr.models.utils import IntermediateLayerGetter, load_pretrained_params, conv_sequence
from .base import DBPostProcessor, _DBNet

__all__ = ['DBNet', 'db_resnet50', 'db_mobilenet_v3_small']
__all__ = ['DBNet', 'db_resnet50', 'db_mobilenet_v3_large']


default_cfgs: Dict[str, Dict[str, Any]] = {
'db_resnet50': {
'mean': (0.798, 0.785, 0.772),
'std': (0.264, 0.2749, 0.287),
'backbone': 'ResNet50',
'backbone': ResNet50,
'fpn_layers': ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"],
'fpn_channels': 128,
'input_shape': (1024, 1024, 3),
'rotated_bbox': False,
'url': 'https://github.com/mindee/doctr/releases/download/v0.2.0/db_resnet50-adcafc63.zip',
},
'db_mobilenet_v3_small': {
'db_mobilenet_v3_large': {
'mean': (0.798, 0.785, 0.772),
'std': (0.264, 0.2749, 0.287),
'backbone': 'mobilenet_v3_small',
'fpn_layers': ["inverted_0", "inverted_2", "inverted_7", "final_block"],
'fpn_channels': 128,
'backbone': mobilenet_v3_large,
'fpn_layers': ["inverted_2", "inverted_5", "inverted_11", "final_block"],
'input_shape': (1024, 1024, 3),
'rotated_bbox': False,
'url': None,
},
}
Expand Down Expand Up @@ -113,14 +110,16 @@ class DBNet(_DBNet, keras.Model, NestedObject):
Args:
feature extractor: the backbone serving as feature extractor
fpn_channels: number of channels each extracted feature maps is mapped to
rotated_bbox: whether the segmentation map can include rotated bounding boxes
cfg: the configuration dict of the model
"""

_children_names: List[str] = ['feat_extractor', 'fpn', 'probability_head', 'threshold_head', 'postprocessor']

def __init__(
self,
feature_extractor: IntermediateLayerGetter,
fpn_channels: int = 128,
fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea
rotated_bbox: bool = False,
cfg: Optional[Dict[str, Any]] = None,
) -> None:
Expand Down Expand Up @@ -246,30 +245,31 @@ def call(
return out


def _db_resnet(arch: str, pretrained: bool, input_shape: Tuple[int, int, int] = None, **kwargs: Any) -> DBNet:
def _db_resnet(
arch: str,
pretrained: bool,
pretrained_backbone: bool = False,
input_shape: Tuple[int, int, int] = None,
**kwargs: Any
) -> DBNet:

pretrained_backbone = pretrained_backbone and not pretrained

# Patch the config
_cfg = deepcopy(default_cfgs[arch])
_cfg['input_shape'] = input_shape or _cfg['input_shape']
_cfg['fpn_channels'] = kwargs.get('fpn_channels', _cfg['fpn_channels'])
_cfg['rotated_bbox'] = kwargs.get('rotated_bbox', _cfg['rotated_bbox'])

# Feature extractor
resnet = tf.keras.applications.__dict__[_cfg['backbone']](
include_top=False,
weights=None,
input_shape=_cfg['input_shape'],
pooling=None,
)

feat_extractor = IntermediateLayerGetter(
resnet,
_cfg['backbone'](
include_top=False,
weights='imagenet' if pretrained_backbone else None,
input_shape=_cfg['input_shape'],
pooling=None,
),
_cfg['fpn_layers'],
)

kwargs['fpn_channels'] = _cfg['fpn_channels']
kwargs['rotated_bbox'] = _cfg['rotated_bbox']

# Build the model
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
# Load pretrained parameters
Expand All @@ -279,26 +279,30 @@ def _db_resnet(arch: str, pretrained: bool, input_shape: Tuple[int, int, int] =
return model


def _db_mobilenet(arch: str, pretrained: bool, input_shape: Tuple[int, int, int] = None, **kwargs: Any) -> DBNet:
def _db_mobilenet(
arch: str,
pretrained: bool,
pretrained_backbone: bool = True,
input_shape: Tuple[int, int, int] = None,
**kwargs: Any
) -> DBNet:

pretrained_backbone = pretrained_backbone and not pretrained

# Patch the config
_cfg = deepcopy(default_cfgs[arch])
_cfg['input_shape'] = input_shape or _cfg['input_shape']
_cfg['fpn_channels'] = kwargs.get('fpn_channels', _cfg['fpn_channels'])
_cfg['rotated_bbox'] = kwargs.get('rotated_bbox', _cfg['rotated_bbox'])

# Feature extractor
feat_extractor = IntermediateLayerGetter(
backbones.__dict__[_cfg['backbone']](
_cfg['backbone'](
input_shape=_cfg['input_shape'],
include_top=False,
pretrained=pretrained_backbone,
),
_cfg['fpn_layers'],
)

kwargs['fpn_channels'] = _cfg['fpn_channels']
kwargs['rotated_bbox'] = _cfg['rotated_bbox']

# Build the model
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
# Load pretrained parameters
Expand Down Expand Up @@ -329,14 +333,14 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
return _db_resnet('db_resnet50', pretrained, **kwargs)


def db_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> DBNet:
def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a mobilenet v3 small backbone.
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a mobilenet v3 large backbone.
Example::
>>> import tensorflow as tf
>>> from doctr.models import db_resnet50
>>> model = db_mobilenet_v3_small(pretrained=True)
>>> from doctr.models import db_mobilenet_v3_large
>>> model = db_mobilenet_v3_large(pretrained=True)
>>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Expand All @@ -347,4 +351,4 @@ def db_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> DBNet:
text detection architecture
"""

return _db_mobilenet('db_mobilenet_v3_small', pretrained, **kwargs)
return _db_mobilenet('db_mobilenet_v3_large', pretrained, **kwargs)
Loading

0 comments on commit bcd9f9e

Please sign in to comment.