Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[models] add ViTSTR TF and PT and update ViT to work as backbone #1055

Merged
merged 19 commits into from
Sep 21, 2022

Conversation

felixdittrich92
Copy link
Contributor

@felixdittrich92 felixdittrich92 commented Sep 9, 2022

This PR:

  • adds ViTSTR in TF and PT
  • update ViT config
  • add pos_interpolation to ViT to make it work also on higher image resolutions (case: recognition) as backbone
  • ViT tf refactor patchify to reduce memory consuption (now it is similar to PT implementation TF: vit_s: 8.9 GB to PT: 7.5GB)
    (This difference is normal faiced the same % diff in SAR and MASTER between TF and PT)
  • apply some fixes
  • added model configs in this PR in both frameworks are: vit_s, vit_b (updated), viststr_small, vitstr_base

Any feedback is very welcome 🤗

NOTE:
Unlike the SAR or MASTER architecture, I am not able to fully train the model because ViT requires a lot of data and I cannot muster the computing power. So just a little test this time based on our word generator to show that it trains well.

slow tests: passed

PT:

(doctr-dev) felix@felix-GS66-Stealth-11UH:~/Desktop/doctr$ python3 /home/felix/Desktop/doctr/references/recognition/train_pytorch.py vitstr
2022-09-19 09:35:14.038958: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Namespace(amp=False, arch='vitstr', batch_size=64, device=None, epochs=10, find_lr=False, font='FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf', input_size=32, lr=0.001, max_chars=12, min_chars=1, name=None, pretrained=False, push_to_hub=False, resume=None, sched='cosine', show_samples=False, test_only=False, train_path=None, train_samples=1000, val_path=None, val_samples=20, vocab='french', wb=False, weight_decay=0, workers=None)
Validation set loaded in 1.727s (2520 samples in 40 batches)
WARNING:root:Invalid model URL, using default initialization.
Train set loaded in 0.002008s (126000 samples in 1968 batches)
Validation loss decreased inf --> 3.90789: saving state...                                                                                                      
Epoch 1/10 - Validation loss: 3.90789 (Exact: 1.35% | Partial: 2.26%)
Validation loss decreased 3.90789 --> 3.54396: saving state...                                                                                                  
Epoch 2/10 - Validation loss: 3.54396 (Exact: 3.93% | Partial: 5.00%)

TF:

(doctr-dev-tf) felix@felix-GS66-Stealth-11UH:~/Desktop/doctr$ python3 /home/felix/Desktop/doctr/references/recognition/train_tensorflow.py vitstr
Namespace(amp=False, arch='vitstr', batch_size=64, epochs=10, find_lr=False, font='FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf', input_size=32, lr=0.001, max_chars=12, min_chars=1, name=None, pretrained=False, push_to_hub=False, resume=None, show_samples=False, test_only=False, train_path=None, train_samples=1000, val_path=None, val_samples=20, vocab='french', wb=False, workers=None)
Validation set loaded in 0.002643s (2520 samples in 40 batches)
WARNING:root:Invalid model URL, using default initialization.
Train set loaded in 0.004127s (126000 samples in 1968 batches)
Validation loss decreased inf --> 3.9933: saving state...                                                                                                       
Epoch 1/10 - Validation loss: 3.9933 (Exact: 3.45% | Partial: 3.73%)
Validation loss decreased 3.9933 --> 3.74347: saving state...                                                                                                   
Epoch 2/10 - Validation loss: 3.74347 (Exact: 7.50% | Partial: 7.78%)

Additional:
pred works also: (only tested with a model which reaches ~15% exact after 9 epochs trained with WordGenerator samples)

Word(value='฿฿_', confidence=0.048),
Word(value='4฿^฿', confidence=0.038),
Word(value='|', confidence=0.05),
Word(value='ërwW', confidence=0.031),
Word(value='x¢฿MMo', confidence=0.018),

@felixdittrich92 felixdittrich92 added module: models Related to doctr.models framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend topic: text recognition Related to the task of text recognition labels Sep 9, 2022
@felixdittrich92 felixdittrich92 added this to the 0.6.0 milestone Sep 9, 2022
@felixdittrich92 felixdittrich92 self-assigned this Sep 9, 2022
@codecov
Copy link

codecov bot commented Sep 15, 2022

Codecov Report

Merging #1055 (cf8470e) into main (0aacb3c) will increase coverage by 0.11%.
The diff coverage is 98.29%.

❗ Current head cf8470e differs from pull request most recent head b84892f. Consider uploading reports for the commit b84892f to get more accurate results

@@            Coverage Diff             @@
##             main    #1055      +/-   ##
==========================================
+ Coverage   95.17%   95.28%   +0.11%     
==========================================
  Files         141      145       +4     
  Lines        5823     6046     +223     
