Skip to content

Commit

Permalink
apply changes
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Dec 14, 2021
1 parent 591b343 commit 273598f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 31 deletions.
29 changes: 12 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import shutil
from io import BytesIO

import hdf5storage
Expand Down Expand Up @@ -154,28 +155,22 @@ def mock_svhn_dataset(tmpdir_factory, mock_image_stream):
svhn_root = root.mkdir('svhn')
file = BytesIO(mock_image_stream)
# ascii image names
first = np.array([[49], [46], [112], [110], [103]], dtype=np.int16) # 0.png
second = np.array([[50], [46], [112], [110], [103]], dtype=np.int16) # 1.png
third = np.array([[51], [46], [112], [110], [103]], dtype=np.int16) # 2.png
first = np.array([[49], [46], [112], [110], [103]], dtype=np.int16) # 1.png
second = np.array([[50], [46], [112], [110], [103]], dtype=np.int16) # 2.png
third = np.array([[51], [46], [112], [110], [103]], dtype=np.int16) # 3.png
# labels: label is also ascii
label = {'height': [35, 35, 35, 35], 'label': [1, 1, 3, 7],
'left': [116, 128, 137, 151], 'top': [27, 29, 29, 26],
'width': [15, 10, 17, 17]}

matcontent = {'digitStruct': {'name': [first, second, third], 'bbox': [label, label, label]}}
# mock train data
root = svhn_root.mkdir('train')
hdf5storage.write(matcontent, filename=root.join('digitStruct.mat'))
for i in range(4):
fn = root.join(f'{i}.png')
with open(fn, 'wb') as f:
f.write(file.getbuffer())
# mock test data
root = svhn_root.mkdir('test')
hdf5storage.write(matcontent, filename=root.join('digitStruct.mat'))
for i in range(4):
fn = root.join(f'{i}.png')
# Mock train data
train_root = svhn_root.mkdir('train')
hdf5storage.write(matcontent, filename=train_root.join('digitStruct.mat'))
for i in range(3):
fn = train_root.join(f'{i+1}.png')
with open(fn, 'wb') as f:
f.write(file.getbuffer())

return str(svhn_root)
# Packing data into an archive to simulate the real data set and bypass archive extraction
shutil.make_archive(svhn_root.join('svhn_train'), 'tar', str(svhn_root))
return str(root)
14 changes: 7 additions & 7 deletions tests/pytorch/test_datasets_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,23 +235,23 @@ def test_ic13_dataset(mock_ic13, size, rotate):


@pytest.mark.parametrize(
"train, input_size, size, rotate",
"input_size, size, rotate",
[
[True, [32, 128], 3, True], # Actual set has 33402 samples
[False, [32, 128], 3, False], # Actual set has 13068 samples
[[32, 128], 3, True], # Actual set has 33402 training samples and 13068 test samples
[[32, 128], 3, False],
],
)
def test_svhn(train, input_size, size, rotate, mock_svhn_dataset):
def test_svhn(input_size, size, rotate, mock_svhn_dataset):
# monkeypatch the path to temporary dataset
datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar")
datasets.SVHN.TEST = (mock_svhn_dataset, None, "svhn_test.tar")

ds = datasets.SVHN(
train=train, download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate,
train=True, download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate,
cache_dir=mock_svhn_dataset, cache_subdir="svhn",
)

assert len(ds) == size
assert repr(ds) == f"SVHN(train={train})"
assert repr(ds) == f"SVHN(train={True})"
img, target = ds[0]
assert isinstance(img, torch.Tensor)
assert img.shape == (3, *input_size)
Expand Down
14 changes: 7 additions & 7 deletions tests/tensorflow/test_datasets_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,23 +221,23 @@ def test_ic13_dataset(mock_ic13, size, rotate):


@pytest.mark.parametrize(
"train, input_size, size, rotate",
"input_size, size, rotate",
[
[True, [32, 128], 3, True], # Actual set has 33402 samples
[False, [32, 128], 3, False], # Actual set has 13068 samples
[[32, 128], 3, True], # Actual set has 33402 training samples and 13068 test samples
[[32, 128], 3, False],
],
)
def test_svhn(train, input_size, size, rotate, mock_svhn_dataset):
def test_svhn(input_size, size, rotate, mock_svhn_dataset):
# monkeypatch the path to temporary dataset
datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar")
datasets.SVHN.TEST = (mock_svhn_dataset, None, "svhn_test.tar")

ds = datasets.SVHN(
train=train, download=False, sample_transforms=Resize(input_size), rotated_bbox=rotate,
train=True, download=True, sample_transforms=Resize(input_size), rotated_bbox=rotate,
cache_dir=mock_svhn_dataset, cache_subdir="svhn",
)

assert len(ds) == size
assert repr(ds) == f"SVHN(train={train})"
assert repr(ds) == f"SVHN(train={True})"
img, target = ds[0]
assert isinstance(img, tf.Tensor)
assert img.shape == (*input_size, 3)
Expand Down

0 comments on commit 273598f

Please sign in to comment.