Skip to content

Commit

Permalink
reopening #597
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Nov 13, 2021
1 parent 400aec0 commit e719034
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Here are all datasets that are available through docTR:
.. autoclass:: CharacterGenerator
.. autoclass:: DocArtefacts
.. autoclass:: IIIT5K
.. autoclass:: SVT


Data Loading
Expand Down
1 change: 1 addition & 0 deletions doctr/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .ocr import *
from .recognition import *
from .sroie import *
from .svt import *
from .utils import *
from .vocabs import *

Expand Down
90 changes: 90 additions & 0 deletions doctr/datasets/svt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (C) 2021, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple

import defusedxml.ElementTree as ET
import numpy as np

from .datasets import VisionDataset

__all__ = ['SVT']


class SVT(VisionDataset):
"""SVT dataset from `"The Street View Text Dataset - UCSD Computer Vision"
<http://vision.ucsd.edu/~kai/svt/>`_.
Example::
>>> from doctr.datasets import SVT
>>> train_set = SVT(train=True, download=True)
>>> img, target = train_set[0]
Args:
train: whether the subset should be the training one
sample_transforms: composable transformations that will be applied to each image
rotated_bbox: whether polygons should be considered as rotated bounding box (instead of straight ones)
**kwargs: keyword arguments from `VisionDataset`.
"""

URL = 'http://vision.ucsd.edu/~kai/svt/svt.zip'
SHA256 = '63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf'

def __init__(
self,
train: bool = True,
sample_transforms: Optional[Callable[[Any], Any]] = None,
rotated_bbox: bool = False,
**kwargs: Any,
) -> None:

super().__init__(self.URL, None, self.SHA256, True, **kwargs)
self.sample_transforms = sample_transforms
self.train = train
self.data: List[Tuple[Path, Dict[str, Any]]] = []
np_dtype = np.float16 if self.fp16 else np.float32

# Load xml data
tmp_root = os.path.join(self.root, 'svt1')
xml_tree = ET.parse(os.path.join(tmp_root, 'train.xml')) if self.train else ET.parse(
os.path.join(tmp_root, 'test.xml'))
xml_root = xml_tree.getroot()

for image in xml_root:
for image_attributes in image:
if image_attributes.tag == 'imageName':
_raw_path = image_attributes.text

# File existence check
if not os.path.exists(os.path.join(tmp_root, _raw_path)):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}")

if rotated_bbox:
# x_center, y_center, w, h, 0
_tmp_box_targets = [
(int(rect_tag.attrib['x']) + int(rect_tag.attrib['width']) / 2,
int(rect_tag.attrib['y']) + int(rect_tag.attrib['height']) / 2,
int(rect_tag.attrib['width']), int(rect_tag.attrib['height']), 0)
for rect_tag in image_attributes
]
else:
# xmin, ymin, xmax, ymax
_tmp_box_targets = [
(int(rect_tag.attrib['x']), int(rect_tag.attrib['y']), # type: ignore[misc]
int(rect_tag.attrib['x']) + int(rect_tag.attrib['width']),
int(rect_tag.attrib['y']) + int(rect_tag.attrib['height']))
for rect_tag in image_attributes
]
_tmp_labels = [lab.text for image_attributes in image for rect_tag in image_attributes for lab in rect_tag]

self.data.append((Path(_raw_path), dict(boxes=np.asarray(
_tmp_box_targets, dtype=np_dtype), labels=_tmp_labels)))

self.root = tmp_root

def extra_repr(self) -> str:
return f"train={self.train}"
2 changes: 2 additions & 0 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def test_visiondataset():
['DocArtefacts', False, [512, 512], 300, True],
['IIIT5K', True, [32, 128], 2000, True],
['IIIT5K', False, [32, 128], 3000, False],
['SVT', True, [512, 512], 100, True],
['SVT', False, [512, 512], 249, False],
],
)
def test_dataset(dataset_name, train, input_size, size, rotate):
Expand Down
2 changes: 2 additions & 0 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
['DocArtefacts', False, [512, 512], 300, True],
['IIIT5K', True, [32, 128], 2000, True],
['IIIT5K', False, [32, 128], 3000, False],
['SVT', True, [512, 512], 100, True],
['SVT', False, [512, 512], 249, False],
],
)
def test_dataset(dataset_name, train, input_size, size, rotate):
Expand Down

0 comments on commit e719034

Please sign in to comment.