Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[models] add ViTSTR TF and PT and update ViT to work as backbone #1055

Merged
merged 19 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ Credits where it's due: this repository is implementing, among others, architect
- CRNN: [An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/pdf/1507.05717.pdf).
- SAR: [Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition](https://arxiv.org/pdf/1811.00751.pdf).
- MASTER: [MASTER: Multi-Aspect Non-local Network for Scene Text Recognition](https://arxiv.org/pdf/1910.02562.pdf).
- ViTSTR: [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/pdf/2105.08582.pdf).


## More goodies
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Text recognition models
* SAR from `"Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_
* CRNN from `"An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_
* MASTER from `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition" <https://arxiv.org/pdf/1910.02562.pdf>`_
* ViTSTR from `"Vision Transformer for Fast and Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_


Supported datasets
Expand Down
6 changes: 6 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ doctr.models.classification

.. autofunction:: doctr.models.classification.magc_resnet31

.. autofunction:: doctr.models.classification.vit_s

.. autofunction:: doctr.models.classification.vit_b

.. autofunction:: doctr.models.classification.crop_orientation_predictor
Expand Down Expand Up @@ -67,6 +69,10 @@ doctr.models.recognition

.. autofunction:: doctr.models.recognition.master

.. autofunction:: doctr.models.recognition.vitstr_small

.. autofunction:: doctr.models.recognition.vitstr_base

.. autofunction:: doctr.models.recognition.recognition_predictor


Expand Down
65 changes: 54 additions & 11 deletions doctr/models/classification/vit/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@

from ...utils.pytorch import load_pretrained_params

__all__ = ["vit_b"]
__all__ = ["vit_s", "vit_b"]


default_cfgs: Dict[str, Dict[str, Any]] = {
"vit": {
"vit_b": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
},
"vit_s": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
Expand Down Expand Up @@ -57,25 +64,25 @@ class VisionTransformer(nn.Sequential):
<https://arxiv.org/pdf/2010.11929.pdf>`_.

Args:
input_shape: size of the input image
patch_size: size of the patches to be extracted from the input
d_model: dimension of the transformer layers
num_layers: number of transformer layers
num_heads: number of attention heads
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
input_shape: size of the input image
patch_size: size of the patches to be extracted from the input
dropout: dropout rate
num_classes: number of output classes
include_top: whether the classifier head should be instantiated
"""

def __init__(
self,
d_model: int,
num_layers: int,
num_heads: int,
ffd_ratio: int,
input_shape: Tuple[int, int, int] = (3, 32, 32),
patch_size: Tuple[int, int] = (4, 4),
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
ffd_ratio: int = 4,
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
Expand Down Expand Up @@ -128,8 +135,40 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
<https://arxiv.org/pdf/2010.11929.pdf>`_.

>>> import torch
>>> from doctr.models import vit
>>> model = vit(pretrained=False)
>>> from doctr.models import vit_b
>>> model = vit_b(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32)
>>> out = model(input_tensor)

Args:
pretrained: boolean, True if model is pretrained

Returns:
A feature extractor model
"""

return _vit(
"vit_b",
pretrained,
d_model=768,
num_layers=12,
num_heads=12,
ffd_ratio=4,
ignore_keys=["head.weight", "head.bias"],
**kwargs,
)


def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer-S architecture
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.

NOTE: unofficial config used in ViTSTR and ParSeq

>>> import torch
>>> from doctr.models import vit_s
>>> model = vit_s(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32)
>>> out = model(input_tensor)

Expand All @@ -141,8 +180,12 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""

return _vit(
"vit",
"vit_s",
pretrained,
d_model=384,
num_layers=12,
num_heads=6,
ffd_ratio=4,
ignore_keys=["head.weight", "head.bias"],
**kwargs,
)
64 changes: 53 additions & 11 deletions doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,18 @@

from ...utils import load_pretrained_params

__all__ = ["vit_b"]
__all__ = ["vit_s", "vit_b"]


default_cfgs: Dict[str, Dict[str, Any]] = {
"vit": {
"vit_s": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
},
"vit_b": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
Expand Down Expand Up @@ -54,25 +61,25 @@ class VisionTransformer(Sequential):
<https://arxiv.org/pdf/2010.11929.pdf>`_.

Args:
input_shape: size of the input image
patch_size: size of the patches to be extracted from the input
d_model: dimension of the transformer layers
num_layers: number of transformer layers
num_heads: number of attention heads
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
input_shape: size of the input image
patch_size: size of the patches to be extracted from the input
dropout: dropout rate
num_classes: number of output classes
include_top: whether the classifier head should be instantiated
"""

def __init__(
self,
d_model: int,
num_layers: int,
num_heads: int,
ffd_ratio: int,
input_shape: Tuple[int, int, int] = (32, 32, 3),
patch_size: Tuple[int, int] = (4, 4),
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
ffd_ratio: int = 4,
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
Expand Down Expand Up @@ -115,14 +122,45 @@ def _vit(
return model


def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer-S architecture
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.

NOTE: unofficial config used in ViTSTR and ParSeq

>>> import tf
>>> from doctr.models import vit_s
>>> model = vit_s(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)

Args:
pretrained: boolean, True if model is pretrained

Returns:
A feature extractor model
"""

return _vit(
"vit_s",
pretrained,
d_model=384,
num_layers=12,
num_heads=6,
ffd_ratio=4,
**kwargs,
)


def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer-B architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.

>>> import tensorflow as tf
>>> from doctr.models import vit
>>> model = vit(pretrained=False)
>>> from doctr.models import vit_b
>>> model = vit_b(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)

Expand All @@ -134,7 +172,11 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""

return _vit(
"vit",
"vit_b",
pretrained,
d_model=768,
num_layers=12,
num_heads=12,
ffd_ratio=4,
**kwargs,
)
1 change: 1 addition & 0 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"resnet50",
"resnet34_wide",
"vgg16_bn_r",
"vit_s",
"vit_b",
]
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]
Expand Down
44 changes: 42 additions & 2 deletions doctr/models/modules/vision_transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.


import math
from typing import Tuple

import torch
Expand All @@ -24,14 +25,53 @@ def __init__(

super().__init__()
channels, height, width = input_shape
self.patch_size = patch_size
# fix patch size if recognition task with 32x128 input
self.patch_size = (4, 8) if height != width else patch_size
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +28 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tricky condition: what are all the possible cases as input, and what do we want as patch_size for each?

Copy link
Contributor Author

@felixdittrich92 felixdittrich92 Sep 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it could be anything ..
currently in classification case: 32x32 -> (4, 4 (check)
recognition case: 32x128 -> (4, 8) (check)
detection case 1024x1024 (not handled)
any other size (not handled)

it will not fail but each size needs a different patch_size

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok two questions then:

  • how should the scale impact the patch size ? (N,N) --> (H,W) implies that (2N,2N) --> (?,?)
  • how should the aspect ratio impact the patch size? I see that (32,32) --> (4, 4), but why (32,128) doesn't do (4, 16) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(4, 16) would work also but i used the values from ParSeq for the PatchEmbedding of 32x128 samples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which is (4, 8)

Copy link
Contributor Author

@felixdittrich92 felixdittrich92 Sep 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/baudm/parseq/blob/main/configs/model/parseq.yaml
https://github.com/baudm/parseq/blob/main/configs/experiment/vitstr.yaml

If we would use a fixed ratio this would be easy to scale ... but yeah i took the values from ParSeq paper / implementation

self.grid_size = (height // patch_size[0], width // patch_size[1])
self.num_patches = (height // patch_size[0]) * (width // patch_size[1])

self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) # type: ignore[attr-defined]
self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) # type: ignore[attr-defined]
self.proj = nn.Linear((channels * self.patch_size[0] * self.patch_size[1]), embed_dim)

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
100 % borrowed from:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py

This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.positions.shape[1] - 1
if num_patches == num_positions and height == width:
return self.positions
class_pos_embed = self.positions[:, 0]
patch_pos_embed = self.positions[:, 1:]
dim = embeddings.shape[-1]
h0 = float(height // self.patch_size[0])
w0 = float(width // self.patch_size[1])
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
mode="bicubic",
align_corners=False,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
)
assert int(h0) == patch_pos_embed.shape[-2], "height of interpolated patch embedding doesn't match"
assert int(w0) == patch_pos_embed.shape[-1], "width of interpolated patch embedding doesn't match"

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height"
Expand All @@ -53,6 +93,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# concate cls_tokens to patches
embeddings = torch.cat([cls_tokens, patches], dim=1) # (batch_size, num_patches + 1, d_model)
# add positions to embeddings
embeddings += self.positions # (batch_size, num_patches + 1, d_model)
embeddings += self.interpolate_pos_encoding(embeddings, H, W) # (batch_size, num_patches + 1, d_model)

return embeddings
Loading