Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[documentation/Fix] add documentation for huggingface feature and fix ocr_predictor model loading #896

Merged
merged 38 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
81c313e
backup
felixdittrich92 Jan 11, 2022
50574b5
Merge branch 'mindee:main' into main
felixdittrich92 Jan 11, 2022
5a6ed54
Merge branch 'mindee:main' into main
felixdittrich92 Jan 18, 2022
b9958a7
Merge branch 'mindee:main' into main
felixdittrich92 Jan 20, 2022
14c4651
Merge branch 'mindee:main' into main
felixdittrich92 Feb 16, 2022
779731f
Merge branch 'mindee:main' into main
felixdittrich92 Feb 18, 2022
ce2cdda
Merge branch 'mindee:main' into main
felixdittrich92 Feb 22, 2022
d13dc43
Merge branch 'mindee:main' into main
felixdittrich92 Feb 23, 2022
9a07d73
Merge branch 'mindee:main' into main
felixdittrich92 Feb 24, 2022
a002a70
Merge branch 'mindee:main' into main
felixdittrich92 Feb 24, 2022
6ad096e
Merge branch 'mindee:main' into main
felixdittrich92 Feb 25, 2022
1e77fd4
Merge branch 'mindee:main' into main
felixdittrich92 Mar 8, 2022
2be762c
Merge branch 'mindee:main' into main
felixdittrich92 Mar 10, 2022
e2f2055
Merge branch 'mindee:main' into main
felixdittrich92 Mar 11, 2022
bdc4e67
Merge branch 'mindee:main' into main
felixdittrich92 Mar 16, 2022
b525021
Merge branch 'mindee:main' into main
felixdittrich92 Mar 16, 2022
417a27b
Merge branch 'mindee:main' into main
felixdittrich92 Mar 16, 2022
9b3f5a1
Merge branch 'mindee:main' into main
felixdittrich92 Mar 18, 2022
93074a8
Merge branch 'mindee:main' into main
felixdittrich92 Mar 21, 2022
c64e209
Merge branch 'mindee:main' into main
felixdittrich92 Mar 22, 2022
fdc8381
Merge branch 'mindee:main' into main
felixdittrich92 Mar 25, 2022
bd68b07
Merge branch 'mindee:main' into main
felixdittrich92 Apr 5, 2022
7ac6ee2
Merge branch 'mindee:main' into main
felixdittrich92 Apr 5, 2022
1c79f32
Merge branch 'mindee:main' into main
felixdittrich92 Apr 7, 2022
45e43ac
Merge branch 'mindee:main' into main
felixdittrich92 Apr 13, 2022
53ba4b9
Merge branch 'mindee:main' into main
felixdittrich92 Apr 22, 2022
807a731
start documentation
felixdittrich92 Apr 23, 2022
01dfefa
update
felixdittrich92 Apr 23, 2022
ed71590
add tf
felixdittrich92 Apr 23, 2022
6b8f438
update
felixdittrich92 Apr 23, 2022
4f99a3d
pass model directly to predictor
felixdittrich92 Apr 24, 2022
cfac31f
update test
felixdittrich92 Apr 24, 2022
cd3a154
minor
felixdittrich92 Apr 24, 2022
a1a97fc
flake
felixdittrich92 Apr 24, 2022
08662b2
minor
felixdittrich92 Apr 24, 2022
4f90ff8
add tests
felixdittrich92 Apr 24, 2022
7f4afff
update snippets
felixdittrich92 Apr 24, 2022
07f4132
Update installing.rst
felixdittrich92 Apr 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Supported datasets
:hidden:

using_doctr/using_models
using_doctr/sharing_models
using_doctr/using_model_export


Expand Down
10 changes: 10 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,13 @@ doctr.models.zoo
----------------

.. autofunction:: doctr.models.ocr_predictor


doctr.models.factory
--------------------

.. autofunction:: doctr.models.factory.login_to_hub

.. autofunction:: doctr.models.factory.from_hub

.. autofunction:: doctr.models.factory.push_to_hf_hub
118 changes: 118 additions & 0 deletions docs/source/using_doctr/sharing_models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
Share your model with the community
===================================

docTR's focus is on open source, so if you also feel in love with than we appreciate sharing your trained model with the community.
To make it easy for you, we have integrated a interface to the huggingface hub.

.. currentmodule:: doctr.models.factory


Loading from Huggingface Hub
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This section shows you how you can easily load a pretrained model from the Huggingface Hub.

