Skip to content

Commit

Permalink
feat: Added SVHN dataset (#634)
Browse files Browse the repository at this point in the history
* start synth

* cleanup

* start synth

* add synthtext

* add docu and tests

* apply code factor suggestions

* apply changes

* clean

* svhn start

* apply some changes and add h5py

* solve h5py keras conflict

* test dep change

* apply changes

* apply changes

* change desc synth and svhn to Reading

* apply changes

* fix style

* start svhn mock

* fix conftest and setup

* fix mock dataset path

* ftm fix

* apply changes

* apply changes

* apply changes
  • Loading branch information
felixdittrich92 committed Dec 14, 2021
1 parent 618b0fa commit c32c1ed
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Here are all datasets that are available through docTR:
.. autoclass:: DocArtefacts
.. autoclass:: IIIT5K
.. autoclass:: SVT
.. autoclass:: SVHN
.. autoclass:: SynthText
.. autoclass:: IC03
.. autoclass:: IC13
Expand Down
1 change: 1 addition & 0 deletions doctr/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .ocr import *
from .recognition import *
from .sroie import *
from .svhn import *
from .svt import *
from .synthtext import *
from .utils import *
Expand Down
108 changes: 108 additions & 0 deletions doctr/datasets/svhn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (C) 2021, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import h5py
import numpy as np
from tqdm import tqdm

from .datasets import VisionDataset

__all__ = ['SVHN']


class SVHN(VisionDataset):
"""SVHN dataset from `"The Street View House Numbers (SVHN) Dataset"
<http://ufldl.stanford.edu/housenumbers/>`_.
Example::
>>> from doctr.datasets import SVHN
>>> train_set = SVHN(train=True, download=True)
>>> img, target = train_set[0]
Args:
train: whether the subset should be the training one
sample_transforms: composable transformations that will be applied to each image
rotated_bbox: whether polygons should be considered as rotated bounding box (instead of straight ones)
**kwargs: keyword arguments from `VisionDataset`.
"""
TRAIN = ('http://ufldl.stanford.edu/housenumbers/train.tar.gz',
'4b17bb33b6cd8f963493168f80143da956f28ec406cc12f8e5745a9f91a51898',
'svhn_train.tar')

TEST = ('http://ufldl.stanford.edu/housenumbers/test.tar.gz',
'57ac9ceb530e4aa85b55d991be8fc49c695b3d71c6f6a88afea86549efde7fb5',
'svhn_test.tar')

def __init__(
self,
train: bool = True,
sample_transforms: Optional[Callable[[Any], Any]] = None,
rotated_bbox: bool = False,
**kwargs: Any,
) -> None:

url, sha256, name = self.TRAIN if train else self.TEST
super().__init__(url=url, file_name=name, file_hash=sha256, extract_archive=True, **kwargs)
self.sample_transforms = sample_transforms
self.train = train
self.data: List[Tuple[str, Dict[str, Any]]] = []
np_dtype = np.float32

tmp_root = os.path.join(self.root, 'train' if train else 'test')

# Load mat data (matlab v7.3 - can not be loaded with scipy)
with h5py.File(os.path.join(tmp_root, 'digitStruct.mat'), 'r') as f:
img_refs = f['digitStruct/name']
box_refs = f['digitStruct/bbox']
for img_ref, box_ref in tqdm(iterable=zip(img_refs, box_refs), desc='Unpacking SVHN', total=len(img_refs)):
# convert ascii matrix to string
img_name = "".join(map(chr, f[img_ref[0]][()].flatten()))

# File existence check
if not os.path.exists(os.path.join(tmp_root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")

# Unpack the information
box = f[box_ref[0]]
if box['left'].shape[0] == 1:
box_dict = {k: [int(vals[0][0])] for k, vals in box.items()}
else:
box_dict = {k: [int(f[v[0]][()].item()) for v in vals] for k, vals in box.items()}

# Convert it to the right format
coords = np.array([
box_dict['left'],
box_dict['top'],
box_dict['width'],
box_dict['height']
], dtype=np_dtype).transpose()
label_targets = list(map(str, box_dict['label']))

if rotated_bbox:
# x_center, y_center, w, h, alpha = 0
box_targets = np.stack([
coords[:, 0] + coords[:, 2] / 2,
coords[:, 1] + coords[:, 3] / 2,
coords[:, 2],
coords[:, 3],
np.zeros(coords.shape[0], dtype=np.dtype),
], axis=-1)
else:
# x, y, width, height -> xmin, ymin, xmax, ymax
box_targets = np.stack([
coords[:, 0],
coords[:, 1],
coords[:, 0] + coords[:, 2],
coords[:, 1] + coords[:, 3],
], axis=-1)
self.data.append((img_name, dict(boxes=box_targets, labels=label_targets)))

self.root = tmp_root

def extra_repr(self) -> str:
return f"train={self.train}"
2 changes: 1 addition & 1 deletion doctr/datasets/synthtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
np_dtype = np.float32

for img_path, word_boxes, txt in tqdm(iterable=zip(paths, boxes, labels),
desc='Loading SynthText...', total=len(paths)):
desc='Unpacking SynthText', total=len(paths)):

# File existence check
if not os.path.exists(os.path.join(tmp_root, img_path[0])):
Expand Down
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,7 @@ ignore_missing_imports = True
[mypy-defusedxml.*]

ignore_missing_imports = True

[mypy-h5py.*]

ignore_missing_imports = True
1 change: 1 addition & 0 deletions requirements-pt.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
numpy>=1.16.0
scipy>=1.4.0
h5py>=3.1.0
opencv-python>=3.4.5.20
PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12
pyclipper>=1.2.0
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
numpy>=1.16.0
scipy>=1.4.0
h5py>=3.1.0
opencv-python>=3.4.5.20
PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12
pyclipper>=1.2.0
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"importlib_metadata",
"numpy>=1.16.0",
"scipy>=1.4.0",
"h5py>=3.1.0",
"opencv-python>=3.4.5.20",
"tensorflow>=2.4.0",
"PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12", # 18.11 and 18.12 fail (issue #222)
Expand All @@ -62,6 +63,7 @@
# Testing
"pytest>=5.3.2",
"coverage>=4.5.4",
"hdf5storage>=0.1.18",
"requests>=2.20.0",
"requirements-parser==0.2.0",
# Quality
Expand Down Expand Up @@ -90,6 +92,7 @@ def deps_list(*pkgs):
deps["importlib_metadata"] + ";python_version<'3.8'", # importlib_metadata for Python versions that don't have it
deps["numpy"],
deps["scipy"],
deps["h5py"],
deps["opencv-python"],
deps["PyMuPDF"],
deps["pyclipper"],
Expand Down Expand Up @@ -130,6 +133,7 @@ def deps_list(*pkgs):
"pytest",
"coverage",
"requests",
"hdf5storage",
"requirements-parser",
)

Expand Down
30 changes: 30 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
import shutil
from io import BytesIO

import hdf5storage
import numpy as np
import pytest
import requests

Expand Down Expand Up @@ -144,3 +147,30 @@ def mock_ic13(tmpdir_factory, mock_image_stream):
with open(fn_i, 'wb') as f:
f.write(file.getbuffer())
return str(image_folder), str(label_folder)


@pytest.fixture(scope="session")
def mock_svhn_dataset(tmpdir_factory, mock_image_stream):
root = tmpdir_factory.mktemp('datasets')
svhn_root = root.mkdir('svhn')
file = BytesIO(mock_image_stream)
# ascii image names
first = np.array([[49], [46], [112], [110], [103]], dtype=np.int16) # 1.png
second = np.array([[50], [46], [112], [110], [103]], dtype=np.int16) # 2.png
third = np.array([[51], [46], [112], [110], [103]], dtype=np.int16) # 3.png
# labels: label is also ascii
label = {'height': [35, 35, 35, 35], 'label': [1, 1, 3, 7],
'left': [116, 128, 137, 151], 'top': [27, 29, 29, 26],
'width': [15, 10, 17, 17]}

matcontent = {'digitStruct': {'name': [first, second, third], 'bbox': [label, label, label]}}
# Mock train data
train_root = svhn_root.mkdir('train')
hdf5storage.write(matcontent, filename=train_root.join('digitStruct.mat'))
for i in range(3):
fn = train_root.join(f'{i+1}.png')
with open(fn, 'wb') as f:
f.write(file.getbuffer())
# Packing data into an archive to simulate the real data set and bypass archive extraction
shutil.make_archive(svhn_root.join('svhn_train'), 'tar', str(svhn_root))
return str(root)
33 changes: 33 additions & 0 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,36 @@ def test_ic13_dataset(mock_ic13, size, rotate):
images, targets = next(iter(loader))
assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size)
assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets)


