diff --git a/.github/workflows/scripts.yml b/.github/workflows/scripts.yml index 8eaadc1206..25d5cff4d8 100644 --- a/.github/workflows/scripts.yml +++ b/.github/workflows/scripts.yml @@ -14,6 +14,7 @@ jobs: matrix: os: [ubuntu-latest, macos-latest] python: [3.7, 3.8] + framework: [tensorflow, pytorch] steps: - if: matrix.os == 'macos-latest' name: Install MacOS prerequisites @@ -24,7 +25,8 @@ jobs: with: python-version: ${{ matrix.python }} architecture: x64 - - name: Cache python modules + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) uses: actions/cache@v2 with: path: ~/.cache/pip @@ -34,10 +36,27 @@ jobs: ${{ runner.os }}-pkg-deps-${{ matrix.python }}- ${{ runner.os }}-pkg-deps- ${{ runner.os }}- - - name: Install dependencies + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('**/*.py') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}- + ${{ runner.os }}-pkg-deps-${{ matrix.python }}- + ${{ runner.os }}-pkg-deps- + ${{ runner.os }}- + - if: matrix.framework == 'tensorflow' + name: Install package (TF) run: | python -m pip install --upgrade pip pip install -e .[tf] --upgrade + - if: matrix.framework == 'pytorch' + name: Install package (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade - name: Run analysis script run: | @@ -51,6 +70,7 @@ jobs: matrix: os: [ubuntu-latest, macos-latest] python: [3.7, 3.8] + framework: [tensorflow, pytorch] steps: - if: matrix.os == 'macos-latest' name: Install MacOS prerequisites @@ -61,7 +81,8 @@ jobs: with: python-version: ${{ matrix.python }} architecture: x64 - - name: Cache python modules + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) uses: actions/cache@v2 with: path: ~/.cache/pip @@ -71,13 +92,29 @@ jobs: ${{ runner.os }}-pkg-deps-${{ matrix.python }}- ${{ runner.os }}-pkg-deps- ${{ runner.os }}- - - name: Install dependencies + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('**/*.py') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}- + ${{ runner.os }}-pkg-deps-${{ matrix.python }}- + ${{ runner.os }}-pkg-deps- + ${{ runner.os }}- + - if: matrix.framework == 'tensorflow' + name: Install package (TF) run: | python -m pip install --upgrade pip pip install -e .[tf] --upgrade - - name: Run evaluation script + - if: matrix.framework == 'pytorch' + name: Install package (PT) run: | - python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10 + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + - name: Run evaluation script + run: python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10 test-collectenv: runs-on: ${{ matrix.os }} diff --git a/doctr/models/_utils.py b/doctr/models/_utils.py index b4c698f868..602f95444f 100644 --- a/doctr/models/_utils.py +++ b/doctr/models/_utils.py @@ -12,13 +12,14 @@ __all__ = ['estimate_orientation', 'extract_crops', 'extract_rcrops', 'get_bitmap_angle'] -def extract_crops(img: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]: +def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> List[np.ndarray]: """Created cropped images from list of bounding boxes Args: img: input image boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative coordinates (xmin, ymin, xmax, ymax) + channels_last: whether the channel dimensions is the last one instead of the last one Returns: list of cropped images @@ -36,16 +37,26 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]: _boxes = _boxes.round().astype(int) # Add last index _boxes[2:] += 1 - return [img[box[1]: box[3], box[0]: box[2]] for box in _boxes] + if channels_last: + return [img[box[1]: box[3], box[0]: box[2]] for box in _boxes] + else: + return [img[:, box[1]: box[3], box[0]: box[2]] for box in _boxes] -def extract_rcrops(img: np.ndarray, boxes: np.ndarray, dtype=np.float32) -> List[np.ndarray]: +def extract_rcrops( + img: np.ndarray, + boxes: np.ndarray, + dtype=np.float32, + channels_last: bool = True +) -> List[np.ndarray]: """Created cropped images from list of rotated bounding boxes Args: img: input image boxes: bounding boxes of shape (N, 5) where N is the number of boxes, and the relative coordinates (x, y, w, h, alpha) + dtype: target data type of bounding boxes + channels_last: whether the channel dimensions is the last one instead of the last one Returns: list of cropped images @@ -80,9 +91,9 @@ def extract_rcrops(img: np.ndarray, boxes: np.ndarray, dtype=np.float32) -> List M = cv2.getAffineTransform(src_pts, dst_pts) # Warp the rotated rectangle if clockwise: - crop = cv2.warpAffine(img, M, (int(w), int(h))) + crop = cv2.warpAffine(img if channels_last else img.transpose(1, 2, 0), M, (int(w), int(h))) else: - crop = cv2.warpAffine(img, M, (int(h), int(w))) + crop = cv2.warpAffine(img if channels_last else img.transpose(1, 2, 0), M, (int(h), int(w))) crops.append(crop) return crops diff --git a/doctr/models/core.py b/doctr/models/core.py index 06e5752a1a..fdc5389d04 100644 --- a/doctr/models/core.py +++ b/doctr/models/core.py @@ -10,6 +10,7 @@ from .detection import DetectionPredictor from .recognition import RecognitionPredictor from ._utils import extract_crops, extract_rcrops +from doctr.file_utils import is_torch_available from doctr.io.elements import Word, Line, Block, Page, Document from doctr.utils.repr import NestedObject from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes, rotate_image @@ -51,20 +52,35 @@ def __call__( # Localize text elements boxes = self.det_predictor(pages, **kwargs) + # Check whether crop mode should be switched to channels first + crop_kwargs = {} + if len(pages) > 0 and not isinstance(pages[0], np.ndarray) and is_torch_available(): + crop_kwargs['channels_last'] = False # Crop images, rotate page if necessary if self.doc_builder.rotated_bbox: - crops = [crop for page, (_boxes, angle) in zip(pages, boxes) for crop in - self.extract_crops_fn(rotate_image(page, -angle, False), _boxes[:, :-1])] # type: ignore[operator] + crops = [ + crop for page, (_boxes, angle) in zip(pages, boxes) for crop in + self.extract_crops_fn( # type: ignore[operator] + rotate_image(page, -angle, False), + _boxes[:, :-1], + **crop_kwargs + ) + ] else: crops = [crop for page, (_boxes, _) in zip(pages, boxes) for crop in - self.extract_crops_fn(page, _boxes[:, :-1])] # type: ignore[operator] + self.extract_crops_fn(page, _boxes[:, :-1], **crop_kwargs)] # type: ignore[operator] + # Avoid sending zero-sized crops + is_kept = [all(s > 0 for s in crop.shape) for crop in crops] + crops = [crop for crop, _kept in zip(crops, is_kept) if _kept] + boxes = [box for box, _kept in zip(boxes, is_kept) if _kept] # Identify character sequences word_preds = self.reco_predictor(crops, **kwargs) # Rotate back boxes if necessary - boxes, angles = zip(*boxes) - if self.doc_builder.rotated_bbox: - boxes = [rotate_boxes(boxes_page, angle) for boxes_page, angle in zip(boxes, angles)] + if len(boxes) > 0: + boxes, angles = zip(*boxes) + if self.doc_builder.rotated_bbox: + boxes = [rotate_boxes(boxes_page, angle) for boxes_page, angle in zip(boxes, angles)] out = self.doc_builder(boxes, word_preds, [page.shape[:2] for page in pages]) return out diff --git a/scripts/analyze.py b/scripts/analyze.py index d67896c313..7ff94ca129 100644 --- a/scripts/analyze.py +++ b/scripts/analyze.py @@ -8,25 +8,40 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" -import tensorflow as tf - -gpu_devices = tf.config.experimental.list_physical_devices('GPU') -if any(gpu_devices): - tf.config.experimental.set_memory_growth(gpu_devices[0], True) - from doctr.models import ocr_predictor from doctr.io import DocumentFile +from doctr.file_utils import is_tf_available + +# Enable GPU growth if using TF +if is_tf_available(): + import tensorflow as tf + gpu_devices = tf.config.experimental.list_physical_devices('GPU') + if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) +else: + import torch def main(args): model = ocr_predictor(args.detection, args.recognition, pretrained=True) + + if not is_tf_available(): + model.det_predictor.pre_processor = model.det_predictor.pre_processor.eval() + model.det_predictor.model = model.det_predictor.model.eval() + model.reco_predictor.pre_processor = model.reco_predictor.pre_processor.eval() + model.reco_predictor.model = model.reco_predictor.model.eval() + if args.path.endswith(".pdf"): doc = DocumentFile.from_pdf(args.path).as_images() else: doc = DocumentFile.from_images(args.path) - out = model(doc, training=False) + if is_tf_available(): + out = model(doc, training=False) + else: + with torch.no_grad(): + out = model(doc) for page, img in zip(out.pages, doc): page.show(img, block=not args.noblock, interactive=not args.static) diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 1373021e7b..db9ceb3133 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -4,26 +4,41 @@ # See LICENSE or go to for full license details. import os -import numpy as np -from tqdm import tqdm os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" -import tensorflow as tf - -gpu_devices = tf.config.experimental.list_physical_devices('GPU') -if any(gpu_devices): - tf.config.experimental.set_memory_growth(gpu_devices[0], True) +import numpy as np +from tqdm import tqdm from doctr.utils.metrics import LocalizationConfusion, TextMatch, OCRMetric from doctr import datasets from doctr.models import ocr_predictor, extract_crops +from doctr.file_utils import is_tf_available + +# Enable GPU growth if using TF +if is_tf_available(): + import tensorflow as tf + gpu_devices = tf.config.experimental.list_physical_devices('GPU') + if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) +else: + import torch + + +def _pct(val): + return "N/A" if val is None else f"{val:.2%}" def main(args): predictor = ocr_predictor(args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size) + if not is_tf_available(): + predictor.det_predictor.pre_processor = predictor.det_predictor.pre_processor.eval() + predictor.det_predictor.model = predictor.det_predictor.model.eval() + predictor.reco_predictor.pre_processor = predictor.reco_predictor.pre_processor.eval() + predictor.reco_predictor.model = predictor.reco_predictor.model.eval() + if args.img_folder and args.label_file: testset = datasets.OCRDataset( img_folder=args.img_folder, @@ -60,9 +75,17 @@ def main(args): gt_labels = target['labels'] # Forward - out = predictor(page[None, ...], training=False) - crops = extract_crops(page, gt_boxes) - reco_out = predictor.reco_predictor(crops, training=False) + if is_tf_available(): + out = predictor(page[None, ...], training=False) + crops = extract_crops(page, gt_boxes) + reco_out = predictor.reco_predictor(crops, training=False) + else: + with torch.no_grad(): + out = predictor(page[None, ...]) + # We directly crop on PyTorch tensors, which are in channels_first + crops = extract_crops(page, gt_boxes, channels_last=False) + reco_out = predictor.reco_predictor(crops) + if len(reco_out): reco_words, _ = zip(*reco_out) else: @@ -111,12 +134,12 @@ def main(args): print(f"Model Evaluation (model= {args.detection} + {args.recognition}, " f"dataset={'OCRDataset' if args.img_folder else args.dataset})") recall, precision, mean_iou = det_metric.summary() - print(f"Text Detection - Recall: {recall:.2%}, Precision: {precision:.2%}, Mean IoU: {mean_iou:.2%}") + print(f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}") acc = reco_metric.summary() - print(f"Text Recognition - Accuracy: {acc['raw']:.2%} (unicase: {acc['unicase']:.2%})") + print(f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})") recall, precision, mean_iou = e2e_metric.summary() - print(f"OCR - Recall: {recall['raw']:.2%} (unicase: {recall['unicase']:.2%}), " - f"Precision: {precision['raw']:.2%} (unicase: {precision['unicase']:.2%}), Mean IoU: {mean_iou:.2%}") + print(f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), " + f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}") def parse_args():