Skip to content

Commit

Permalink
update with latest changes
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Sep 15, 2022
1 parent e468b22 commit 16a5b0f
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 31 deletions.
57 changes: 57 additions & 0 deletions doctr/models/recognition/vitstr/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (C) 2022, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import List, Tuple

import numpy as np

from ....datasets import encode_sequences
from ..core import RecognitionPostProcessor


class _ViTSTR:

vocab: str
max_length: int

def build_target(
self,
gts: List[str],
) -> Tuple[np.ndarray, List[int]]:
"""Encode a list of gts sequences into a np array and gives the corresponding*
sequence lengths.
Args:
gts: list of ground-truth labels
Returns:
A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
"""
encoded = encode_sequences(
sequences=gts,
vocab=self.vocab,
target_size=self.max_length,
eos=len(self.vocab),
sos=len(self.vocab) + 1,
pad=len(self.vocab) + 2,
)
seq_len = [len(word) for word in gts]
return encoded, seq_len


class _ViTSTRPostProcessor(RecognitionPostProcessor):
"""Abstract class to postprocess the raw output of the model
Args:
vocab: string containing the ordered sequence of supported characters
"""

def __init__(
self,
vocab: str,
) -> None:

super().__init__(vocab)
self._embedding = list(vocab) + ["<eos>"] + ["<sos>"] + ["<pad>"]
36 changes: 19 additions & 17 deletions doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from torch.nn import functional as F

from doctr.datasets import VOCABS
from doctr.models.classification import vit
from doctr.models.classification import vit_b

from ...utils.pytorch import load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor
from .base import _ViTSTR, _ViTSTRPostProcessor

__all__ = ["ViTSTR", "vitstr"]

Expand All @@ -29,7 +29,7 @@
}


