-
Notifications
You must be signed in to change notification settings - Fork 420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] [models] add ViTSTR in TF and PT #1048
Changes from all commits
69446ec
21e1d37
1a0a699
2e6fc96
c3f3867
8db9236
6437335
446aa97
79e98c6
08621cc
d736495
a4b93d1
65da057
50a500b
86b1302
c63d594
006446b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .transformer import * | ||
from .vision_transformer import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
|
||
from doctr.utils.repr import NestedObject | ||
|
||
__all__ = ["Decoder", "PositionalEncoding"] | ||
__all__ = ["Decoder", "PositionalEncoding", "MultiHeadAttention", "PositionwiseFeedForward"] | ||
|
||
tf.config.run_functions_eagerly(True) | ||
|
||
|
@@ -74,14 +74,19 @@ def scaled_dot_product_attention( | |
class PositionwiseFeedForward(layers.Layer, NestedObject): | ||
"""Position-wise Feed-Forward Network""" | ||
|
||
def __init__(self, d_model: int, ffd: int, dropout=0.1) -> None: | ||
def __init__(self, d_model: int, ffd: int, dropout=0.1, use_gelu: bool = False) -> None: | ||
super(PositionwiseFeedForward, self).__init__() | ||
self.use_gelu = use_gelu | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instantiate a |
||
|
||
self.first_linear = layers.Dense(ffd, kernel_initializer=tf.initializers.he_uniform()) | ||
self.sec_linear = layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform()) | ||
self.dropout = layers.Dropout(rate=dropout) | ||
|
||
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: | ||
x = tf.nn.relu(self.first_linear(x, **kwargs)) | ||
if self.use_gelu: # used for ViT | ||
x = tf.nn.gelu(self.first_linear(x, **kwargs)) | ||
else: | ||
x = tf.nn.relu(self.first_linear(x, **kwargs)) | ||
x = self.dropout(x, **kwargs) | ||
x = self.sec_linear(x, **kwargs) | ||
return x | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from doctr.file_utils import is_tf_available, is_torch_available | ||
|
||
if is_tf_available(): | ||
from .tensorflow import * | ||
elif is_torch_available(): | ||
from .pytorch import * # type: ignore[misc] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (C) 2022, Mindee. | ||
|
||
# This program is licensed under the Apache License 2.0. | ||
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from ..transformer.pytorch import MultiHeadAttention, PositionwiseFeedForward | ||
|
||
__all__ = ["VisionTransformer"] | ||
|
||
|
||
class PatchEmbedding(nn.Module): | ||
"""Compute 2D patch embedding""" | ||
|
||
# Inpired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just FYI: can you confirm that you made a lot of modifications? |
||
|
||
def __init__( | ||
self, | ||
img_size: Tuple[int, int], | ||
patch_size: Tuple[int, int], | ||
channels: int, | ||
embed_dim: int = 768, | ||
) -> None: | ||
|
||
super().__init__() | ||
self.img_size = img_size | ||
self.patch_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) | ||
self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) | ||
self.proj = nn.Conv2d(channels, embed_dim, kernel_size=patch_size, stride=patch_size) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
B, C, H, W = x.shape | ||
|
||
assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height" | ||
assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width" | ||
|
||
return self.proj(x) # BCHW | ||
|
||
|
||
class VisionTransformer(nn.Module): | ||
"""VisionTransformer architecture as described in | ||
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", | ||
<https://arxiv.org/pdf/2010.11929.pdf>`_.""" | ||
Comment on lines
+45
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's specify the constructor args |
||
|
||
def __init__( | ||
self, | ||
img_size: Tuple[int, int], | ||
patch_size: Tuple[int, int], | ||
channels: int = 3, | ||
d_model: int = 768, | ||
num_layers: int = 12, | ||
num_heads: int = 12, | ||
dropout: float = 0.1, | ||
) -> None: | ||
|
||
super().__init__() | ||
self.img_size = img_size | ||
self.patch_size = patch_size | ||
self.num_layers = num_layers | ||
|
||
self.patch_embedding = PatchEmbedding(self.img_size, self.patch_size, channels, d_model) | ||
self.num_patches = self.patch_embedding.num_patches | ||
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model)) # type: ignore[attr-defined] | ||
self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, d_model)) # type: ignore[attr-defined] | ||
|
||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-5) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
self.attention = nn.ModuleList( | ||
[MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)] | ||
) | ||
self.position_feed_forward = nn.ModuleList( | ||
[PositionwiseFeedForward(d_model, d_model, dropout, use_gelu=True) for _ in range(self.num_layers)] | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
patches = self.patch_embedding(x) | ||
|
||
B, C, H, W = patches.shape | ||
patches = patches.view(B, C, -1).permute(0, 2, 1) # (batch_size, num_patches, d_model) | ||
cls_tokens = self.cls_token.repeat(B, 1, 1) # (batch_size, 1, d_model) | ||
# concate cls_tokens to patches | ||
embeddings = torch.cat([cls_tokens, patches], dim=1) # (batch_size, num_patches + 1, d_model) | ||
# add positions to embeddings | ||
embeddings += self.positions # (batch_size, num_patches + 1, d_model) | ||
|
||
output = embeddings | ||
|
||
for i in range(self.num_layers): | ||
normed_output = self.layer_norm(output) | ||
output = output + self.dropout(self.attention[i](normed_output, normed_output, normed_output)) | ||
normed_output = self.layer_norm(output) | ||
output = output + self.dropout(self.position_feed_forward[i](normed_output)) | ||
|
||
# (batch_size, seq_len + cls token, d_model) | ||
return self.layer_norm(output) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (C) 2022, Mindee. | ||
|
||
# This program is licensed under the Apache License 2.0. | ||
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | ||
|
||
from typing import Any, Tuple | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras import layers | ||
|
||
from doctr.utils.repr import NestedObject | ||
|
||
from ..transformer.tensorflow import MultiHeadAttention, PositionwiseFeedForward | ||
|
||
__all__ = ["VisionTransformer"] | ||
|
||
tf.config.run_functions_eagerly(True) | ||
|
||
|
||
class PatchEmbedding(layers.Layer, NestedObject): | ||
"""Compute 2D patch embedding""" | ||
|
||
def __init__( | ||
self, | ||
img_size: Tuple[int, int], | ||
patch_size: Tuple[int, int], | ||
embed_dim: int = 768, | ||
) -> None: | ||
|
||
super().__init__() | ||
self.img_size = img_size | ||
self.patch_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) | ||
self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) | ||
self.proj = layers.Conv2D( | ||
embed_dim, | ||
kernel_size=patch_size, | ||
strides=patch_size, | ||
padding="valid", | ||
kernel_initializer="he_normal", | ||
) | ||
|
||
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: | ||
B, W, H, C = x.shape | ||
|
||
assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height" | ||
assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width" | ||
|
||
return self.proj(x, **kwargs) # BHWC | ||
|
||
|
||
class VisionTransformer(layers.Layer, NestedObject): | ||
"""VisionTransformer architecture as described in | ||
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", | ||
<https://arxiv.org/pdf/2010.11929.pdf>`_.""" | ||
Comment on lines
+52
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
|
||
def __init__( | ||
self, | ||
img_size: Tuple[int, int], | ||
patch_size: Tuple[int, int], | ||
d_model: int = 768, | ||
num_layers: int = 12, | ||
num_heads: int = 12, | ||
dropout: float = 0.1, | ||
) -> None: | ||
|
||
super().__init__() | ||
self.img_size = img_size | ||
self.patch_size = patch_size | ||
self.num_layers = num_layers | ||
|
||
self.patch_embedding = PatchEmbedding(self.img_size, self.patch_size, d_model) | ||
self.num_patches = self.patch_embedding.num_patches | ||
self.cls_token = self.add_weight(shape=(1, 1, d_model), initializer="zeros", trainable=True, name="cls_token") | ||
self.positions = self.add_weight( | ||
shape=(1, self.num_patches + 1, d_model), initializer="zeros", trainable=True, name="positions" | ||
) | ||
|
||
self.layer_norm = layers.LayerNormalization(epsilon=1e-5) | ||
self.dropout = layers.Dropout(rate=dropout) | ||
|
||
self.attention = [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)] | ||
self.position_feed_forward = [ | ||
PositionwiseFeedForward(d_model, d_model, dropout, use_gelu=True) for _ in range(self.num_layers) | ||
] | ||
|
||
def __call__(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: | ||
patches = self.patch_embedding(x) | ||
|
||
B, C, H, W = patches.shape | ||
patches = tf.reshape(patches, (B, (C * H), W)) # (batch_size, num_patches, d_model) | ||
|
||
cls_tokens = tf.repeat(self.cls_token, B, axis=0) # (batch_size, num_patches, d_model) | ||
# concate cls_tokens to patches | ||
embeddings = tf.concat([cls_tokens, patches], axis=1) # (batch_size, num_patches + 1, d_model) | ||
# add positions to embeddings | ||
embeddings += self.positions # (batch_size, num_patches + 1, d_model) | ||
|
||
output = embeddings | ||
|
||
for i in range(self.num_layers): | ||
normed_output = self.layer_norm(output, **kwargs) | ||
output = output + self.dropout( | ||
self.attention[i](normed_output, normed_output, normed_output, **kwargs), | ||
**kwargs, | ||
) | ||
normed_output = self.layer_norm(output, **kwargs) | ||
output = output + self.dropout(self.position_feed_forward[i](normed_output, **kwargs), **kwargs) | ||
|
||
# (batch_size, seq_len + cls token, d_model) | ||
return self.layer_norm(output, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .crnn import * | ||
from .master import * | ||
from .sar import * | ||
from .vitstr import * | ||
from .zoo import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from doctr.file_utils import is_tf_available, is_torch_available | ||
|
||
if is_tf_available(): | ||
from .tensorflow import * | ||
elif is_torch_available(): | ||
from .pytorch import * # type: ignore[misc] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
boolean value for activation selection is rather limited: are we positive that only relu & gelu can be used for such architecture types?