Skip to content

Commit

Permalink
[Datumaro] Reducing nesting of tests (#1875)
Browse files Browse the repository at this point in the history
* Add `Dataset.from_iterable` constructor
* Simplify creation of `Dataset` objects in common simple cases
* Refactor tests
  • Loading branch information
KochankovID authored Jul 20, 2020
1 parent e372589 commit 7ecdcf1
Show file tree
Hide file tree
Showing 10 changed files with 1,363 additions and 1,578 deletions.
52 changes: 52 additions & 0 deletions datumaro/datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,33 @@ def __eq__(self, other):
class LabelCategories(Categories):
Category = namedtuple('Category', ['name', 'parent', 'attributes'])

@classmethod
def from_iterable(cls, iterable):
"""Generation of LabelCategories from iterable object
Args:
iterable ([type]): This iterable object can be:
1)simple str - will generate one Category with str as name
2)list of str - will interpreted as list of Category names
3)list of positional argumetns - will generate Categories
with this arguments
Returns:
LabelCategories: LabelCategories object
"""
temp_categories = cls()

if isinstance(iterable, str):
iterable = [[iterable]]

for category in iterable:
if isinstance(category, str):
category = [category]
temp_categories.add(*category)

return temp_categories

def __init__(self, items=None, attributes=None):
super().__init__(attributes=attributes)

Expand Down Expand Up @@ -482,6 +509,31 @@ def iou(self, other):
class PointsCategories(Categories):
Category = namedtuple('Category', ['labels', 'joints'])

@classmethod
def from_iterable(cls, iterable):
"""Generation of PointsCategories from iterable object
Args:
iterable ([type]): This iterable object can be:
1)simple int - will generate one Category with int as label
2)list of int - will interpreted as list of Category labels
3)list of positional argumetns - will generate Categories
with this arguments
Returns:
PointsCategories: PointsCategories object
"""
temp_categories = cls()

if isinstance(iterable, int):
iterable = [[iterable]]

for category in iterable:
if isinstance(category, int):
category = [category]
temp_categories.add(*category)
return temp_categories

def __init__(self, items=None, attributes=None):
super().__init__(attributes=attributes)

Expand Down
32 changes: 31 additions & 1 deletion datumaro/datumaro/components/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from datumaro.components.config import Config, DEFAULT_FORMAT
from datumaro.components.config_model import (Model, Source,
PROJECT_DEFAULT_CONFIG, PROJECT_SCHEMA)
from datumaro.components.extractor import Extractor
from datumaro.components.extractor import Extractor, LabelCategories,\
AnnotationType
from datumaro.components.launcher import ModelTransform
from datumaro.components.dataset_filter import \
XPathDatasetFilter, XPathAnnotationsFilter
Expand Down Expand Up @@ -319,6 +320,35 @@ def categories(self):
return self._parent.categories()

class Dataset(Extractor):
@classmethod
def from_iterable(cls, iterable, categories=None):
"""Generation of Dataset from iterable object
Args:
iterable: Iterable object contains DatasetItems
categories (dict, optional): You can pass dict of categories or
you can pass list of names. It'll interpreted as list of names of
LabelCategories. Defaults to {}.
Returns:
Dataset: Dataset object
"""

if isinstance(categories, list):
categories = {AnnotationType.label : LabelCategories.from_iterable(categories)}

if not categories:
categories = {}

class tmpExtractor(Extractor):
def __iter__(self):
return iter(iterable)

def categories(self):
return categories

return cls.from_extractors(tmpExtractor())

@classmethod
def from_extractors(cls, *sources):
# merge categories
Expand Down
Loading

0 comments on commit 7ecdcf1

Please sign in to comment.