Skip to content

Commit

Permalink
feat: Adds support for Imgur5k dataset (#785)
Browse files Browse the repository at this point in the history
* start synth

* cleanup

* start synth

* add synthtext

* add docu and tests

* apply code factor suggestions

* apply changes

* clean

* start imgur5k

* up

* update box computation

* make flake happy

* filter images without boxes

* aqpply changes

* change desc
  • Loading branch information
felixdittrich92 committed Jan 10, 2022
1 parent ea1c351 commit 0da7ce0
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Public datasets
.. autoclass:: SynthText
.. autoclass:: IC03
.. autoclass:: IC13
.. autoclass:: IMGUR5K

docTR synthetic datasets
^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions doctr/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .ic03 import *
from .ic13 import *
from .iiit5k import *
from .imgur5k import *
from .ocr import *
from .recognition import *
from .sroie import *
Expand Down
99 changes: 99 additions & 0 deletions doctr/datasets/imgur5k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (C) 2021-2022, Mindee.

# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import json
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple

import cv2
import numpy as np

from .datasets import AbstractDataset
from .utils import convert_target_to_relative

__all__ = ["IMGUR5K"]


class IMGUR5K(AbstractDataset):
"""IMGUR5K dataset from `"TextStyleBrush: Transfer of Text Aesthetics from a Single Example"
<https://arxiv.org/abs/2106.08385>`_ |
`"repository" <https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset>`_.
Example::
>>> # NOTE: You need to download/generate the dataset from the repository.
>>> from doctr.datasets import IMGUR5K
>>> train_set = IMGUR5K(train=True, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images",
>>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json")
>>> img, target = train_set[0]
>>> test_set = IMGUR5K(train=False, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images",
>>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json")
>>> img, target = test_set[0]
Args:
img_folder: folder with all the images of the dataset
label_path: path to the annotations file of the dataset
train: whether the subset should be the training one
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
**kwargs: keyword arguments from `AbstractDataset`.
"""

def __init__(
self,
img_folder: str,
label_path: str,
train: bool = True,
use_polygons: bool = False,
**kwargs: Any,
) -> None:
super().__init__(img_folder, pre_transforms=convert_target_to_relative, **kwargs)

# File existence check
if not os.path.exists(label_path) or not os.path.exists(img_folder):
raise FileNotFoundError(
f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}")

self.data: List[Tuple[Path, Dict[str, Any]]] = []
self.train = train
np_dtype = np.float32

img_names = os.listdir(img_folder)
train_samples = int(len(img_names) * 0.9)
set_slice = slice(train_samples) if self.train else slice(train_samples, None)

with open(label_path) as f:
annotation_file = json.load(f)

for img_name in img_names[set_slice]:
img_path = Path(img_folder, img_name)
img_id = img_name.split(".")[0]

# File existence check
if not os.path.exists(os.path.join(self.root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}")

# some files have no annotations which are marked with only a dot in the 'word' key
# ref: https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset/blob/main/README.md
if img_id not in annotation_file['index_to_ann_map'].keys():
continue
ann_ids = annotation_file['index_to_ann_map'][img_id]
annotations = [annotation_file['ann_id'][a_id] for a_id in ann_ids]

labels = [ann['word'] for ann in annotations if ann['word'] != '.']
# x_center, y_center, width, height, angle
_boxes = [list(map(float, ann['bounding_box'].strip('[ ]').split(', ')))
for ann in annotations if ann['word'] != '.']
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes]

if not use_polygons:
# xmin, ymin, xmax, ymax
box_targets = [np.concatenate((points.min(0), points.max(0)), axis=-1) for points in box_targets]

# filter images without boxes
if len(box_targets) > 0:
self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels)))

def extra_repr(self) -> str:
return f"train={self.train}"
71 changes: 71 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,77 @@ def mock_ic13(tmpdir_factory, mock_image_stream):
return str(image_folder), str(label_folder)


