Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reopening #597 SVT dataset integration #620

Merged
merged 5 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Here are all datasets that are available through docTR:
.. autoclass:: CharacterGenerator
.. autoclass:: DocArtefacts
.. autoclass:: IIIT5K
.. autoclass:: SVT


Data Loading
Expand Down
1 change: 1 addition & 0 deletions doctr/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .ocr import *
from .recognition import *
from .sroie import *
from .svt import *
from .utils import *
from .vocabs import *

Expand Down
91 changes: 91 additions & 0 deletions doctr/datasets/svt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (C) 2021, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import defusedxml.ElementTree as ET
import numpy as np

from .datasets import VisionDataset

__all__ = ['SVT']


class SVT(VisionDataset):
"""SVT dataset from `"The Street View Text Dataset - UCSD Computer Vision"
<http://vision.ucsd.edu/~kai/svt/>`_.

Example::
>>> from doctr.datasets import SVT
>>> train_set = SVT(train=True, download=True)
>>> img, target = train_set[0]

Args:
train: whether the subset should be the training one
sample_transforms: composable transformations that will be applied to each image
rotated_bbox: whether polygons should be considered as rotated bounding box (instead of straight ones)
**kwargs: keyword arguments from `VisionDataset`.
"""

URL = 'http://vision.ucsd.edu/~kai/svt/svt.zip'
SHA256 = '63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf'

def __init__(
self,
train: bool = True,
sample_transforms: Optional[Callable[[Any], Any]] = None,
rotated_bbox: bool = False,
**kwargs: Any,
) -> None:

super().__init__(self.URL, None, self.SHA256, True, **kwargs)
self.sample_transforms = sample_transforms
self.train = train
self.data: List[Tuple[str, Dict[str, Any]]] = []
np_dtype = np.float16 if self.fp16 else np.float32

# Load xml data
tmp_root = os.path.join(self.root, 'svt1')
xml_tree = ET.parse(os.path.join(tmp_root, 'train.xml')) if self.train else ET.parse(
os.path.join(tmp_root, 'test.xml'))
xml_root = xml_tree.getroot()

for image in xml_root:
name, _, _, resolution, rectangles = image

# File existence check
if not os.path.exists(os.path.join(tmp_root, name.text)):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}")

if rotated_bbox:
_boxes = [
[float(rect.attrib['x']) + float(rect.attrib['width']) / 2,
float(rect.attrib['y']) + float(rect.attrib['height']) / 2,
float(rect.attrib['width']), float(rect.attrib['height'])]
for rect in rectangles
]
else:
_boxes = [
[float(rect.attrib['x']), float(rect.attrib['y']),
float(rect.attrib['x']) + float(rect.attrib['width']),
float(rect.attrib['y']) + float(rect.attrib['height'])]
for rect in rectangles
]
# Convert them to relative
w, h = int(resolution.attrib['x']), int(resolution.attrib['y'])
boxes = np.asarray(_boxes, dtype=np_dtype)
boxes[:, [0, 2]] /= w
boxes[:, [1, 3]] /= h

# Get the labels
labels = [lab.text for rect in rectangles for lab in rect]

self.data.append((name.text, dict(boxes=boxes, labels=labels)))

self.root = tmp_root

def extra_repr(self) -> str:
return f"train={self.train}"
2 changes: 2 additions & 0 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def test_visiondataset():
['DocArtefacts', False, [512, 512], 300, True],
['IIIT5K', True, [32, 128], 2000, True],
['IIIT5K', False, [32, 128], 3000, False],
['SVT', True, [512, 512], 100, True],
['SVT', False, [512, 512], 249, False],
],
)
def test_dataset(dataset_name, train, input_size, size, rotate):
Expand Down
2 changes: 2 additions & 0 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
['DocArtefacts', False, [512, 512], 300, True],
['IIIT5K', True, [32, 128], 2000, True],
['IIIT5K', False, [32, 128], 3000, False],
['SVT', True, [512, 512], 100, True],
['SVT', False, [512, 512], 249, False],
],
)
def test_dataset(dataset_name, train, input_size, size, rotate):
Expand Down