.. tabs::

.. tab:: TensorFlow

.. code:: python3

from doctr.io import DocumentFile
from doctr.models import ocr_predictor, from_hub
image = DocumentFile.from_images(['data/example.jpg'])
# Load a custom detection model from huggingface hub
det_model = from_hub('Felix92/doctr-tf-db-resnet50')
# Load a custom recognition model from huggingface hub
reco_model = from_hub('Felix92/doctr-tf-crnn-vgg16-bn-french')
# You can easily plug in this models to the OCR predictor
predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model)
result = predictor(image)

.. tab:: PyTorch

.. code:: python3

from doctr.io import DocumentFile
from doctr.models import ocr_predictor, from_hub
image = DocumentFile.from_images(['data/example.jpg'])
# Load a custom detection model from huggingface hub
det_model = from_hub('Felix92/doctr-torch-db-mobilenet-v3-large')
# Load a custom recognition model from huggingface hub
reco_model = from_hub('Felix92/doctr-torch-crnn-mobilenet-v3-large-french')
# You can easily plug in this models to the OCR predictor
predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model)
result = predictor(image)


Pushing to the Huggingface Hub
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

You can also push your trained model to the Huggingface Hub.
You need only to provide the task type (classification, detection, recognition or obj_detection), a name for your trained model (NOTE:
existing repositories will not be overwritten) and the model name itself.

