diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index 627afa09f8..6e4717041b 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -25,6 +25,8 @@ class _OCRPredictor: straighten_pages: if True, estimates the page general orientation based on the median line orientation. Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped accordingly. Doing so will improve performances for documents with page-uniform rotations. + preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding) + symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically. kwargs: keyword args of `DocumentBuilder` """ @@ -34,12 +36,16 @@ def __init__( self, assume_straight_pages: bool = True, straighten_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, **kwargs: Any, ) -> None: self.assume_straight_pages = assume_straight_pages self.straighten_pages = straighten_pages self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True) self.doc_builder = DocumentBuilder(**kwargs) + self.preserve_aspect_ratio = preserve_aspect_ratio + self.symmetric_pad = symmetric_pad @staticmethod def _generate_crops( @@ -91,6 +97,44 @@ def _rectify_crops( ] return rect_crops, rect_loc_preds + def _remove_padding( + self, + pages: List[np.ndarray], + loc_preds: List[np.ndarray], + ) -> List[np.ndarray]: + if self.preserve_aspect_ratio: + # Rectify loc_preds to remove padding + rectified_preds = [] + for page, loc_pred in zip(pages, loc_preds): + h, w = page.shape[0], page.shape[1] + if h > w: + # y unchanged, dilate x coord + if self.symmetric_pad: + if self.assume_straight_pages: + loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - .5) * h / w + .5, 0, 1) + else: + loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - .5) * h / w + .5, 0, 1) + else: + if self.assume_straight_pages: + loc_pred[:, [0, 2]] *= h / w + else: + loc_pred[:, :, 0] *= h / w + elif w > h: + # x unchanged, dilate y coord + if self.symmetric_pad: + if self.assume_straight_pages: + loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - .5) * w / h + .5, 0, 1) + else: + loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - .5) * w / h + .5, 0, 1) + else: + if self.assume_straight_pages: + loc_pred[:, [1, 3]] *= w / h + else: + loc_pred[:, :, 1] *= w / h + rectified_preds.append(loc_pred) + return rectified_preds + return loc_preds + @staticmethod def _process_predictions( loc_preds: List[np.ndarray], diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index cfaab7821a..6985c4c64c 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -40,13 +40,17 @@ def __init__( reco_predictor: RecognitionPredictor, assume_straight_pages: bool = True, straighten_pages: bool = False, + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = True, **kwargs: Any, ) -> None: nn.Module.__init__(self) self.det_predictor = det_predictor.eval() # type: ignore[attr-defined] self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined] - _OCRPredictor.__init__(self, assume_straight_pages, straighten_pages, **kwargs) + _OCRPredictor.__init__( + self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs + ) @torch.no_grad() def forward( @@ -71,6 +75,9 @@ def forward( # Check whether crop mode should be switched to channels first channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray) + # Rectify crops if aspect ratio + loc_preds = self._remove_padding(pages, loc_preds) + # Crop images crops, loc_preds = self._prepare_crops( pages, loc_preds, channels_last=channels_last, assume_straight_pages=self.assume_straight_pages diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index d8bd7cb5ed..3862591df5 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -41,12 +41,16 @@ def __init__( reco_predictor: RecognitionPredictor, assume_straight_pages: bool = True, straighten_pages: bool = False, + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = True, **kwargs: Any, ) -> None: self.det_predictor = det_predictor self.reco_predictor = reco_predictor - _OCRPredictor.__init__(self, assume_straight_pages, straighten_pages, **kwargs) + _OCRPredictor.__init__( + self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs + ) def __call__( self, @@ -68,6 +72,9 @@ def __call__( # Localize text elements loc_preds = self.det_predictor(pages, **kwargs) + # Rectify crops if aspect ratio + loc_preds = self._remove_padding(pages, loc_preds) + # Crop images crops, loc_preds = self._prepare_crops( pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index 7197edc7cc..2973315ad7 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -18,6 +18,7 @@ def _predictor( pretrained: bool, assume_straight_pages: bool = True, preserve_aspect_ratio: bool = False, + symmetric_pad: bool = True, det_bs: int = 2, reco_bs: int = 128, **kwargs, @@ -30,6 +31,7 @@ def _predictor( batch_size=det_bs, assume_straight_pages=assume_straight_pages, preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, ) # Recognition @@ -39,6 +41,8 @@ def _predictor( det_predictor, reco_predictor, assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, **kwargs ) @@ -49,6 +53,7 @@ def ocr_predictor( pretrained: bool = False, assume_straight_pages: bool = True, preserve_aspect_ratio: bool = False, + symmetric_pad: bool = True, export_as_straight_boxes: bool = False, **kwargs: Any ) -> OCRPredictor: @@ -69,6 +74,7 @@ def ocr_predictor( without rotated textual elements. preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before running the detection model on it. + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right. export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions (potentially rotated) as straight bounding boxes. kwargs: keyword args of `OCRPredictor` @@ -83,6 +89,7 @@ def ocr_predictor( pretrained, assume_straight_pages=assume_straight_pages, preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, export_as_straight_boxes=export_as_straight_boxes, **kwargs, ) diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index 7c4f02fd58..6f163b4e3b 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -91,6 +91,24 @@ def test_trained_ocr_predictor(mock_tilted_payslip): [0.51385817, 0.21002172]]) assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised) + det_predictor = detection_predictor( + 'db_resnet50', pretrained=True, batch_size=2, assume_straight_pages=True, + preserve_aspect_ratio=True, symmetric_pad=True + ) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == 'Mr.' + @pytest.mark.parametrize( "det_arch, reco_arch",