Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add straight-eval arg in evaluate script #793

Merged
merged 5 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(
loc_preds = self.det_predictor(pages, **kwargs)
# Check whether crop mode should be switched to channels first
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
# Crop images

crops, loc_preds = self._prepare_crops(
pages, loc_preds, channels_last=channels_last, assume_straight_pages=self.assume_straight_pages
)
Expand Down
4 changes: 4 additions & 0 deletions doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def __call__(

# Localize text elements
loc_preds = self.det_predictor(pages, **kwargs)

# Crop images
if not isinstance(pages[0], np.ndarray):
pages = [page.numpy() for page in pages]

charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
crops, loc_preds = self._prepare_crops(
pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages
)
Expand Down
74 changes: 53 additions & 21 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr import datasets
from doctr.file_utils import is_tf_available
from doctr.models import ocr_predictor
from doctr.models._utils import extract_crops
from doctr.models._utils import extract_crops, extract_rcrops
from doctr.utils.metrics import LocalizationConfusion, OCRMetric, TextMatch

# Enable GPU growth if using TF
Expand All @@ -32,7 +32,16 @@ def _pct(val):

def main(args):

predictor = ocr_predictor(args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size)
if not args.rotation:
args.eval_straight = True

predictor = ocr_predictor(
args.detection,
args.recognition,
pretrained=True,
reco_bs=args.batch_size,
assume_straight_pages=not args.rotation
)

if args.img_folder and args.label_file:
testset = datasets.OCRDataset(
Expand All @@ -41,25 +50,25 @@ def main(args):
)
sets = [testset]
else:
train_set = datasets.__dict__[args.dataset](train=True, download=True, use_polygons=args.rotation)
val_set = datasets.__dict__[args.dataset](train=False, download=True, use_polygons=args.rotation)
train_set = datasets.__dict__[args.dataset](train=True, download=True, use_polygons=not args.eval_straight)
val_set = datasets.__dict__[args.dataset](train=False, download=True, use_polygons=not args.eval_straight)
sets = [train_set, val_set]

reco_metric = TextMatch()
if args.rotation and args.mask_shape:
if args.mask_shape:
det_metric = LocalizationConfusion(
iou_thresh=args.iou,
use_polygons=args.rotation,
use_polygons=not args.eval_straight,
mask_shape=(args.mask_shape, args.mask_shape)
)
e2e_metric = OCRMetric(
iou_thresh=args.iou,
use_polygons=args.rotation,
use_polygons=not args.eval_straight,
mask_shape=(args.mask_shape, args.mask_shape)
)
else:
det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=args.rotation)
e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=args.rotation)
det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=not args.eval_straight)
e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=not args.eval_straight)

sample_idx = 0
for dataset in sets:
Expand All @@ -75,15 +84,16 @@ def main(args):
gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1)

# Forward
extraction_fn = extract_crops if args.eval_straight else extract_rcrops
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
if is_tf_available():
out = predictor(page[None, ...])
crops = extract_crops(page, gt_boxes)
crops = extraction_fn(page, gt_boxes)
reco_out = predictor.reco_predictor(crops)
else:
with torch.no_grad():
out = predictor(page[None, ...])
# We directly crop on PyTorch tensors, which are in channels_first
crops = extract_crops(page, gt_boxes, channels_last=False)
crops = extraction_fn(page, gt_boxes, channels_last=False)
reco_out = predictor.reco_predictor(crops)

if len(reco_out):
Expand All @@ -108,19 +118,39 @@ def main(args):
pred_boxes.append([int(a * width), int(b * height),
int(c * width), int(d * height)])
else:
pred_boxes.append(
[
[int(x1 * width), int(y1 * height)],
[int(x2 * width), int(y2 * height)],
[int(x3 * width), int(y3 * height)],
[int(x4 * width), int(y4 * height)],
]
)
if args.eval_straight:
pred_boxes.append(
[
int(width * min(x1, x2, x3, x4)),
int(height * min(y1, y2, y3, y4)),
int(width * max(x1, x2, x3, x4)),
int(height * max(y1, y2, y3, y4)),
]
)
else:
pred_boxes.append(
[
[int(x1 * width), int(y1 * height)],
[int(x2 * width), int(y2 * height)],
[int(x3 * width), int(y3 * height)],
[int(x4 * width), int(y4 * height)],
]
)
else:
if not args.rotation:
pred_boxes.append([a, b, c, d])
else:
pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
if args.eval_straight:
pred_boxes.append(
[
min(x1, x2, x3, x4),
min(y1, y2, y3, y4),
max(x1, x2, x3, x4),
max(y1, y2, y3, y4),
]
)
else:
pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
pred_labels.append(word.value)

# Update the metric
Expand Down Expand Up @@ -158,10 +188,12 @@ def parse_args():
parser.add_argument('--dataset', type=str, default='FUNSD', help='choose a dataset: FUNSD, CORD')
parser.add_argument('--img_folder', type=str, default=None, help='Only for local sets, path to images')
parser.add_argument('--label_file', type=str, default=None, help='Only for local sets, path to labels')
parser.add_argument('--rotation', dest='rotation', action='store_true', help='evaluate with rotated bbox')
parser.add_argument('--rotation', dest='rotation', action='store_true', help='run rotated OCR + postprocessing')
parser.add_argument('-b', '--batch_size', type=int, default=32, help='batch size for recognition')
parser.add_argument('--mask_shape', type=int, default=None, help='mask shape for mask iou (only for rotation)')
parser.add_argument('--samples', type=int, default=None, help='evaluate only on the N first samples')
parser.add_argument('--eval-straight', action='store_true',
help='evaluate on straight pages with straight bbox (to use the quick and light metric)')
args = parser.parse_args()

return args
Expand Down