class ViTSTR(nn.Module, RecognitionModel):
class ViTSTR(_ViTSTR, nn.Module):
"""Implements a ViTSTR architecture as described in `"Vision Transformer for Fast and
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
Expand Down Expand Up @@ -63,10 +63,10 @@ def __init__(
self.exportable = exportable
self.cfg = cfg

self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word
self.max_length = max_length + 3 # Add 1 timestep for EOS, 1 for SOS, 1 for PAD

self.feat_extractor = feature_extractor
self.head = nn.Linear(embedding_units, len(self.vocab) + 1)
self.head = nn.Linear(embedding_units, len(self.vocab) + 3)

self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)

Expand All @@ -93,7 +93,7 @@ def forward(
# batch, seqlen, embedding_size
B, N, E = features.size()
features = features.reshape(B * N, E)
logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch, seqlen, vocab + 1)
logits = self.head(features).view(B, N, len(self.vocab) + 3) # (batch, seqlen, vocab + 3)
decoded_features = logits[:, 1:] # remove cls_token

out: Dict[str, Any] = {}
Expand Down Expand Up @@ -121,29 +121,32 @@ def compute_loss(
) -> torch.Tensor:
"""Compute categorical cross-entropy loss for the model.
Sequences are masked after the EOS character.
Args:
model_output: predicted logits of the model
gt: the encoded tensor with gt labels
model_output: predicted logits of the model
seq_len: lengths of each gt word inside the batch
Returns:
The loss of the model on the batch
"""
# Input length : number of timesteps
input_len = model_output.shape[1]
# Add one for additional <eos> token
# Add one for additional <eos> token (sos disappear in shift!)
seq_len = seq_len + 1
# Compute loss
# (N, L, vocab_size + 1)
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
mask_2d = torch.arange(input_len, device=model_output.device)[None, :] >= seq_len[:, None]
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
# The "masked" first gt char is <sos>. Delete last logit of the model output.
cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
# Compute mask, remove 1 timestep here as well
mask_2d = torch.arange(input_len - 1, device=model_output.device)[None, :] >= seq_len[:, None]
cce[mask_2d] = 0

ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
return ce_loss.mean()


class ViTSTRPostProcessor(RecognitionPostProcessor):
"""Post processor for ViTSTR architectures
class ViTSTRPostProcessor(_ViTSTRPostProcessor):
"""Post processor for ViTSTR architecture
Args:
vocab: string containing the ordered sequence of supported characters
Expand All @@ -163,7 +166,7 @@ def __call__(
# Manual decoding
word_values = [
"".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
for encoded_seq in out_idxs.detach().cpu().numpy()
for encoded_seq in out_idxs.cpu().numpy()
]

return list(zip(word_values, probs.numpy().tolist()))
Expand All @@ -188,7 +191,7 @@ def _vitstr(
kwargs["input_shape"] = _cfg["input_shape"]

# Feature extractor
feat_extractor = vit(
feat_extractor = vit_b( # type: ignore[operator]
pretrained=pretrained_backbone,
input_shape=_cfg["input_shape"],
patch_size=(4, 8),
Expand All @@ -198,7 +201,6 @@ def _vitstr(
dropout=0.0,
include_top=False,
)
# TODO: update also Tensorflow all !!!

# Build the model
model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs)
Expand Down
32 changes: 18 additions & 14 deletions doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from tensorflow.keras import Model, layers

from doctr.datasets import VOCABS
from doctr.models.classification import vit
from doctr.models.classification import vit_b

from ...utils.tensorflow import load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor
from .base import _ViTSTR, _ViTSTRPostProcessor

__all__ = ["ViTSTR", "vitstr"]

Expand All @@ -28,7 +28,7 @@
}


class ViTSTR(Model, RecognitionModel):
class ViTSTR(_ViTSTR, Model):
"""Implements a ViTSTR architecture as described in `"Vision Transformer for Fast and
Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
Expand Down Expand Up @@ -63,41 +63,45 @@ def __init__(
self.vocab = vocab
self.exportable = exportable
self.cfg = cfg
self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word
self.max_length = max_length + 3 # Add 1 timestep for EOS, 1 for SOS, 1 for PAD

self.feat_extractor = feature_extractor
self.head = layers.Dense(len(self.vocab) + 1)
self.head = layers.Dense(len(self.vocab) + 3)

self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)

@staticmethod
def compute_loss(
model_output: tf.Tensor,
gt: tf.Tensor,
seq_len: tf.Tensor,
seq_len: List[int],
) -> tf.Tensor:
"""Compute categorical cross-entropy loss for the model.
Sequences are masked after the EOS character.
Args:
gt: the encoded tensor with gt labels
model_output: predicted logits of the model
seq_len: lengths of each gt word inside the batch
Returns:
The loss of the model on the batch
"""
# Input length : number of timesteps
input_len = tf.shape(model_output)[1]
# Add one for additional <eos> token
seq_len = seq_len + 1
# Add one for additional <eos> token (sos disappear in shift!)
seq_len = tf.cast(seq_len, tf.int32) + 1
# One-hot gt labels
oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
# Compute loss
cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt, model_output)
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
# The "masked" first gt char is <sos>. Delete last logit of the model output.
cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output[:, :-1, :])
# Compute mask
mask_values = tf.zeros_like(cce)
mask_2d = tf.sequence_mask(seq_len, input_len)
mask_2d = tf.sequence_mask(seq_len, input_len - 1) # delete the last mask timestep as well
masked_loss = tf.where(mask_2d, cce, mask_values)
ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))

return tf.expand_dims(ce_loss, axis=1)

def call(
Expand All @@ -122,7 +126,7 @@ def call(
# batch, seqlen, embedding_size
B, N, E = features.shape
features = tf.reshape(features, (B * N, E))
logits = tf.reshape(self.head(features), (B, N, len(self.vocab) + 1)) # (batch, seqlen, vocab + 1)
logits = tf.reshape(self.head(features), (B, N, len(self.vocab) + 3)) # (batch, seqlen, vocab + 1)
decoded_features = logits[:, 1:] # remove cls_token

out: Dict[str, tf.Tensor] = {}
Expand All @@ -143,7 +147,7 @@ def call(
return out


class ViTSTRPostProcessor(RecognitionPostProcessor):
class ViTSTRPostProcessor(_ViTSTRPostProcessor):
"""Post processor for ViTSTR architecture
Args:
Expand Down Expand Up @@ -190,7 +194,7 @@ def _vitstr(
kwargs["vocab"] = _cfg["vocab"]

# Feature extractor
feat_extractor = vit(
feat_extractor = vit_b(
pretrained=pretrained_backbone,
input_shape=_cfg["input_shape"],
patch_size=(4, 8),
Expand Down

0 comments on commit 16a5b0f

Please sign in to comment.