Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fixed PyTorch tensor cropping and extended script support #458

Merged
merged 9 commits into from
Sep 6, 2021
49 changes: 43 additions & 6 deletions .github/workflows/scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest]
python: [3.7, 3.8]
framework: [tensorflow, pytorch]
steps:
- if: matrix.os == 'macos-latest'
name: Install MacOS prerequisites
Expand All @@ -24,7 +25,8 @@ jobs:
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Cache python modules
- if: matrix.framework == 'tensorflow'
name: Cache python modules (TF)
uses: actions/cache@v2
with:
path: ~/.cache/pip
Expand All @@ -34,10 +36,27 @@ jobs:
${{ runner.os }}-pkg-deps-${{ matrix.python }}-
${{ runner.os }}-pkg-deps-
${{ runner.os }}-
- name: Install dependencies
- if: matrix.framework == 'pytorch'
name: Cache python modules (PT)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('**/*.py') }}
restore-keys: |
${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-
${{ runner.os }}-pkg-deps-${{ matrix.python }}-
${{ runner.os }}-pkg-deps-
${{ runner.os }}-
- if: matrix.framework == 'tensorflow'
name: Install package (TF)
run: |
python -m pip install --upgrade pip
pip install -e .[tf] --upgrade
- if: matrix.framework == 'pytorch'
name: Install package (PT)
run: |
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade

- name: Run analysis script
run: |
Expand All @@ -51,6 +70,7 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest]
python: [3.7, 3.8]
framework: [tensorflow, pytorch]
steps:
- if: matrix.os == 'macos-latest'
name: Install MacOS prerequisites
Expand All @@ -61,7 +81,8 @@ jobs:
with:
python-version: ${{ matrix.python }}
architecture: x64
- name: Cache python modules
- if: matrix.framework == 'tensorflow'
name: Cache python modules (TF)
uses: actions/cache@v2
with:
path: ~/.cache/pip
Expand All @@ -71,13 +92,29 @@ jobs:
${{ runner.os }}-pkg-deps-${{ matrix.python }}-
${{ runner.os }}-pkg-deps-
${{ runner.os }}-
- name: Install dependencies
- if: matrix.framework == 'pytorch'
name: Cache python modules (PT)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('**/*.py') }}
restore-keys: |
${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-
${{ runner.os }}-pkg-deps-${{ matrix.python }}-
${{ runner.os }}-pkg-deps-
${{ runner.os }}-
- if: matrix.framework == 'tensorflow'
name: Install package (TF)
run: |
python -m pip install --upgrade pip
pip install -e .[tf] --upgrade
- name: Run evaluation script
- if: matrix.framework == 'pytorch'
name: Install package (PT)
run: |
python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade
- name: Run evaluation script
run: python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10

test-collectenv:
runs-on: ${{ matrix.os }}
Expand Down
21 changes: 16 additions & 5 deletions doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
__all__ = ['estimate_orientation', 'extract_crops', 'extract_rcrops', 'get_bitmap_angle']


def extract_crops(img: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]:
def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> List[np.ndarray]:
"""Created cropped images from list of bounding boxes

Args:
img: input image
boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative
coordinates (xmin, ymin, xmax, ymax)
channels_last: whether the channel dimensions is the last one instead of the last one

Returns:
list of cropped images
Expand All @@ -36,16 +37,26 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]:
_boxes = _boxes.round().astype(int)
# Add last index
_boxes[2:] += 1
return [img[box[1]: box[3], box[0]: box[2]] for box in _boxes]
if channels_last:
return [img[box[1]: box[3], box[0]: box[2]] for box in _boxes]
else:
return [img[:, box[1]: box[3], box[0]: box[2]] for box in _boxes]


def extract_rcrops(img: np.ndarray, boxes: np.ndarray, dtype=np.float32) -> List[np.ndarray]:
def extract_rcrops(
img: np.ndarray,
boxes: np.ndarray,
dtype=np.float32,
channels_last: bool = True
) -> List[np.ndarray]:
"""Created cropped images from list of rotated bounding boxes

Args:
img: input image
boxes: bounding boxes of shape (N, 5) where N is the number of boxes, and the relative
coordinates (x, y, w, h, alpha)
dtype: target data type of bounding boxes
channels_last: whether the channel dimensions is the last one instead of the last one

Returns:
list of cropped images
Expand Down Expand Up @@ -80,9 +91,9 @@ def extract_rcrops(img: np.ndarray, boxes: np.ndarray, dtype=np.float32) -> List
M = cv2.getAffineTransform(src_pts, dst_pts)
# Warp the rotated rectangle
if clockwise:
crop = cv2.warpAffine(img, M, (int(w), int(h)))
crop = cv2.warpAffine(img if channels_last else img.transpose(1, 2, 0), M, (int(w), int(h)))
else:
crop = cv2.warpAffine(img, M, (int(h), int(w)))
crop = cv2.warpAffine(img if channels_last else img.transpose(1, 2, 0), M, (int(h), int(w)))
crops.append(crop)

return crops
Expand Down
28 changes: 22 additions & 6 deletions doctr/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .detection import DetectionPredictor
from .recognition import RecognitionPredictor
from ._utils import extract_crops, extract_rcrops
from doctr.file_utils import is_torch_available
from doctr.io.elements import Word, Line, Block, Page, Document
from doctr.utils.repr import NestedObject
from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes, rotate_image
Expand Down Expand Up @@ -51,20 +52,35 @@ def __call__(

# Localize text elements
boxes = self.det_predictor(pages, **kwargs)
# Check whether crop mode should be switched to channels first
crop_kwargs = {}
if len(pages) > 0 and not isinstance(pages[0], np.ndarray) and is_torch_available():
crop_kwargs['channels_last'] = False
# Crop images, rotate page if necessary
if self.doc_builder.rotated_bbox:
crops = [crop for page, (_boxes, angle) in zip(pages, boxes) for crop in
self.extract_crops_fn(rotate_image(page, -angle, False), _boxes[:, :-1])] # type: ignore[operator]
crops = [
crop for page, (_boxes, angle) in zip(pages, boxes) for crop in
self.extract_crops_fn( # type: ignore[operator]
rotate_image(page, -angle, False),
_boxes[:, :-1],
**crop_kwargs
)
]
else:
crops = [crop for page, (_boxes, _) in zip(pages, boxes) for crop in
self.extract_crops_fn(page, _boxes[:, :-1])] # type: ignore[operator]
self.extract_crops_fn(page, _boxes[:, :-1], **crop_kwargs)] # type: ignore[operator]
# Avoid sending zero-sized crops
is_kept = [all(s > 0 for s in crop.shape) for crop in crops]
crops = [crop for crop, _kept in zip(crops, is_kept) if _kept]
boxes = [box for box, _kept in zip(boxes, is_kept) if _kept]
# Identify character sequences
word_preds = self.reco_predictor(crops, **kwargs)

# Rotate back boxes if necessary
boxes, angles = zip(*boxes)
if self.doc_builder.rotated_bbox:
boxes = [rotate_boxes(boxes_page, angle) for boxes_page, angle in zip(boxes, angles)]
if len(boxes) > 0:
boxes, angles = zip(*boxes)
if self.doc_builder.rotated_bbox:
boxes = [rotate_boxes(boxes_page, angle) for boxes_page, angle in zip(boxes, angles)]
out = self.doc_builder(boxes, word_preds, [page.shape[:2] for page in pages])
return out

Expand Down
29 changes: 22 additions & 7 deletions scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,40 @@

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

from doctr.models import ocr_predictor
from doctr.io import DocumentFile
from doctr.file_utils import is_tf_available

# Enable GPU growth if using TF
if is_tf_available():
import tensorflow as tf
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
else:
import torch


def main(args):

model = ocr_predictor(args.detection, args.recognition, pretrained=True)

if not is_tf_available():
model.det_predictor.pre_processor = model.det_predictor.pre_processor.eval()
model.det_predictor.model = model.det_predictor.model.eval()
model.reco_predictor.pre_processor = model.reco_predictor.pre_processor.eval()
model.reco_predictor.model = model.reco_predictor.model.eval()

if args.path.endswith(".pdf"):
doc = DocumentFile.from_pdf(args.path).as_images()
else:
doc = DocumentFile.from_images(args.path)

out = model(doc, training=False)
if is_tf_available():
out = model(doc, training=False)
else:
with torch.no_grad():
out = model(doc)

for page, img in zip(out.pages, doc):
page.show(img, block=not args.noblock, interactive=not args.static)
Expand Down
51 changes: 37 additions & 14 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,41 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import os
import numpy as np
from tqdm import tqdm

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
import numpy as np
from tqdm import tqdm

from doctr.utils.metrics import LocalizationConfusion, TextMatch, OCRMetric
from doctr import datasets
from doctr.models import ocr_predictor, extract_crops
from doctr.file_utils import is_tf_available

# Enable GPU growth if using TF
if is_tf_available():
import tensorflow as tf
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)
else:
import torch


