Skip to content

Commit

Permalink
Merge branch 'main' into hub
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Apr 5, 2022
2 parents b23399e + 1b4b687 commit 23f95ee
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 1 deletion.
32 changes: 31 additions & 1 deletion doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from doctr.utils.data import download_from_url

__all__ = ['load_pretrained_params', 'conv_sequence_pt']
__all__ = ['load_pretrained_params', 'conv_sequence_pt', 'export_classification_model_to_onnx']


def load_pretrained_params(
Expand Down Expand Up @@ -87,3 +87,33 @@ def conv_sequence_pt(
conv_seq.append(nn.ReLU(inplace=True))

return conv_seq


def export_classification_model_to_onnx(model: nn.Module, exp_name: str, dummy_input: torch.Tensor) -> str:
"""Export classification model to ONNX format.
>>> import torch
>>> from doctr.models.classification import resnet18
>>> from doctr.models.utils import export_classification_model_to_onnx
>>> model = resnet18(pretrained=True)
>>> export_classification_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
Args:
model: the PyTorch model to be exported
exp_name: the name for the exported model
dummy_input: the dummy input to the model
Returns:
the path to the exported model
"""
torch.onnx.export(
model,
dummy_input,
f"{exp_name}.onnx",
input_names=['input'],
output_names=['logits'],
dynamic_axes={'input': {0: 'batch_size'}, 'logits': {0: 'batch_size'}},
export_params=True, opset_version=13, verbose=False
)
logging.info(f"Model exported to {exp_name}.onnx")
return f"{exp_name}.onnx"
10 changes: 10 additions & 0 deletions references/classification/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from doctr import transforms as T
from doctr.datasets import VOCABS, CharacterGenerator
from doctr.models import classification, login_to_hub, push_to_hf_hub
from doctr.models.utils import export_classification_model_to_onnx
from utils import plot_recorder, plot_samples


Expand Down Expand Up @@ -339,6 +340,13 @@ def main(args):

if args.push_to_hub:
push_to_hf_hub(model, exp_name, task='classification', run_config=args)

if args.export_onnx:
print("Exporting model to ONNX...")
dummy_batch = next(iter(val_loader))
dummy_input = dummy_batch[0].cuda() if torch.cuda.is_available() else dummy_batch[0]
model_path = export_classification_model_to_onnx(model, exp_name, dummy_input)
print(f"Exported model saved in {model_path}")


def parse_args():
Expand Down Expand Up @@ -384,6 +392,8 @@ def parse_args():
parser.add_argument('--push-to-hub', dest='push_to_hub', action='store_true', help='Push to Huggingface Hub')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--export-onnx', dest='export_onnx', action='store_true',
help='Export the model to ONNX')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR')
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"pytest>=5.3.2",
"coverage>=4.5.4",
"hdf5storage>=0.1.18",
"onnxruntime>=1.11.0",
"requests>=2.20.0",
"requirements-parser==0.2.0",
# Quality
Expand Down Expand Up @@ -137,6 +138,7 @@ def deps_list(*pkgs):
"coverage",
"requests",
"hdf5storage",
"onnxruntime",
"requirements-parser",
)

Expand Down
39 changes: 39 additions & 0 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import tempfile

import cv2
import numpy as np
import onnxruntime
import pytest
import torch

from doctr.models import classification
from doctr.models.classification.predictor import CropOrientationPredictor
from doctr.models.utils import export_classification_model_to_onnx


def _test_classification(model, input_shape, output_size, batch_size=2):
Expand Down Expand Up @@ -98,3 +103,37 @@ def test_crop_orientation_model(mock_text_box):
text_box_270 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True)
assert classifier([text_box_0, text_box_90, text_box_180, text_box_270]) == [0, 1, 2, 3]


@pytest.mark.parametrize(
"arch_name, input_shape, output_size",
[
["vgg16_bn_r", (3, 32, 32), (126,)],
["resnet18", (3, 32, 32), (126,)],
["resnet31", (3, 32, 32), (126,)],
["resnet34", (3, 32, 32), (126,)],
["resnet34_wide", (3, 32, 32), (126,)],
["resnet50", (3, 32, 32), (126,)],
["magc_resnet31", (3, 32, 32), (126,)],
["mobilenet_v3_small", (3, 32, 32), (126,)],
["mobilenet_v3_large", (3, 32, 32), (126,)],
["mobilenet_v3_small_orientation", (3, 128, 128), (4,)],
],
)
def test_models_onnx_export(arch_name, input_shape, output_size):
# Model
batch_size = 2
model = classification.__dict__[arch_name](pretrained=True).eval()
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
with tempfile.TemporaryDirectory() as tmpdir:
# Export
model_path = export_classification_model_to_onnx(model,
exp_name=os.path.join(tmpdir, "model"),
dummy_input=dummy_input)
assert os.path.exists(model_path)
# Inference
ort_session = onnxruntime.InferenceSession(os.path.join(tmpdir, "model.onnx"),
providers=["CPUExecutionProvider"])
ort_outs = ort_session.run(['logits'], {'input': dummy_input.numpy()})
assert isinstance(ort_outs, list) and len(ort_outs) == 1
assert ort_outs[0].shape == (batch_size, *output_size)
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pytest>=5.3.2
requests>=2.20.0
hdf5storage>=0.1.18
coverage>=4.5.4
onnxruntime>=1.11.0
requirements-parser==0.2.0

0 comments on commit 23f95ee

Please sign in to comment.