Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ImageDataset.split #846

Merged
merged 7 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/safeds/data/image/containers/_image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def __contains__(self, item: object) -> bool:
Returns
-------
has_item:
Weather the given item is in this image list
Whether the given item is in this image list
"""
return isinstance(item, Image) and self.has_image(item)

Expand Down Expand Up @@ -524,7 +524,7 @@ def has_image(self, image: Image) -> bool:
Returns
-------
has_image:
Weather the given image is in this image list
Whether the given image is in this image list
"""

# ------------------------------------------------------------------------------------------------------------------
Expand Down
126 changes: 99 additions & 27 deletions src/safeds/data/labeled/containers/_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ImageDataset(Dataset[ImageList, Out_co]):
batch_size:
the batch size used for training
shuffle:
weather the data should be shuffled after each epoch of training
whether the data should be shuffled after each epoch of training
"""

def __init__(self, input_data: ImageList, output_data: Out_co, batch_size: int = 1, shuffle: bool = False) -> None:
Expand Down Expand Up @@ -108,13 +108,13 @@ def __iter__(self) -> ImageDataset:
return im_ds

def __next__(self) -> tuple[Tensor, Tensor]:
if self._next_batch_index * self._batch_size >= len(self._input):
if self._next_batch_index * self._batch_size >= len(self._shuffle_tensor_indices):
raise StopIteration
self._next_batch_index += 1
return self._get_batch(self._next_batch_index - 1)

def __len__(self) -> int:
return self._input.image_count
return len(self._shuffle_tensor_indices)

def __eq__(self, other: object) -> bool:
"""
Expand All @@ -138,6 +138,7 @@ def __eq__(self, other: object) -> bool:
and isinstance(other._output, type(self._output))
and (self._input == other._input)
and (self._output == other._output)
and (self._shuffle_tensor_indices.tolist() == other._shuffle_tensor_indices.tolist())
)

def __hash__(self) -> int:
Expand All @@ -149,7 +150,13 @@ def __hash__(self) -> int:
hash:
the hash value
"""
return _structural_hash(self._input, self._output, self._shuffle_after_epoch, self._batch_size)
return _structural_hash(
self._input,
self._output,
self._shuffle_after_epoch,
self._batch_size,
self._shuffle_tensor_indices.tolist(),
)

def __sizeof__(self) -> int:
"""
Expand Down Expand Up @@ -205,7 +212,7 @@ def get_input(self) -> ImageList:
input:
the input data of this dataset
"""
return self._sort_image_list_with_shuffle_tensor_indices(self._input)
return self._sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(self._input)

def get_output(self) -> Out_co:
"""
Expand All @@ -222,19 +229,25 @@ def get_output(self) -> Out_co:
elif isinstance(output, _ColumnAsTensor):
return output._to_column(self._shuffle_tensor_indices) # type: ignore[return-value]
else:
return self._sort_image_list_with_shuffle_tensor_indices(self._output) # type: ignore[return-value]
return self._sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(self._output) # type: ignore[return-value]

