Skip to content

Commit

Permalink
refactor: Refactored LinkNet (#733)
Browse files Browse the repository at this point in the history
* refactor: Refactored LinkNet

* docs: Updated documentation

* test: Updated unittests

* refactor: Reflected changes on linknet naming

* fix: Fixed LinkNet

* refactor: Refactored LinkNet TF
  • Loading branch information
fg-mindee committed Dec 22, 2021
1 parent eb35c5f commit 808081f
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 92 deletions.
2 changes: 1 addition & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ doctr.models.classification
doctr.models.detection
----------------------

.. autofunction:: doctr.models.detection.linknet16
.. autofunction:: doctr.models.detection.linknet_resnet18

.. autofunction:: doctr.models.detection.db_resnet50

Expand Down
108 changes: 36 additions & 72 deletions doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,25 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet18
from torchvision.models._utils import IntermediateLayerGetter

from ...utils import load_pretrained_params
from .base import LinkNetPostProcessor, _LinkNet

__all__ = ['LinkNet', 'linknet16']
__all__ = ['LinkNet', 'linknet_resnet18']


default_cfgs: Dict[str, Dict[str, Any]] = {
'linknet16': {
'layout': [64, 128, 256, 512],
'linknet_resnet18': {
'backbone': resnet18,
'fpn_layers': ['layer1', 'layer2', 'layer3', 'layer4'],
'input_shape': (3, 1024, 1024),
'mean': (.5, .5, .5),
'std': (1., 1., 1.),
Expand All @@ -28,76 +30,35 @@
}


class LinkNetEncoder(nn.Module):
def __init__(self, in_chans: int, out_chans: int):
class LinkNetFPN(nn.Module):
def __init__(self, layer_shapes: List[Tuple[int, int, int]]) -> None:
super().__init__()
self.stage1 = nn.Sequential(
nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(out_chans),
nn.ReLU(inplace=True),
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_chans),
)

self.shortcut = nn.Sequential(
nn.Conv2d(in_chans, out_chans, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(out_chans),
) if in_chans != out_chans else None

self.stage2 = nn.Sequential(
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_chans),
nn.ReLU(inplace=True),
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_chans),
nn.ReLU(inplace=True),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.stage1(x)
if self.shortcut is not None:
out += self.shortcut(x)
out = F.relu(out, inplace=True)
out = self.stage2(out) + out

return out


def linknet_backbone(layout: List[int], in_channels: int = 3, stem_channels: int = 64) -> nn.Sequential:
# Stem
_layers: List[nn.Module] = [
nn.Conv2d(in_channels, stem_channels, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(stem_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
]
# Encoders
for in_chan, out_chan in zip([stem_channels] + layout[:-1], layout):
_layers.append(LinkNetEncoder(in_chan, out_chan))

return nn.Sequential(*_layers)
strides = [
1 if (in_shape[-1] == out_shape[-1]) else 2
for in_shape, out_shape in zip(layer_shapes[:-1], layer_shapes[1:])
]

chans = [shape[0] for shape in layer_shapes]

class LinkNetFPN(nn.Module):
def __init__(self, layout: List[int], in_channels: int = 64) -> None:
super().__init__()
_decoder_layers = [
self.decoder_block(out_chan, in_chan) for in_chan, out_chan in zip([in_channels] + layout[:-1], layout)
self.decoder_block(ochan, ichan, stride) for ichan, ochan, stride in zip(chans[:-1], chans[1:], strides)
]

self.decoders = nn.ModuleList(_decoder_layers)

@staticmethod
def decoder_block(in_chan: int, out_chan: int) -> nn.Sequential:
def decoder_block(in_chan: int, out_chan: int, stride: int) -> nn.Sequential:
"""Creates a LinkNet decoder block"""

mid_chan = in_chan // 4
return nn.Sequential(
nn.Conv2d(in_chan, in_chan // 4, kernel_size=1, bias=False),
nn.BatchNorm2d(in_chan // 4),
nn.Conv2d(in_chan, mid_chan, kernel_size=1, bias=False),
nn.BatchNorm2d(mid_chan),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_chan // 4, in_chan // 4, 3, padding=1, output_padding=1, stride=2, bias=False),
nn.BatchNorm2d(in_chan // 4),
nn.ConvTranspose2d(mid_chan, mid_chan, 3, padding=1, output_padding=stride - 1, stride=stride, bias=False),
nn.BatchNorm2d(mid_chan),
nn.ReLU(inplace=True),
nn.Conv2d(in_chan // 4, out_chan, kernel_size=1, bias=False),
nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=False),
nn.BatchNorm2d(out_chan),
nn.ReLU(inplace=True),
)
Expand Down Expand Up @@ -131,16 +92,20 @@ def __init__(
_is_training = self.feat_extractor.training
self.feat_extractor = self.feat_extractor.eval()
with torch.no_grad():
out = self.feat_extractor(torch.zeros((1, 3, 224, 224)))
fpn_channels = [v.shape[1] for _, v in out.items()]
in_shape = (3, 512, 512)
out = self.feat_extractor(torch.zeros((1, *in_shape)))
# Get the shapes of the extracted feature maps
_shapes = [v.shape[1:] for _, v in out.items()]
# Prepend the expected shapes of the first encoder
_shapes = [(_shapes[0][0], in_shape[1] // 4, in_shape[2] // 4)] + _shapes

if _is_training:
self.feat_extractor = self.feat_extractor.train()

self.fpn = LinkNetFPN(fpn_channels, fpn_channels[0])
self.fpn = LinkNetFPN(_shapes)

self.classifier = nn.Sequential(
nn.ConvTranspose2d(fpn_channels[0], 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),
nn.ConvTranspose2d(_shapes[0][0], 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),
Expand Down Expand Up @@ -233,13 +198,12 @@ def _linknet(arch: str, pretrained: bool, pretrained_backbone: bool = False, **k
pretrained_backbone = pretrained_backbone and not pretrained

# Build the feature extractor
backbone = linknet_backbone(default_cfgs[arch]['layout'])
backbone = default_cfgs[arch]['backbone']()
if pretrained_backbone:
load_pretrained_params(backbone, None)

feat_extractor = IntermediateLayerGetter(
backbone,
{str(layer): str(idx) for idx, layer in enumerate(range(4, 4 + len(default_cfgs[arch]['layout'])))},
{layer_name: str(idx) for idx, layer_name in enumerate(default_cfgs[arch]['fpn_layers'])},
)

# Build the model
Expand All @@ -251,14 +215,14 @@ def _linknet(arch: str, pretrained: bool, pretrained_backbone: bool = False, **k
return model


def linknet16(pretrained: bool = False, **kwargs: Any) -> LinkNet:
def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
<https://arxiv.org/pdf/1707.03718.pdf>`_.
Example::
>>> import torch
>>> from doctr.models import linknet16
>>> model = linknet16(pretrained=True).eval()
>>> from doctr.models import linknet_resnet18
>>> model = linknet_resnet18(pretrained=True).eval()
>>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
>>> with torch.no_grad(): out = model(input_tensor)
Expand All @@ -269,4 +233,4 @@ def linknet16(pretrained: bool = False, **kwargs: Any) -> LinkNet:
text detection architecture
"""

return _linknet('linknet16', pretrained, **kwargs)
return _linknet('linknet_resnet18', pretrained, **kwargs)
26 changes: 13 additions & 13 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

from .base import LinkNetPostProcessor, _LinkNet

__all__ = ['LinkNet', 'linknet16']
__all__ = ['LinkNet', 'linknet_resnet18']


default_cfgs: Dict[str, Dict[str, Any]] = {
'linknet16': {
'linknet_resnet18': {
'mean': (0.798, 0.785, 0.772),
'std': (0.264, 0.2749, 0.287),
'input_shape': (1024, 1024, 3),
Expand All @@ -32,15 +32,15 @@
}


def decoder_block(in_chan: int, out_chan: int) -> Sequential:
def decoder_block(in_chan: int, out_chan: int, stride: int) -> Sequential:
"""Creates a LinkNet decoder block"""

return Sequential([
*conv_sequence(in_chan // 4, 'relu', True, kernel_size=1),
layers.Conv2DTranspose(
filters=in_chan // 4,
kernel_size=3,
strides=2,
strides=stride,
padding="same",
use_bias=False,
kernel_initializer='he_normal'
Expand All @@ -59,14 +59,14 @@ def __init__(
) -> None:

super().__init__()
self.encoder_1 = ResnetStage(num_blocks=2, output_channels=64, downsample=True)
self.encoder_1 = ResnetStage(num_blocks=2, output_channels=64, downsample=False)
self.encoder_2 = ResnetStage(num_blocks=2, output_channels=128, downsample=True)
self.encoder_3 = ResnetStage(num_blocks=2, output_channels=256, downsample=True)
self.encoder_4 = ResnetStage(num_blocks=2, output_channels=512, downsample=True)
self.decoder_1 = decoder_block(in_chan=64, out_chan=64)
self.decoder_2 = decoder_block(in_chan=128, out_chan=64)
self.decoder_3 = decoder_block(in_chan=256, out_chan=128)
self.decoder_4 = decoder_block(in_chan=512, out_chan=256)
self.decoder_1 = decoder_block(in_chan=64, out_chan=64, stride=1)
self.decoder_2 = decoder_block(in_chan=128, out_chan=64, stride=2)
self.decoder_3 = decoder_block(in_chan=256, out_chan=128, stride=2)
self.decoder_4 = decoder_block(in_chan=512, out_chan=256, stride=2)

def call(
self,
Expand Down Expand Up @@ -216,14 +216,14 @@ def _linknet(arch: str, pretrained: bool, input_shape: Tuple[int, int, int] = No
return model


def linknet16(pretrained: bool = False, **kwargs: Any) -> LinkNet:
def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
<https://arxiv.org/pdf/1707.03718.pdf>`_.
Example::
>>> import tensorflow as tf
>>> from doctr.models import linknet16
>>> model = linknet16(pretrained=True)
>>> from doctr.models import linknet_resnet18
>>> model = linknet_resnet18(pretrained=True)
>>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Expand All @@ -234,4 +234,4 @@ def linknet16(pretrained: bool = False, **kwargs: Any) -> LinkNet:
text detection architecture
"""

return _linknet('linknet16', pretrained, **kwargs)
return _linknet('linknet_resnet18', pretrained, **kwargs)
4 changes: 2 additions & 2 deletions doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@


if is_tf_available():
ARCHS = ['db_resnet50', 'db_mobilenet_v3_large', 'linknet16']
ARCHS = ['db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18']
elif is_torch_available():
ARCHS = ['db_resnet34', 'db_resnet50', 'db_mobilenet_v3_large', 'linknet16']
ARCHS = ['db_resnet34', 'db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18']


def _predictor(
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_models_detection_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
["db_resnet34", (3, 512, 512), (1, 512, 512), True],
["db_resnet50", (3, 512, 512), (1, 512, 512), True],
["db_mobilenet_v3_large", (3, 512, 512), (1, 512, 512), True],
["linknet16", (3, 512, 512), (1, 512, 512), False],
["linknet_resnet18", (3, 512, 512), (1, 512, 512), False],
],
)
def test_detection_models(arch_name, input_shape, output_size, out_prob):
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob):
"db_resnet34",
"db_resnet50",
"db_mobilenet_v3_large",
"linknet16",
"linknet_resnet18",
],
)
def test_detection_zoo(arch_name):
Expand Down
4 changes: 2 additions & 2 deletions tests/tensorflow/test_models_detection_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
[
["db_resnet50", (512, 512, 3), (512, 512, 1), True],
["db_mobilenet_v3_large", (512, 512, 3), (512, 512, 1), True],
["linknet16", (512, 512, 3), (512, 512, 1), False],
["linknet_resnet18", (512, 512, 3), (512, 512, 1), False],
],
)
def test_detection_models(arch_name, input_shape, output_size, out_prob):
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_rotated_detectionpredictor(mock_pdf): # noqa: F811
[
"db_resnet50",
"db_mobilenet_v3_large",
"linknet16",
"linknet_resnet18",
],
)
def test_detection_zoo(arch_name):
Expand Down

0 comments on commit 808081f

Please sign in to comment.