diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index 848157951c..502d597fc9 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -75,6 +75,7 @@ def forward( loc_preds = self.det_predictor(pages, **kwargs) # Check whether crop mode should be switched to channels first channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray) + # Crop images crops, loc_preds = self._prepare_crops( pages, loc_preds, channels_last=channels_last, assume_straight_pages=self.assume_straight_pages diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index 0244a4b118..eb7d48178f 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -74,6 +74,7 @@ def __call__( # Localize text elements loc_preds = self.det_predictor(pages, **kwargs) + # Crop images crops, loc_preds = self._prepare_crops( pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages diff --git a/scripts/evaluate.py b/scripts/evaluate.py index dcf4da1e0f..08f9d15fde 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -13,7 +13,7 @@ from doctr import datasets from doctr.file_utils import is_tf_available from doctr.models import ocr_predictor -from doctr.models._utils import extract_crops +from doctr.models._utils import extract_crops, extract_rcrops from doctr.utils.metrics import LocalizationConfusion, OCRMetric, TextMatch # Enable GPU growth if using TF @@ -32,7 +32,16 @@ def _pct(val): def main(args): - predictor = ocr_predictor(args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size) + if not args.rotation: + args.eval_straight = True + + predictor = ocr_predictor( + args.detection, + args.recognition, + pretrained=True, + reco_bs=args.batch_size, + assume_straight_pages=not args.rotation + ) if args.img_folder and args.label_file: testset = datasets.OCRDataset( @@ -41,27 +50,29 @@ def main(args): ) sets = [testset] else: - train_set = datasets.__dict__[args.dataset](train=True, download=True, use_polygons=args.rotation) - val_set = datasets.__dict__[args.dataset](train=False, download=True, use_polygons=args.rotation) + train_set = datasets.__dict__[args.dataset](train=True, download=True, use_polygons=not args.eval_straight) + val_set = datasets.__dict__[args.dataset](train=False, download=True, use_polygons=not args.eval_straight) sets = [train_set, val_set] reco_metric = TextMatch() - if args.rotation and args.mask_shape: + if args.mask_shape: det_metric = LocalizationConfusion( iou_thresh=args.iou, - use_polygons=args.rotation, + use_polygons=not args.eval_straight, mask_shape=(args.mask_shape, args.mask_shape) ) e2e_metric = OCRMetric( iou_thresh=args.iou, - use_polygons=args.rotation, + use_polygons=not args.eval_straight, mask_shape=(args.mask_shape, args.mask_shape) ) else: - det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=args.rotation) - e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=args.rotation) + det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight) + e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight) sample_idx = 0 + extraction_fn = extract_crops if args.eval_straight else extract_rcrops + for dataset in sets: for page, target in tqdm(dataset): # GT @@ -77,13 +88,13 @@ def main(args): # Forward if is_tf_available(): out = predictor(page[None, ...]) - crops = extract_crops(page, gt_boxes) + crops = extraction_fn(page, gt_boxes) reco_out = predictor.reco_predictor(crops) else: with torch.no_grad(): out = predictor(page[None, ...]) # We directly crop on PyTorch tensors, which are in channels_first - crops = extract_crops(page, gt_boxes, channels_last=False) + crops = extraction_fn(page, gt_boxes, channels_last=False) reco_out = predictor.reco_predictor(crops) if len(reco_out): @@ -108,19 +119,39 @@ def main(args): pred_boxes.append([int(a * width), int(b * height), int(c * width), int(d * height)]) else: - pred_boxes.append( - [ - [int(x1 * width), int(y1 * height)], - [int(x2 * width), int(y2 * height)], - [int(x3 * width), int(y3 * height)], - [int(x4 * width), int(y4 * height)], - ] - ) + if args.eval_straight: + pred_boxes.append( + [ + int(width * min(x1, x2, x3, x4)), + int(height * min(y1, y2, y3, y4)), + int(width * max(x1, x2, x3, x4)), + int(height * max(y1, y2, y3, y4)), + ] + ) + else: + pred_boxes.append( + [ + [int(x1 * width), int(y1 * height)], + [int(x2 * width), int(y2 * height)], + [int(x3 * width), int(y3 * height)], + [int(x4 * width), int(y4 * height)], + ] + ) else: if not args.rotation: pred_boxes.append([a, b, c, d]) else: - pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) + if args.eval_straight: + pred_boxes.append( + [ + min(x1, x2, x3, x4), + min(y1, y2, y3, y4), + max(x1, x2, x3, x4), + max(y1, y2, y3, y4), + ] + ) + else: + pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) pred_labels.append(word.value) # Update the metric @@ -158,10 +189,12 @@ def parse_args(): parser.add_argument('--dataset', type=str, default='FUNSD', help='choose a dataset: FUNSD, CORD') parser.add_argument('--img_folder', type=str, default=None, help='Only for local sets, path to images') parser.add_argument('--label_file', type=str, default=None, help='Only for local sets, path to labels') - parser.add_argument('--rotation', dest='rotation', action='store_true', help='evaluate with rotated bbox') + parser.add_argument('--rotation', dest='rotation', action='store_true', help='run rotated OCR + postprocessing') parser.add_argument('-b', '--batch_size', type=int, default=32, help='batch size for recognition') parser.add_argument('--mask_shape', type=int, default=None, help='mask shape for mask iou (only for rotation)') parser.add_argument('--samples', type=int, default=None, help='evaluate only on the N first samples') + parser.add_argument('--eval-straight', action='store_true', + help='evaluate on straight pages with straight bbox (to use the quick and light metric)') args = parser.parse_args() return args