Skip to content
This repository has been archived by the owner on Nov 14, 2023. It is now read-only.

Commit

Permalink
Create cvat_sdk.datasets, a framework-agnostic version of cvat_sdk.py…
Browse files Browse the repository at this point in the history
…torch (cvat-ai#6428)

The new `TaskDataset` class provides conveniences like per-frame
annotations, bulk data downloading, and caching without forcing a
dependency on PyTorch (and somewhat awkwardly conforming to the PyTorch
dataset interface). It also provides a few extra niceties, like easy
access to labels and original frame numbers.

Note that it's called `TaskDataset` rather than `TaskVisionDataset`, as
my plan is to keep it domain-agnostic. The `MediaElement` class is
extensible, and we can add, for example, support for point clouds, by
adding another `load_*` method.

There is currently no `ProjectDataset` equivalent, although one could
(and probably should) be added later. If we add one, we should probably
also add a `task_id` field to `Sample`.
  • Loading branch information
SpecLad authored and mikhail-treskin committed Oct 25, 2023
1 parent 82ccb2d commit 3f7e852
Show file tree
Hide file tree
Showing 13 changed files with 484 additions and 130 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[Unreleased]
### Added
- Multi-line text attributes supported (<https://github.com/opencv/cvat/pull/6458>)
- \{SDK\] `cvat_sdk.datasets`, a framework-agnostic equivalent of `cvat_sdk.pytorch`
(<https://github.com/opencv/cvat/pull/6428>)

### Changed
- TDB
Expand Down
7 changes: 7 additions & 0 deletions cvat-sdk/cvat_sdk/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from .caching import UpdatePolicy
from .common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError
from .task_dataset import TaskDataset
57 changes: 57 additions & 0 deletions cvat-sdk/cvat_sdk/datasets/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (C) 2022-2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import abc
from typing import List

import attrs
import attrs.validators
import PIL.Image

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models


class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException):
pass


@attrs.frozen
class FrameAnnotations:
"""
Contains annotations that pertain to a single frame.
"""

tags: List[models.LabeledImage] = attrs.Factory(list)
shapes: List[models.LabeledShape] = attrs.Factory(list)


class MediaElement(metaclass=abc.ABCMeta):
"""
The media part of a dataset sample.
"""

@abc.abstractmethod
def load_image(self) -> PIL.Image.Image:
"""
Loads the media data and returns it as a PIL Image object.
"""
...


@attrs.frozen
class Sample:
"""
Represents an element of a dataset.
"""

frame_index: int
"""Index of the corresponding frame in its task."""

annotations: FrameAnnotations
"""Annotations belonging to the frame."""

media: MediaElement
"""Media data of the frame."""
164 changes: 164 additions & 0 deletions cvat-sdk/cvat_sdk/datasets/task_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (C) 2022-2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

from __future__ import annotations

import zipfile
from concurrent.futures import ThreadPoolExecutor
from typing import Sequence

import PIL.Image

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.common import FrameAnnotations, MediaElement, Sample, UnsupportedDatasetError

_NUM_DOWNLOAD_THREADS = 4


class TaskDataset:
"""
Represents a task on a CVAT server as a collection of samples.
Each sample corresponds to one frame in the task, and provides access to
the corresponding annotations and media data. Deleted frames are omitted.
This class caches all data and annotations for the task on the local file system
during construction.
Limitations:
* Only tasks with image (not video) data are supported at the moment.
* Track annotations are currently not accessible.
"""

class _TaskMediaElement(MediaElement):
def __init__(self, dataset: TaskDataset, frame_index: int) -> None:
self._dataset = dataset
self._frame_index = frame_index

def load_image(self) -> PIL.Image.Image:
return self._dataset._load_frame_image(self._frame_index)

def __init__(
self,
client: cvat_sdk.core.Client,
task_id: int,
*,
update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
server that `client` is connected to.
`update_policy` determines when and if the local cache will be updated.
"""

self._logger = client.logger

cache_manager = make_cache_manager(client, update_policy)
self._task = cache_manager.retrieve_task(task_id)

if not self._task.size or not self._task.data_chunk_size:
raise UnsupportedDatasetError("The task has no data")

if self._task.data_original_chunk_type != "imageset":
raise UnsupportedDatasetError(
f"{self.__class__.__name__} only supports tasks with image chunks;"
f" current chunk type is {self._task.data_original_chunk_type!r}"
)

self._logger.info("Fetching labels...")
self._labels = tuple(self._task.get_labels())

data_meta = cache_manager.ensure_task_model(
self._task.id,
"data_meta.json",
models.DataMetaRead,
self._task.get_meta,
"data metadata",
)

active_frame_indexes = set(range(self._task.size)) - set(data_meta.deleted_frames)

self._logger.info("Downloading chunks...")

self._chunk_dir = cache_manager.chunk_dir(task_id)
self._chunk_dir.mkdir(exist_ok=True, parents=True)

needed_chunks = {index // self._task.data_chunk_size for index in active_frame_indexes}

with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool:

def ensure_chunk(chunk_index):
cache_manager.ensure_chunk(self._task, chunk_index)

for _ in pool.map(ensure_chunk, sorted(needed_chunks)):
# just need to loop through all results so that any exceptions are propagated
pass

self._logger.info("All chunks downloaded")

annotations = cache_manager.ensure_task_model(
self._task.id,
"annotations.json",
models.LabeledData,
self._task.get_annotations,
"annotations",
)

self._frame_annotations = {
frame_index: FrameAnnotations() for frame_index in sorted(active_frame_indexes)
}

for tag in annotations.tags:
# Some annotations may belong to deleted frames; skip those.
if tag.frame in self._frame_annotations:
self._frame_annotations[tag.frame].tags.append(tag)

for shape in annotations.shapes:
if shape.frame in self._frame_annotations:
self._frame_annotations[shape.frame].shapes.append(shape)

# TODO: tracks?

self._samples = [
Sample(frame_index=k, annotations=v, media=self._TaskMediaElement(self, k))
for k, v in self._frame_annotations.items()
]

@property
def labels(self) -> Sequence[models.ILabel]:
"""
Returns the labels configured in the task.
Clients must not modify the object returned by this property or its components.
"""
return self._labels

@property
def samples(self) -> Sequence[Sample]:
"""
Returns a sequence of all samples, in order of their frame indices.
Note that the frame indices may not be contiguous, as deleted frames will not be included.
Clients must not modify the object returned by this property or its components.
"""
return self._samples

def _load_frame_image(self, frame_index: int) -> PIL.Image:
assert frame_index in self._frame_annotations

chunk_index = frame_index // self._task.data_chunk_size
member_index = frame_index % self._task.data_chunk_size

with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip:
with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member:
image = PIL.Image.open(chunk_member)
image.load()

return image
8 changes: 6 additions & 2 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
#
# SPDX-License-Identifier: MIT

from .caching import UpdatePolicy
from .common import FrameAnnotations, Target, UnsupportedDatasetError
from .common import Target
from .project_dataset import ProjectVisionDataset
from .task_dataset import TaskVisionDataset
from .transforms import ExtractBoundingBoxes, ExtractSingleLabelIndex, LabeledBoxes

# isort: split
# Compatibility imports
from ..datasets.caching import UpdatePolicy
from ..datasets.common import FrameAnnotations, UnsupportedDatasetError
21 changes: 2 additions & 19 deletions cvat-sdk/cvat_sdk/pytorch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,11 @@
#
# SPDX-License-Identifier: MIT

from typing import List, Mapping
from typing import Mapping

import attrs
import attrs.validators

import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models


class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException):
pass


@attrs.frozen
class FrameAnnotations:
"""
Contains annotations that pertain to a single frame.
"""

tags: List[models.LabeledImage] = attrs.Factory(list)
shapes: List[models.LabeledShape] = attrs.Factory(list)
from cvat_sdk.datasets.common import FrameAnnotations


@attrs.frozen
Expand Down
2 changes: 1 addition & 1 deletion cvat-sdk/cvat_sdk/pytorch/project_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.caching import UpdatePolicy, make_cache_manager
from cvat_sdk.pytorch.task_dataset import TaskVisionDataset


Expand Down
Loading

0 comments on commit 3f7e852

Please sign in to comment.