==========================================
+ Hits         5542     5761     +219     
- Misses        281      285       +4     
Flag Coverage Δ
unittests 95.28% <98.29%> (+0.11%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
doctr/models/classification/zoo.py 100.00% <ø> (ø)
doctr/models/recognition/vitstr/pytorch.py 97.46% <97.46%> (ø)
doctr/models/recognition/vitstr/tensorflow.py 97.56% <97.56%> (ø)
doctr/models/classification/vit/pytorch.py 100.00% <100.00%> (ø)
doctr/models/classification/vit/tensorflow.py 100.00% <100.00%> (ø)
doctr/models/modules/vision_transformer/pytorch.py 100.00% <100.00%> (ø)
...tr/models/modules/vision_transformer/tensorflow.py 100.00% <100.00%> (ø)
doctr/models/recognition/__init__.py 100.00% <100.00%> (ø)
doctr/models/recognition/vitstr/__init__.py 100.00% <100.00%> (ø)
doctr/models/recognition/vitstr/base.py 100.00% <100.00%> (ø)
... and 1 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

Copy link
Collaborator

@frgfm frgfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Felix! I'd have to go through the paper more thoroughly for an accurate review, but I added some comments!

doctr/models/recognition/vitstr/base.py Outdated Show resolved Hide resolved
@felixdittrich92 felixdittrich92 linked an issue Sep 19, 2022 that may be closed by this pull request
@felixdittrich92 felixdittrich92 changed the title [DRAFT] [models] add ViTSTR TF and PT [models] add ViTSTR TF and PT Sep 19, 2022
@felixdittrich92 felixdittrich92 marked this pull request as ready for review September 19, 2022 12:17
@felixdittrich92 felixdittrich92 marked this pull request as draft September 20, 2022 06:43
@felixdittrich92
Copy link
Contributor Author

done

@felixdittrich92 felixdittrich92 marked this pull request as ready for review September 20, 2022 11:02
odulcy-mindee
odulcy-mindee previously approved these changes Sep 20, 2022
Copy link
Collaborator

@odulcy-mindee odulcy-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Maybe I'll wait for a review from @frgfm for this model

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Sep 21, 2022

@frgfm for TF we will switch later to IntermediateLayerGetter if we have pretrained ViT models 👍

About the model it's quite simple ViT as feature extractor with a custom head (nothing special inside)

@felixdittrich92 felixdittrich92 changed the title [models] add ViTSTR TF and PT [models] add ViTSTR TF and PT and update ViT to work as backbone Sep 21, 2022
@felixdittrich92 felixdittrich92 merged commit e538cc2 into mindee:main Sep 21, 2022
@felixdittrich92
Copy link
Contributor Author

@frgfm i will open another PR to add pretrained weights for ViT. We can add your requested changes into this (if there is anything :) )

Copy link
Collaborator

@frgfm frgfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's only set all the arguments for/against the patch_size default value (cf. comment)

Comment on lines +28 to +29
# fix patch size if recognition task with 32x128 input
self.patch_size = (4, 8) if height != width else patch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tricky condition: what are all the possible cases as input, and what do we want as patch_size for each?

Copy link
Contributor Author

@felixdittrich92 felixdittrich92 Sep 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it could be anything ..
currently in classification case: 32x32 -> (4, 4 (check)
recognition case: 32x128 -> (4, 8) (check)
detection case 1024x1024 (not handled)
any other size (not handled)

it will not fail but each size needs a different patch_size

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok two questions then:

  • how should the scale impact the patch size ? (N,N) --> (H,W) implies that (2N,2N) --> (?,?)
  • how should the aspect ratio impact the patch size? I see that (32,32) --> (4, 4), but why (32,128) doesn't do (4, 16) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(4, 16) would work also but i used the values from ParSeq for the PatchEmbedding of 32x128 samples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which is (4, 8)

Copy link
Contributor Author

@felixdittrich92 felixdittrich92 Sep 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/baudm/parseq/blob/main/configs/model/parseq.yaml
https://github.com/baudm/parseq/blob/main/configs/experiment/vitstr.yaml

If we would use a fixed ratio this would be easy to scale ... but yeah i took the values from ParSeq paper / implementation

@felixdittrich92 felixdittrich92 mentioned this pull request Sep 26, 2022
85 tasks
@ghassenBtb
Copy link

@felixdittrich92 Did you add pretrained weights for ViTSTR recognition model ? I looked in the code and there was no url link to pretrained weights.
Thanks

@felixdittrich92
Copy link
Contributor Author

Hi @ghassenBtb 👋,

no, i have added only pretrained weights for the ViT backbone.
I have no access to mindee`s internal dataset which was used to train the recognition models so ping @charlesmindee @odulcy-mindee

@ghassenBtb
Copy link

Thank you Felix for your quick response :)
@charlesmindee @odulcy-mindee did you pretrain ViTSTR on your internal dataset?
Thanks

@charlesmindee
Copy link
Collaborator

Hi @ghassenBtb, we didn't pretrain ViTSTR on our internal dataset

@felixdittrich92
Copy link
Contributor Author

Hi @charlesmindee i think the question from @ghassenBtb was more about if you can train the model internally and provide the checkpoints 😅

@ghassenBtb
Copy link

Yes it would be great if you can train the model on your internal french dataset :)
It will certainly have superior recognition performance than CRNN-based recognition models as it will learn semantic and syntactic properties.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend module: models Related to doctr.models topic: text recognition Related to the task of text recognition type: new feature New feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding ViTSTR
5 participants