def _pct(val):
return "N/A" if val is None else f"{val:.2%}"


def main(args):

predictor = ocr_predictor(args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size)

if not is_tf_available():
predictor.det_predictor.pre_processor = predictor.det_predictor.pre_processor.eval()
predictor.det_predictor.model = predictor.det_predictor.model.eval()
predictor.reco_predictor.pre_processor = predictor.reco_predictor.pre_processor.eval()
predictor.reco_predictor.model = predictor.reco_predictor.model.eval()

if args.img_folder and args.label_file:
testset = datasets.OCRDataset(
img_folder=args.img_folder,
Expand Down Expand Up @@ -60,9 +75,17 @@ def main(args):
gt_labels = target['labels']

# Forward
out = predictor(page[None, ...], training=False)
crops = extract_crops(page, gt_boxes)
reco_out = predictor.reco_predictor(crops, training=False)
if is_tf_available():
out = predictor(page[None, ...], training=False)
crops = extract_crops(page, gt_boxes)
reco_out = predictor.reco_predictor(crops, training=False)
else:
with torch.no_grad():
out = predictor(page[None, ...])
# We directly crop on PyTorch tensors, which are in channels_first
crops = extract_crops(page, gt_boxes, channels_last=False)
reco_out = predictor.reco_predictor(crops)

if len(reco_out):
reco_words, _ = zip(*reco_out)
else:
Expand Down Expand Up @@ -111,12 +134,12 @@ def main(args):
print(f"Model Evaluation (model= {args.detection} + {args.recognition}, "
f"dataset={'OCRDataset' if args.img_folder else args.dataset})")
recall, precision, mean_iou = det_metric.summary()
print(f"Text Detection - Recall: {recall:.2%}, Precision: {precision:.2%}, Mean IoU: {mean_iou:.2%}")
print(f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}")
acc = reco_metric.summary()
print(f"Text Recognition - Accuracy: {acc['raw']:.2%} (unicase: {acc['unicase']:.2%})")
print(f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})")
recall, precision, mean_iou = e2e_metric.summary()
print(f"OCR - Recall: {recall['raw']:.2%} (unicase: {recall['unicase']:.2%}), "
f"Precision: {precision['raw']:.2%} (unicase: {precision['unicase']:.2%}), Mean IoU: {mean_iou:.2%}")
print(f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), "
f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}")


def parse_args():
Expand Down