Skip to content

Commit

Permalink
fix: Recognition dataset merging (#376)
Browse files Browse the repository at this point in the history
* fix: dataset merging

* fix: memory

* fix: memory error

* fix: memory error

* fix: comprehension list

* fix: merge_datasets
  • Loading branch information
charlesmindee committed Jul 9, 2021
1 parent 43e8564 commit 011f934
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
10 changes: 10 additions & 0 deletions doctr/datasets/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import json
from typing import Tuple, List, Optional, Callable, Any
from pathlib import Path

from .datasets import AbstractDataset

Expand Down Expand Up @@ -46,3 +47,12 @@ def __init__(
if not isinstance(label, str):
raise KeyError("Image is not in referenced in label file")
self.data.append((img_path, label))

def merge_dataset(self, ds: AbstractDataset) -> None:
# Update data with new root for self
self.data = [(str(Path(self.root).joinpath(img_path)), label) for img_path, label in self.data]
# Define new root
self.root = Path("/")
# Merge with ds data
for img_path, label in ds.data:
self.data.append((str(Path(ds.root).joinpath(img_path)), label))
4 changes: 4 additions & 0 deletions test/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler
from copy import deepcopy

from doctr import datasets
from doctr.transforms import Resize
Expand Down Expand Up @@ -126,6 +127,9 @@ def test_recognition_dataset(mock_image_folder, mock_recognition_label):
ds = datasets.RecognitionDataset(img_folder=mock_image_folder, labels_path=mock_recognition_label, fp16=True)
image, label = ds[0]
assert image.dtype == torch.float16
ds2, ds3 = deepcopy(ds), deepcopy(ds)
ds2.merge_dataset(ds3)
assert len(ds2) == 2 * len(ds)


def test_ocrdataset(mock_ocrdataset):
Expand Down
4 changes: 4 additions & 0 deletions test/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import numpy as np
import tensorflow as tf
from copy import deepcopy

from doctr import datasets
from doctr.transforms import Resize
Expand Down Expand Up @@ -112,6 +113,9 @@ def test_recognition_dataset(mock_image_folder, mock_recognition_label):
ds = datasets.RecognitionDataset(img_folder=mock_image_folder, labels_path=mock_recognition_label, fp16=True)
image, _ = ds[0]
assert image.dtype == tf.float16
ds2, ds3 = deepcopy(ds), deepcopy(ds)
ds2.merge_dataset(ds3)
assert len(ds2) == 2 * len(ds)


def test_ocrdataset(mock_ocrdataset):
Expand Down

0 comments on commit 011f934

Please sign in to comment.