@pytest.mark.parametrize(
"input_size, size, rotate",
[
[[32, 128], 3, True], # Actual set has 33402 training samples and 13068 test samples
[[32, 128], 3, False],
],
)
def test_svhn(input_size, size, rotate, mock_svhn_dataset):
# monkeypatch the path to temporary dataset
datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar")

ds = datasets.SVHN(
train=True, download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate,
cache_dir=mock_svhn_dataset, cache_subdir="svhn",
)

assert len(ds) == size
assert repr(ds) == f"SVHN(train={True})"
img, target = ds[0]
assert isinstance(img, torch.Tensor)
assert img.shape == (3, *input_size)
assert img.dtype == torch.float32
assert isinstance(target, dict)

loader = DataLoader(
ds, batch_size=2, drop_last=True, sampler=RandomSampler(ds), num_workers=0, pin_memory=True,
collate_fn=ds.collate_fn)

images, targets = next(iter(loader))
assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size)
assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets)
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pytest>=5.3.2
requests>=2.20.0
hdf5storage>=0.1.18
coverage>=4.5.4
requirements-parser==0.2.0
30 changes: 30 additions & 0 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,33 @@ def test_ic13_dataset(mock_ic13, size, rotate):
images, targets = next(iter(loader))
assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3)
assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets)


@pytest.mark.parametrize(
"input_size, size, rotate",
[
[[32, 128], 3, True], # Actual set has 33402 training samples and 13068 test samples
[[32, 128], 3, False],
],
)
def test_svhn(input_size, size, rotate, mock_svhn_dataset):
# monkeypatch the path to temporary dataset
datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar")

ds = datasets.SVHN(
train=True, download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate,
cache_dir=mock_svhn_dataset, cache_subdir="svhn",
)

assert len(ds) == size
assert repr(ds) == f"SVHN(train={True})"
img, target = ds[0]
assert isinstance(img, tf.Tensor)
assert img.shape == (*input_size, 3)
assert img.dtype == tf.float32
assert isinstance(target, dict)

loader = datasets.DataLoader(ds, batch_size=2)
images, targets = next(iter(loader))
assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3)
assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets)

0 comments on commit c32c1ed

Please sign in to comment.