Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] [models] add ViTSTR in TF and PT #1048

Closed
wants to merge 17 commits into from
Closed
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ Credits where it's due: this repository is implementing, among others, architect
- CRNN: [An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/pdf/1507.05717.pdf).
- SAR: [Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition](https://arxiv.org/pdf/1811.00751.pdf).
- MASTER: [MASTER: Multi-Aspect Non-local Network for Scene Text Recognition](https://arxiv.org/pdf/1910.02562.pdf).
- ViTSTR: [Vision Transformer for Fast and Efficient Scene Text Recognition](https://arxiv.org/pdf/2105.08582.pdf).


## More goodies
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Text recognition models
* SAR from `"Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition" <https://arxiv.org/pdf/1811.00751.pdf>`_
* CRNN from `"An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_
* MASTER from `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition" <https://arxiv.org/pdf/1910.02562.pdf>`_
* ViTSTR from `"Vision Transformer for Fast and Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_


Supported datasets
Expand Down
2 changes: 2 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ doctr.models.recognition

.. autofunction:: doctr.models.recognition.master

.. autofunction:: doctr.models.recognition.vitstr

.. autofunction:: doctr.models.recognition.recognition_predictor


Expand Down
1 change: 1 addition & 0 deletions doctr/models/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .transformer import *
from .vision_transformer import *
6 changes: 3 additions & 3 deletions doctr/models/modules/transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torch import nn

__all__ = ["Decoder", "PositionalEncoding"]
__all__ = ["Decoder", "PositionalEncoding", "MultiHeadAttention", "PositionwiseFeedForward"]


class PositionalEncoding(nn.Module):
Expand Down Expand Up @@ -57,10 +57,10 @@ def scaled_dot_product_attention(
class PositionwiseFeedForward(nn.Sequential):
"""Position-wise Feed-Forward Network"""

def __init__(self, d_model: int, ffd: int, dropout: float = 0.1) -> None:
def __init__(self, d_model: int, ffd: int, dropout: float = 0.1, use_gelu: bool = False) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boolean value for activation selection is rather limited: are we positive that only relu & gelu can be used for such architecture types?

super().__init__(
nn.Linear(d_model, ffd),
nn.ReLU(),
nn.ReLU() if not use_gelu else nn.GELU(), # Gelu for ViT
nn.Dropout(p=dropout),
nn.Linear(ffd, d_model),
)
Expand Down
11 changes: 8 additions & 3 deletions doctr/models/modules/transformer/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from doctr.utils.repr import NestedObject

__all__ = ["Decoder", "PositionalEncoding"]
__all__ = ["Decoder", "PositionalEncoding", "MultiHeadAttention", "PositionwiseFeedForward"]

tf.config.run_functions_eagerly(True)

Expand Down Expand Up @@ -74,14 +74,19 @@ def scaled_dot_product_attention(
class PositionwiseFeedForward(layers.Layer, NestedObject):
"""Position-wise Feed-Forward Network"""

def __init__(self, d_model: int, ffd: int, dropout=0.1) -> None:
def __init__(self, d_model: int, ffd: int, dropout=0.1, use_gelu: bool = False) -> None:
super(PositionwiseFeedForward, self).__init__()
self.use_gelu = use_gelu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instantiate a self.activation_fn in the constructor to avoid conditional execution in the call 👍


self.first_linear = layers.Dense(ffd, kernel_initializer=tf.initializers.he_uniform())
self.sec_linear = layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform())
self.dropout = layers.Dropout(rate=dropout)

def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
x = tf.nn.relu(self.first_linear(x, **kwargs))
if self.use_gelu: # used for ViT
x = tf.nn.gelu(self.first_linear(x, **kwargs))
else:
x = tf.nn.relu(self.first_linear(x, **kwargs))
x = self.dropout(x, **kwargs)
x = self.sec_linear(x, **kwargs)
return x
Expand Down
6 changes: 6 additions & 0 deletions doctr/models/modules/vision_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from doctr.file_utils import is_tf_available, is_torch_available

if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
100 changes: 100 additions & 0 deletions doctr/models/modules/vision_transformer/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (C) 2022, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Tuple

import torch
from torch import nn

from ..transformer.pytorch import MultiHeadAttention, PositionwiseFeedForward

__all__ = ["VisionTransformer"]


class PatchEmbedding(nn.Module):
"""Compute 2D patch embedding"""

# Inpired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI: can you confirm that you made a lot of modifications?
inspired by is rather light
borrowed from is more significant


def __init__(
self,
img_size: Tuple[int, int],
patch_size: Tuple[int, int],
channels: int,
embed_dim: int = 768,
) -> None:

super().__init__()
self.img_size = img_size
self.patch_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
self.proj = nn.Conv2d(channels, embed_dim, kernel_size=patch_size, stride=patch_size)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape

assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height"
assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width"

return self.proj(x) # BCHW


class VisionTransformer(nn.Module):
"""VisionTransformer architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_."""
Comment on lines +45 to +47
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's specify the constructor args


def __init__(
self,
img_size: Tuple[int, int],
patch_size: Tuple[int, int],
channels: int = 3,
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
dropout: float = 0.1,
) -> None:

super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_layers = num_layers

self.patch_embedding = PatchEmbedding(self.img_size, self.patch_size, channels, d_model)
self.num_patches = self.patch_embedding.num_patches
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model)) # type: ignore[attr-defined]
self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, d_model)) # type: ignore[attr-defined]

self.layer_norm = nn.LayerNorm(d_model, eps=1e-5)
self.dropout = nn.Dropout(dropout)

self.attention = nn.ModuleList(
[MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
)
self.position_feed_forward = nn.ModuleList(
[PositionwiseFeedForward(d_model, d_model, dropout, use_gelu=True) for _ in range(self.num_layers)]
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
patches = self.patch_embedding(x)

B, C, H, W = patches.shape
patches = patches.view(B, C, -1).permute(0, 2, 1) # (batch_size, num_patches, d_model)
cls_tokens = self.cls_token.repeat(B, 1, 1) # (batch_size, 1, d_model)
# concate cls_tokens to patches
embeddings = torch.cat([cls_tokens, patches], dim=1) # (batch_size, num_patches + 1, d_model)
# add positions to embeddings
embeddings += self.positions # (batch_size, num_patches + 1, d_model)

output = embeddings

for i in range(self.num_layers):
normed_output = self.layer_norm(output)
output = output + self.dropout(self.attention[i](normed_output, normed_output, normed_output))
normed_output = self.layer_norm(output)
output = output + self.dropout(self.position_feed_forward[i](normed_output))

# (batch_size, seq_len + cls token, d_model)
return self.layer_norm(output)
110 changes: 110 additions & 0 deletions doctr/models/modules/vision_transformer/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (C) 2022, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, Tuple

import tensorflow as tf
from tensorflow.keras import layers

from doctr.utils.repr import NestedObject

from ..transformer.tensorflow import MultiHeadAttention, PositionwiseFeedForward

__all__ = ["VisionTransformer"]

tf.config.run_functions_eagerly(True)


class PatchEmbedding(layers.Layer, NestedObject):
"""Compute 2D patch embedding"""

def __init__(
self,
img_size: Tuple[int, int],
patch_size: Tuple[int, int],
embed_dim: int = 768,
) -> None:

super().__init__()
self.img_size = img_size
self.patch_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
self.proj = layers.Conv2D(
embed_dim,
kernel_size=patch_size,
strides=patch_size,
padding="valid",
kernel_initializer="he_normal",
)

def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
B, W, H, C = x.shape

assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height"
assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width"

return self.proj(x, **kwargs) # BHWC


class VisionTransformer(layers.Layer, NestedObject):
"""VisionTransformer architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_."""
Comment on lines +52 to +54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here


def __init__(
self,
img_size: Tuple[int, int],
patch_size: Tuple[int, int],
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
dropout: float = 0.1,
) -> None:

super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_layers = num_layers

self.patch_embedding = PatchEmbedding(self.img_size, self.patch_size, d_model)
self.num_patches = self.patch_embedding.num_patches
self.cls_token = self.add_weight(shape=(1, 1, d_model), initializer="zeros", trainable=True, name="cls_token")
self.positions = self.add_weight(
shape=(1, self.num_patches + 1, d_model), initializer="zeros", trainable=True, name="positions"
)

self.layer_norm = layers.LayerNormalization(epsilon=1e-5)
self.dropout = layers.Dropout(rate=dropout)

self.attention = [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
self.position_feed_forward = [
PositionwiseFeedForward(d_model, d_model, dropout, use_gelu=True) for _ in range(self.num_layers)
]

def __call__(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
patches = self.patch_embedding(x)

B, C, H, W = patches.shape
patches = tf.reshape(patches, (B, (C * H), W)) # (batch_size, num_patches, d_model)

cls_tokens = tf.repeat(self.cls_token, B, axis=0) # (batch_size, num_patches, d_model)
# concate cls_tokens to patches
embeddings = tf.concat([cls_tokens, patches], axis=1) # (batch_size, num_patches + 1, d_model)
# add positions to embeddings
embeddings += self.positions # (batch_size, num_patches + 1, d_model)

output = embeddings

for i in range(self.num_layers):
normed_output = self.layer_norm(output, **kwargs)
output = output + self.dropout(
self.attention[i](normed_output, normed_output, normed_output, **kwargs),
**kwargs,
)
normed_output = self.layer_norm(output, **kwargs)
output = output + self.dropout(self.position_feed_forward[i](normed_output, **kwargs), **kwargs)

# (batch_size, seq_len + cls token, d_model)
return self.layer_norm(output, **kwargs)
1 change: 1 addition & 0 deletions doctr/models/recognition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .crnn import *
from .master import *
from .sar import *
from .vitstr import *
from .zoo import *
6 changes: 6 additions & 0 deletions doctr/models/recognition/vitstr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from doctr.file_utils import is_tf_available, is_torch_available

if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[misc]
Loading