@pytest.fixture(scope="session")
def mock_imgur5k(tmpdir_factory, mock_image_stream):
file = BytesIO(mock_image_stream)
image_folder = tmpdir_factory.mktemp("images")
label_folder = tmpdir_factory.mktemp("dataset_info")
labels = {
"index_id": {
"YsaVkzl": {
"image_url": "https://i.imgur.com/YsaVkzl.jpg",
"image_path": "/path/to/IMGUR5K-Handwriting-Dataset/images/YsaVkzl.jpg",
"image_hash": "993a7cbb04a7c854d1d841b065948369"
},
"wz3wHhN": {
"image_url": "https://i.imgur.com/wz3wHhN.jpg",
"image_path": "/path/to/IMGUR5K-Handwriting-Dataset/images/wz3wHhN.jpg",
"image_hash": "9157426a98ee52f3e1e8d41fa3a99175"
},
"BRHSP23": {
"image_url": "https://i.imgur.com/BRHSP23.jpg",
"image_path": "/path/to/IMGUR5K-Handwriting-Dataset/images/BRHSP23.jpg",
"image_hash": "aab01f7ac82ae53845b01674e9e34167"
}
},
"index_to_ann_map": {
"YsaVkzl": [
"YsaVkzl_0",
"YsaVkzl_1",
"YsaVkzl_2"],
"wz3wHhN": [
"wz3wHhN_0",
"wz3wHhN_1"],
"BRHSP23": [
"BRHSP23_0"]
},
"ann_id": {
"YsaVkzl_0": {
"word": "I",
"bounding_box": "[605.33, 1150.67, 614.33, 226.33, 81.0]"
},
"YsaVkzl_1": {
"word": "am",
"bounding_box": "[783.67, 654.67, 521.0, 222.33, 56.67]"
},
"YsaVkzl_2": {
"word": "a",
"bounding_box": "[959.0, 437.0, 76.67, 201.0, 38.33]"
},
"wz3wHhN_0": {
"word": "jedi",
"bounding_box": "[783.67, 654.67, 521.0, 222.33, 56.67]"
},
"wz3wHhN_1": {
"word": "!",
"bounding_box": "[959.0, 437.0, 76.67, 201.0, 38.33]"
},
"BRHSP23_0": {
"word": "jedi",
"bounding_box": "[783.67, 654.67, 521.0, 222.33, 56.67]"
}
}
}
label_file = label_folder.join("imgur5k_annotations.json")
with open(label_file, 'w') as f:
json.dump(labels, f)
for index_id in ['YsaVkzl', 'wz3wHhN', 'BRHSP23']:
fn_i = image_folder.join(f"{index_id}.jpg")
with open(fn_i, 'wb') as f:
f.write(file.getbuffer())
return str(image_folder), str(label_file)


@pytest.fixture(scope="session")
def mock_svhn_dataset(tmpdir_factory, mock_image_stream):
root = tmpdir_factory.mktemp('datasets')
Expand Down
22 changes: 22 additions & 0 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,28 @@ def test_ic13_dataset(num_samples, rotate, mock_ic13):
_validate_dataset(ds, input_size, is_polygons=rotate)


@pytest.mark.parametrize(
"num_samples, rotate",
[
[3, True], # Actual set has 7149 train and 796 test samples
[3, False]
],
)
def test_imgur5k_dataset(num_samples, rotate, mock_imgur5k):
input_size = (512, 512)
ds = datasets.IMGUR5K(
*mock_imgur5k,
train=True,
img_transforms=Resize(input_size),
use_polygons=rotate,
)

assert len(ds) == num_samples - 1 # -1 because of the test set 90 / 10 split
assert repr(ds) == f"IMGUR5K(train={True})"
_validate_dataset(ds, input_size, is_polygons=rotate)


@pytest.mark.parametrize(
"input_size, num_samples, rotate",
[
Expand Down
22 changes: 22 additions & 0 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,28 @@ def test_ic13_dataset(mock_ic13, num_samples, rotate):
_validate_dataset(ds, input_size, is_polygons=rotate)


@pytest.mark.parametrize(
"num_samples, rotate",
[
[3, True], # Actual set has 7149 train and 796 test samples
[3, False]
],
)
def test_imgur5k_dataset(num_samples, rotate, mock_imgur5k):
input_size = (512, 512)
ds = datasets.IMGUR5K(
*mock_imgur5k,
train=True,
img_transforms=Resize(input_size),
use_polygons=rotate,
)

assert len(ds) == num_samples - 1 # -1 because of the test set 90 / 10 split
assert repr(ds) == f"IMGUR5K(train={True})"
_validate_dataset(ds, input_size, is_polygons=rotate)


@pytest.mark.parametrize(
"input_size, num_samples, rotate",
[
Expand Down

0 comments on commit 0da7ce0

Please sign in to comment.