def _sort_image_list_with_shuffle_tensor_indices(self, image_list: _SingleSizeImageList) -> _SingleSizeImageList:
def _sort_image_list_with_shuffle_tensor_indices_reduce_if_necessary(
self,
image_list: _SingleSizeImageList,
) -> _SingleSizeImageList:
shuffled_image_list = _SingleSizeImageList()
shuffled_image_list._tensor = image_list._tensor
shuffled_image_list._indices_to_tensor_positions = {
index: self._shuffle_tensor_indices[tensor_position].item()
for index, tensor_position in image_list._indices_to_tensor_positions.items()
tensor_pos = [
image_list._indices_to_tensor_positions[shuffled_index]
for shuffled_index in sorted(self._shuffle_tensor_indices.tolist())
]
temp_pos = {
shuffled_index: new_index for new_index, shuffled_index in enumerate(self._shuffle_tensor_indices.tolist())
}
shuffled_image_list._tensor = image_list._tensor[tensor_pos]
shuffled_image_list._tensor_positions_to_indices = [
index
for index, _ in sorted(shuffled_image_list._indices_to_tensor_positions.items(), key=lambda item: item[1])
new_index for _, new_index in sorted(temp_pos.items(), key=lambda item: item[0])
]
shuffled_image_list._indices_to_tensor_positions = shuffled_image_list._calc_new_indices_to_tensor_positions()
return shuffled_image_list

def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[Tensor, Tensor]:
Expand All @@ -247,18 +260,18 @@ def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[

_check_bounds("batch_size", batch_size, lower_bound=_ClosedBound(1))

if batch_number < 0 or batch_size * batch_number >= len(self._input):
if batch_number < 0 or batch_size * batch_number >= len(self._shuffle_tensor_indices):
raise IndexOutOfBoundsError(batch_size * batch_number)
max_index = (
batch_size * (batch_number + 1) if batch_size * (batch_number + 1) < len(self._input) else len(self._input)
batch_size * (batch_number + 1)
if batch_size * (batch_number + 1) < len(self._shuffle_tensor_indices)
else len(self._shuffle_tensor_indices)
)
input_tensor = (
self._input._tensor[
self._shuffle_tensor_indices[
[
self._input._indices_to_tensor_positions[index]
for index in range(batch_size * batch_number, max_index)
]
[
self._input._indices_to_tensor_positions[index]
for index in self._shuffle_tensor_indices[batch_size * batch_number : max_index].tolist()
]
].to(torch.float32)
/ 255
Expand All @@ -267,11 +280,9 @@ def _get_batch(self, batch_number: int, batch_size: int | None = None) -> tuple[
if isinstance(self._output, _SingleSizeImageList):
output_tensor = (
self._output._tensor[
self._shuffle_tensor_indices[
[
self._output._indices_to_tensor_positions[index]
for index in range(batch_size * batch_number, max_index)
]
[
self._input._indices_to_tensor_positions[index]
for index in self._shuffle_tensor_indices[batch_size * batch_number : max_index].tolist()
]
].to(torch.float32)
/ 255
Expand All @@ -284,7 +295,7 @@ def shuffle(self) -> ImageDataset[Out_co]:
"""
Return a new `ImageDataset` with shuffled data.

The original dataset list is not modified.
The original dataset is not modified.

Returns
-------
Expand All @@ -296,10 +307,71 @@ def shuffle(self) -> ImageDataset[Out_co]:
_init_default_device()

im_dataset: ImageDataset[Out_co] = copy.copy(self)
im_dataset._shuffle_tensor_indices = torch.randperm(len(self))
im_dataset._shuffle_tensor_indices = self._shuffle_tensor_indices[
torch.randperm(len(self._shuffle_tensor_indices))
]
im_dataset._next_batch_index = 0
return im_dataset

def split(
self,
percentage_in_first: float,
*,
shuffle: bool = True,
) -> tuple[ImageDataset[Out_co], ImageDataset[Out_co]]:
"""
Create two image datasets by splitting the data of the current dataset.

The first dataset contains a percentage of the data specified by `percentage_in_first`, and the second dataset
contains the remaining data.

The original dataset is not modified.
By default, the data is shuffled before splitting. You can disable this by setting `shuffle` to False.

Parameters
----------
percentage_in_first:
The percentage of data to include in the first dataset. Must be between 0 and 1.
shuffle:
Whether to shuffle the data before splitting.

Returns
-------
first_dataset:
The first dataset.
second_dataset:
The second dataset.

Raises
------
OutOfBoundsError
If `percentage_in_first` is not between 0 and 1.
"""
import torch

_check_bounds(
"percentage_in_first",
percentage_in_first,
lower_bound=_ClosedBound(0),
upper_bound=_ClosedBound(1),
)

first_dataset: ImageDataset[Out_co] = copy.copy(self)
second_dataset: ImageDataset[Out_co] = copy.copy(self)

if shuffle:
shuffled_indices = torch.randperm(len(self._shuffle_tensor_indices))
else:
shuffled_indices = torch.arange(len(self._shuffle_tensor_indices))

first_dataset._shuffle_tensor_indices, second_dataset._shuffle_tensor_indices = shuffled_indices.split(
[
round(percentage_in_first * len(self)),
len(self) - round(percentage_in_first * len(self)),
],
)
return first_dataset, second_dataset


class _TableAsTensor:
def __init__(self, table: Table) -> None:
Expand Down
103 changes: 103 additions & 0 deletions tests/safeds/data/labeled/containers/test_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,109 @@ def test_get_batch_device(self, device: Device) -> None:
assert batch[1].device == _get_device()


@pytest.mark.parametrize("device", get_devices(), ids=get_devices_ids())
@pytest.mark.parametrize("shuffle", [True, False])
class TestSplit:

@pytest.mark.parametrize(
"output",
[
Column("images", images_all()[:4] + images_all()[5:]),
Table(
{
"0": [1, 0, 0, 0, 0, 0],
"1": [0, 1, 0, 0, 0, 0],
"2": [0, 0, 1, 0, 0, 0],
"3": [0, 0, 0, 1, 0, 0],
"4": [0, 0, 0, 0, 1, 0],
"5": [0, 0, 0, 0, 0, 1],
},
),
_EmptyImageList(),
],
ids=["Column", "Table", "ImageList"],
)
def test_should_split(self, device: Device, shuffle: bool, output: Column | Table | ImageList) -> None:
configure_test_with_device(device)
image_list = ImageList.from_files(resolve_resource_path(images_all())).remove_duplicate_images().resize(10, 10)
if isinstance(output, _EmptyImageList):
output = image_list
image_dataset = ImageDataset(image_list, output) # type: ignore[type-var]
image_dataset1, image_dataset2 = image_dataset.split(0.4, shuffle=shuffle)
offset = len(image_dataset1)
assert len(image_dataset1) == round(0.4 * len(image_dataset))
assert len(image_dataset2) == len(image_dataset) - offset
assert len(image_dataset1.get_input()) == round(0.4 * len(image_dataset))
assert len(image_dataset2.get_input()) == len(image_dataset) - offset
im1_output = image_dataset1.get_output()
im2_output = image_dataset2.get_output()
if isinstance(im1_output, Table):
assert im1_output.row_count == round(0.4 * len(image_dataset))
else:
assert len(im1_output) == round(0.4 * len(image_dataset))
if isinstance(im2_output, Table):
assert im2_output.row_count == len(image_dataset) - offset
else:
assert len(im2_output) == len(image_dataset) - offset

assert image_dataset != image_dataset1
assert image_dataset != image_dataset2
assert image_dataset1 != image_dataset2

for i, image in enumerate(image_dataset1.get_input().to_images()):
index = image_list.index(image)[0]
if not shuffle:
assert index == i
out = image_dataset1.get_output()
if isinstance(out, ImageList):
assert image_list.index(out.get_image(i))[0] == index
elif isinstance(out, Column) and isinstance(output, Column):
assert output.to_list().index(out.to_list()[i]) == index
elif isinstance(out, Table) and isinstance(output, Table):
assert output.get_column(str(index)).to_list()[index] == 1

for i, image in enumerate(image_dataset2.get_input().to_images()):
index = image_list.index(image)[0]
if not shuffle:
assert index == i + offset
out = image_dataset2.get_output()
if isinstance(out, ImageList):
assert image_list.index(out.get_image(i))[0] == index
elif isinstance(out, Column) and isinstance(output, Column):
assert output.to_list().index(out.to_list()[i]) == index
elif isinstance(out, Table) and isinstance(output, Table):
assert output.get_column(str(index)).to_list()[index] == 1

image_dataset._batch_size = len(image_dataset)
image_dataset1._batch_size = 1
image_dataset2._batch_size = 1
image_dataset_batch = next(iter(image_dataset))

for i, b in enumerate(iter(image_dataset1)):
assert b[0] in image_dataset_batch[0]
index = (b[0] == image_dataset_batch[0]).all(dim=[1, 2, 3]).nonzero()[0][0]
if not shuffle:
assert index == i
assert torch.all(torch.eq(b[0], image_dataset_batch[0][index]))
assert torch.all(torch.eq(b[1], image_dataset_batch[1][index]))

for i, b in enumerate(iter(image_dataset2)):
assert b[0] in image_dataset_batch[0]
index = (b[0] == image_dataset_batch[0]).all(dim=[1, 2, 3]).nonzero()[0][0]
if not shuffle:
assert index == i + offset
assert torch.all(torch.eq(b[0], image_dataset_batch[0][index]))
assert torch.all(torch.eq(b[1], image_dataset_batch[1][index]))

@pytest.mark.parametrize("percentage", [-1, -0.1, 1.1, 2])
def test_should_raise(self, device: Device, shuffle: bool, percentage: float) -> None:
configure_test_with_device(device)
image_list = ImageList.from_files(resolve_resource_path(images_all())).resize(10, 10)
image_dataset = ImageDataset(image_list, Column("images", images_all()))
with pytest.raises(OutOfBoundsError):
image_dataset.split(percentage, shuffle=shuffle)


@pytest.mark.parametrize("device", get_devices(), ids=get_devices_ids())
class TestTableAsTensor:
def test_should_raise_if_not_one_hot_encoded(self, device: Device) -> None:
Expand Down