Skip to content

Commit

Permalink
feat: integration of the classifier in the ocr predictor (#723)
Browse files Browse the repository at this point in the history
* feat: add draft classifier

* feat: add classification module + utils fn

* fix: flake8

* fix: mypy

* fix: typing

* tests: add tests

* feat: integrate classifier to predictor

* fix: add flag distinction

* refacto: backbone module

* fix: typo in cfg

* fix: classes

* fix: typo

* fix: tests

* fix: cfg

* refacto: predictor name + tests

* fix: input shape tf

* fix: cfg

* fix: cfg

* fix: cfg

* fix: cfg

* fix: typing

* fix: isort

* fix: sort

* fix: naming

* fix: naming ocr

* fix: edge case 0 rotation

* tests: add mock text box

* fix: sorting tests
  • Loading branch information
charlesmindee committed Dec 21, 2021
1 parent 1c8b5cc commit caa363c
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 4 deletions.
23 changes: 21 additions & 2 deletions doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,5 +206,24 @@ def rectify_crops(
3: 270 ccw, rotate 1 time ccw
"""
# Inverse predictions (if angle of +90 is detected, rotate by -90)
orientations = [4 - pred for pred in orientations if pred != 0]
return [np.rot90(crop, orientation) for orientation, crop in zip(orientations, crops)]
orientations = [4 - pred if pred != 0 else 0 for pred in orientations]
return [
crop if orientation == 0 else np.rot90(crop, orientation)
for orientation, crop in zip(orientations, crops)
]


def rectify_loc_preds(
page_loc_preds: np.ndarray,
orientations: List[int],
) -> np.ndarray:
"""Orient the quadrangle (Polygon4P) according to the predicted orientation,
so that the points are in this order: top L, top R, bot R, bot L if the crop is readable
"""
return np.stack(
[page_loc_pred if orientation == 0 else np.roll(
page_loc_pred,
orientation,
axis=0) for orientation, page_loc_pred in zip(orientations, page_loc_preds)],
axis=0
)
2 changes: 1 addition & 1 deletion doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _crop_orientation_predictor(


def crop_orientation_predictor(
arch: str = 'classif_mobilenet_v3_small',
arch: str = 'mobilenet_v3_small_orientation',
pretrained: bool = False,
**kwargs: Any
) -> CropOrientationPredictor:
Expand Down
20 changes: 19 additions & 1 deletion doctr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from doctr.models.builder import DocumentBuilder

from .._utils import extract_crops, extract_rcrops
from .._utils import extract_crops, extract_rcrops, rectify_crops, rectify_loc_preds
from ..classification import crop_orientation_predictor

__all__ = ['_OCRPredictor']

Expand All @@ -24,6 +25,9 @@ class _OCRPredictor:

doc_builder: DocumentBuilder

def __init__(self) -> None:
self.crop_orientation_predictor = crop_orientation_predictor(pretrained=True)

@staticmethod
def _generate_crops(
pages: List[np.ndarray],
Expand Down Expand Up @@ -60,6 +64,20 @@ def _prepare_crops(

return crops, loc_preds

def _rectify_crops(
self,
crops: List[List[np.ndarray]],
loc_preds: List[np.ndarray],
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
# Work at a page level
orientations = [self.crop_orientation_predictor(page_crops) for page_crops in crops]
rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
rect_loc_preds = [
rectify_loc_preds(page_loc_preds, orientation) for page_loc_preds, orientation
in zip(loc_preds, orientations)
]
return rect_crops, rect_loc_preds

@staticmethod
def _process_predictions(
loc_preds: List[np.ndarray],
Expand Down
3 changes: 3 additions & 0 deletions doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def forward(
crops, loc_preds = self._prepare_crops(
pages, loc_preds, channels_last=channels_last, assume_straight_pages=self.assume_straight_pages
)
# Rectify crop orientation
if not self.assume_straight_pages:
crops, loc_preds = self._rectify_crops(crops, loc_preds)
# Identify character sequences
word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)

Expand Down
3 changes: 3 additions & 0 deletions doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def __call__(
crops, loc_preds = self._prepare_crops(
pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
)
# Rectify crop orientation
if not self.assume_straight_pages:
crops, loc_preds = self._rectify_crops(crops, loc_preds)
# Identify character sequences
word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)

Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ def mock_pdf(mock_pdf_stream, tmpdir_factory):
return str(fn)


@pytest.fixture(scope="session")
def mock_text_box_stream():
url = 'https://www.pngitem.com/pimgs/m/357-3579845_love-neon-loveislove-word-text-typography-freetoedit-picsart.png'
return requests.get(url).content


@pytest.fixture(scope="session")
def mock_text_box(mock_text_box_stream, tmpdir_factory):
file = BytesIO(mock_text_box_stream)
fn = tmpdir_factory.mktemp("data").join("mock_text_box_file.png")
with open(fn, 'wb') as f:
f.write(file.getbuffer())
return str(fn)


@pytest.fixture(scope="session")
def mock_image_stream():
url = "https://miro.medium.com/max/3349/1*mk1-6aYaf_Bes1E3Imhc0A.jpeg"
Expand Down
11 changes: 11 additions & 0 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import cv2
import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -77,3 +79,12 @@ def test_classification_zoo(arch_name):
out = predictor(input_tensor)
assert isinstance(out, list) and len(out) == batch_size
assert all(isinstance(pred, int) for pred in out)


def test_crop_orientation_model(mock_text_box):
text_box_0 = cv2.imread(mock_text_box)
text_box_90 = np.rot90(text_box_0, 1)
text_box_180 = np.rot90(text_box_0, 2)
text_box_270 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True)
assert classifier([text_box_0, text_box_90, text_box_180, text_box_270]) == [0, 1, 2, 3]
11 changes: 11 additions & 0 deletions tests/tensorflow/test_models_classification_tf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import cv2
import numpy as np
import pytest
import tensorflow as tf

Expand Down Expand Up @@ -60,3 +62,12 @@ def test_classification_zoo(arch_name):
out = predictor(input_tensor)
assert isinstance(out, list) and len(out) == batch_size
assert all(isinstance(pred, int) for pred in out)


def test_crop_orientation_model(mock_text_box):
text_box_0 = cv2.imread(mock_text_box)
text_box_90 = np.rot90(text_box_0, 1)
text_box_180 = np.rot90(text_box_0, 2)
text_box_270 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True)
assert classifier([text_box_0, text_box_90, text_box_180, text_box_270]) == [0, 1, 2, 3]

0 comments on commit caa363c

Please sign in to comment.