- Prerequisites:
- Huggingface account (you can easy create one at https://huggingface.co/)
- installed Git LFS (check installation at: https://git-lfs.github.com/) in the repository

.. code:: python3

from doctr.models import recognition, login_to_hub, push_to_hf_hub
login_to_hub()
my_awesome_model = recognition.crnn_mobilenet_v3_large(pretrained=True)
push_to_hf_hub(my_awesome_model, model_name='doctr-crnn-mobilenet-v3-large-french-v1', task='recognition', arch='crnn_mobilenet_v3_large')

It is also possible to push your model directly after training.

.. tabs::

.. tab:: TensorFlow

python3 ~/doctr/references/recognition/train_tensorflow.py crnn_mobilenet_v3_large --name doctr-crnn-mobilenet-v3-large --push-to-hub

.. tab:: PyTorch

python3 ~/doctr/references/recognition/train_pytorch.py crnn_mobilenet_v3_large --name doctr-crnn-mobilenet-v3-large --push-to-hub


Pretrained community models
---------------------------

This section is to provide some tables for pretrained community models.
Feel free to open a pull request or issue to add your model to this list.

Classification
^^^^^^^^^^^^^^

+---------------------------------+-------------------------------------+-----------------------+------------------------+
| **Architecture** | **Repo_ID** | **Vocabulary** | **Framework** |
+=================================+=====================================+=======================+========================+
| resnet18 (dummy) | Felix92/doctr-dummy-torch-resnet18 | french | PyTorch |
+---------------------------------+-------------------------------------+-----------------------+------------------------+
| resnet18 (dummy) | Felix92/doctr-dummy-tf-resnet18 | french | TensorFlow |
+---------------------------------+-------------------------------------+-----------------------+------------------------+


Detection
^^^^^^^^^

+---------------------------------+-------------------------------------------------+------------------------+
| **Architecture** | **Repo_ID** | **Framework** |
+=================================+=================================================+========================+
| db_mobilenet_v3_large (dummy) | Felix92/doctr-torch-db-mobilenet-v3-large | PyTorch |
+---------------------------------+-------------------------------------------------+------------------------+
| db_resnet50 (dummy) | Felix92/doctr-tf-db-resnet50 | TensorFlow |
+---------------------------------+-------------------------------------------------+------------------------+


Recognition
^^^^^^^^^^^

+---------------------------------+---------------------------------------------------+---------------------+------------------------+
| **Architecture** | **Repo_ID** | **Language** | **Framework** |
+=================================+===================================================+=====================+========================+
| crnn_mobilenet_v3_large (dummy) | Felix92/doctr-torch-crnn-mobilenet-v3-large | french | PyTorch |
+---------------------------------+---------------------------------------------------+---------------------+------------------------+
| crnn_vgg16_bn (dummy) | Felix92/doctr-tf-crnn-vgg16-bn-french | french | TensorFlow |
+---------------------------------+---------------------------------------------------+---------------------+------------------------+
3 changes: 2 additions & 1 deletion doctr/models/detection/linknet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(

super().__init__()
self.cfg = cfg
self.assume_straight_pages = assume_straight_pages

self.feat_extractor = feat_extractor
# Identify the number of channels for the FPN initialization
Expand All @@ -124,7 +125,7 @@ def __init__(
nn.ConvTranspose2d(head_chans, num_classes, kernel_size=2, stride=2),
)

self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages)
self.postprocessor = LinkNetPostProcessor(assume_straight_pages=self.assume_straight_pages)

for n, m in self.named_modules():
# Don't override the initialization of the backbone
Expand Down
30 changes: 19 additions & 11 deletions doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,30 @@


def _predictor(
arch: str,
arch: Any,
pretrained: bool,
assume_straight_pages: bool = True,
**kwargs: Any
) -> DetectionPredictor:

if arch not in ARCHS:
raise ValueError(f"unknown architecture '{arch}'")
if isinstance(arch, str):
if arch not in ARCHS + ROT_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

if arch not in ROT_ARCHS and not assume_straight_pages:
raise AssertionError("You are trying to use a model trained on straight pages while not assuming"
" your pages are straight. If you have only straight documents, don't pass"
f" assume_straight_pages=False, otherwise you should use one of these archs: {ROT_ARCHS}")
if arch not in ROT_ARCHS and not assume_straight_pages:
raise AssertionError("You are trying to use a model trained on straight pages while not assuming"
" your pages are straight. If you have only straight documents, don't pass"
" assume_straight_pages=False, otherwise you should use one of these archs:"
f"{ROT_ARCHS}")

_model = detection.__dict__[arch](pretrained=pretrained, assume_straight_pages=assume_straight_pages)
else:
if not isinstance(arch, (detection.DBNet, detection.LinkNet)):
raise ValueError(f"unknown architecture: {type(arch)}")

_model = arch
_model.assume_straight_pages = assume_straight_pages

# Detection
_model = detection.__dict__[arch](pretrained=pretrained, assume_straight_pages=assume_straight_pages)
kwargs['mean'] = kwargs.get('mean', _model.cfg['mean'])
kwargs['std'] = kwargs.get('std', _model.cfg['std'])
kwargs['batch_size'] = kwargs.get('batch_size', 1)
Expand All @@ -54,7 +62,7 @@ def _predictor(


def detection_predictor(
arch: str = 'db_resnet50',
arch: Any = 'db_resnet50',
pretrained: bool = False,
assume_straight_pages: bool = True,
**kwargs: Any
Expand All @@ -68,7 +76,7 @@ def detection_predictor(
>>> out = model([input_page])

Args:
arch: name of the architecture to use (e.g. 'db_resnet50')
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
pretrained: If True, returns a model pre-trained on our text detection dataset
assume_straight_pages: If True, fit straight boxes to the page

Expand Down
22 changes: 16 additions & 6 deletions doctr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@
ARCHS: List[str] = ['crnn_vgg16_bn', 'crnn_mobilenet_v3_small', 'crnn_mobilenet_v3_large', 'sar_resnet31', 'master']


def _predictor(arch: str, pretrained: bool, **kwargs: Any) -> RecognitionPredictor:
def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor:

if arch not in ARCHS:
raise ValueError(f"unknown architecture '{arch}'")
if isinstance(arch, str):
if arch not in ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

_model = recognition.__dict__[arch](pretrained=pretrained)
else:
if not isinstance(arch, (recognition.CRNN, recognition.SAR, recognition.MASTER)):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch

_model = recognition.__dict__[arch](pretrained=pretrained)
kwargs['mean'] = kwargs.get('mean', _model.cfg['mean'])
kwargs['std'] = kwargs.get('std', _model.cfg['std'])
kwargs['batch_size'] = kwargs.get('batch_size', 32)
Expand All @@ -35,7 +41,11 @@ def _predictor(arch: str, pretrained: bool, **kwargs: Any) -> RecognitionPredict
return predictor


def recognition_predictor(arch: str = 'crnn_vgg16_bn', pretrained: bool = False, **kwargs: Any) -> RecognitionPredictor:
def recognition_predictor(
arch: Any = 'crnn_vgg16_bn',
pretrained: bool = False,
**kwargs: Any
) -> RecognitionPredictor:
"""Text recognition architecture.

Example::
Expand All @@ -46,7 +56,7 @@ def recognition_predictor(arch: str = 'crnn_vgg16_bn', pretrained: bool = False,
>>> out = model([input_page])

Args:
arch: name of the architecture to use (e.g. 'crnn_vgg16_bn')
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
pretrained: If True, returns a model pre-trained on our text recognition dataset

Returns:
Expand Down
14 changes: 8 additions & 6 deletions doctr/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@


def _predictor(
det_arch: str,
reco_arch: str,
det_arch: Any,
reco_arch: Any,
pretrained: bool,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = False,
Expand Down Expand Up @@ -48,8 +48,8 @@ def _predictor(


def ocr_predictor(
det_arch: str = 'db_resnet50',
reco_arch: str = 'crnn_vgg16_bn',
det_arch: Any = 'db_resnet50',
reco_arch: Any = 'crnn_vgg16_bn',
pretrained: bool = False,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = False,
Expand All @@ -66,8 +66,10 @@ def ocr_predictor(
>>> out = model([input_page])

Args:
det_arch: name of the detection architecture to use (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
reco_arch: name of the recognition architecture to use (e.g. 'crnn_vgg16_bn', 'sar_resnet31')
det_arch: name of the detection architecture or the model itself to use
(e.g. 'db_resnet50', 'db_mobilenet_v3_large')
reco_arch: name of the recognition architecture or the model itself to use
(e.g. 'crnn_vgg16_bn', 'sar_resnet31')
pretrained: If True, returns a model pre-trained on our OCR dataset
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
Expand Down
36 changes: 27 additions & 9 deletions tests/pytorch/test_models_zoo_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,7 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
_ = predictor([input_page])


@pytest.mark.parametrize(
"det_arch, reco_arch",
[
["db_mobilenet_v3_large", "crnn_mobilenet_v3_large"],
],
)
def test_zoo_models(det_arch, reco_arch):
# Model
predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True)
def _test_predictor(predictor):
# Output checks
assert isinstance(predictor, OCRPredictor)

Expand All @@ -81,3 +73,29 @@ def test_zoo_models(det_arch, reco_arch):
with pytest.raises(ValueError):
input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8)
_ = predictor([input_page])


@pytest.mark.parametrize(
"det_arch, reco_arch",
[
["db_mobilenet_v3_large", "crnn_mobilenet_v3_large"],
],
)
def test_zoo_models(det_arch, reco_arch):
# Model
predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True)
_test_predictor(predictor)

# passing model instance directly
det_model = detection.__dict__[det_arch](pretrained=True)
reco_model = recognition.__dict__[reco_arch](pretrained=True)
predictor = models.ocr_predictor(det_model, reco_model)
_test_predictor(predictor)

# passing recognition model as detection model
with pytest.raises(ValueError):
models.ocr_predictor(det_arch=reco_model, pretrained=True)

# passing detection model as recognition model
with pytest.raises(ValueError):
models.ocr_predictor(reco_arch=det_model, pretrained=True)
36 changes: 27 additions & 9 deletions tests/tensorflow/test_models_zoo_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,7 @@ def test_trained_ocr_predictor(mock_tilted_payslip):
assert out.pages[0].blocks[0].lines[0].words[0].value == 'Mr.'


@pytest.mark.parametrize(
"det_arch, reco_arch",
[
["db_mobilenet_v3_large", "crnn_vgg16_bn"],
],
)
def test_zoo_models(det_arch, reco_arch):
# Model
predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True)
def _test_predictor(predictor):
# Output checks
assert isinstance(predictor, OCRPredictor)

Expand All @@ -133,3 +125,29 @@ def test_zoo_models(det_arch, reco_arch):
with pytest.raises(ValueError):
input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8)
_ = predictor([input_page])


@pytest.mark.parametrize(
"det_arch, reco_arch",
[
["db_mobilenet_v3_large", "crnn_vgg16_bn"],
],
)
def test_zoo_models(det_arch, reco_arch):
# Model
predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True)
_test_predictor(predictor)

# passing model instance directly
det_model = detection.__dict__[det_arch](pretrained=True)
reco_model = recognition.__dict__[reco_arch](pretrained=True)
predictor = models.ocr_predictor(det_model, reco_model)
_test_predictor(predictor)

# passing recognition model as detection model
with pytest.raises(ValueError):
models.ocr_predictor(det_arch=reco_model, pretrained=True)

# passing detection model as recognition model
with pytest.raises(ValueError):
models.ocr_predictor(reco_arch=det_model, pretrained=True)