Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add RandomCrop module in transforms #448

Merged
merged 15 commits into from
Aug 31, 2021
5 changes: 5 additions & 0 deletions doctr/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ def __init__(
img_folder: str,
label_folder: str,
sample_transforms: Optional[Callable[[Any], Any]] = None,
geometric_transforms: Optional[Callable[[Any], Any]] = None,
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
rotated_bbox: bool = False,
**kwargs: Any,
) -> None:
super().__init__(img_folder, **kwargs)
self.sample_transforms = sample_transforms
self.geometric_transforms = geometric_transforms

self.data: List[Tuple[str, Dict[str, Any]]] = []
np_dtype = np.float16 if self.fp16 else np.float32
Expand Down Expand Up @@ -73,4 +75,7 @@ def __getitem__(
boxes[..., [0, 2]] /= w
boxes[..., [1, 3]] /= h

if self.geometric_transforms is not None:
img, boxes = self.geometric_transforms(img, boxes)

return img, dict(boxes=boxes.clip(0, 1), flags=target['flags'])
2 changes: 1 addition & 1 deletion doctr/transforms/functional/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def rotate(
img: torch.Tensor,
boxes: np.ndarray,
angle: float,
expand: bool = False,
expand: bool = True,
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[torch.Tensor, np.ndarray]:
"""Rotate image around the center, interpolation=NEAREST, pad with 0 (black)

Expand Down
2 changes: 1 addition & 1 deletion doctr/transforms/functional/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def rotate(
img: tf.Tensor,
boxes: np.ndarray,
angle: float,
expand: bool = False,
expand: bool = True,
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[tf.Tensor, np.ndarray]:
"""Rotate image around the center, interpolation=NEAREST, pad with 0 (black)

Expand Down
36 changes: 30 additions & 6 deletions doctr/transforms/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .. import functional as F


__all__ = ['ColorInversion', 'OneOf', 'RandomApply', 'RandomRotate']
__all__ = ['ColorInversion', 'OneOf', 'RandomApply', 'RandomRotate', 'RandomCrop']


class ColorInversion(NestedObject):
Expand Down Expand Up @@ -89,21 +89,45 @@ def __call__(self, img: Any) -> Any:


class RandomRotate(NestedObject):
"""Randomly rotate a tensor image
"""Randomly rotate a tensor image and its boxes

Args:
max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in
[-max_angle, max_angle]
expand: whether the image should be padded before the rotation
"""
def __init__(self, max_angle: float = 25., expand: bool = False) -> None:
def __init__(self, max_angle: float = 5., expand: bool = True) -> None:
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
self.max_angle = max_angle
self.expand = expand

def extra_repr(self) -> str:
return f"max_angle={self.max_angle}, expand={self.expand}"

def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]:
def __call__(self, img: Any, boxes: np.ndarray) -> Tuple[Any, np.ndarray]:
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
angle = random.uniform(-self.max_angle, self.max_angle)
img, target['boxes'] = F.rotate(img, target['boxes'], angle, self.expand)
return img, target
r_img, r_boxes = F.rotate(img, boxes, angle, self.expand)
return r_img, r_boxes


class RandomCrop(NestedObject):
"""Randomly crop a tensor image and its boxes

Args:
min_wh: float, min relative width/height of the crop
max_wh: float, max relative width/height of the crop
"""
def __init__(self, min_wh: float = 0.4, max_wh: float = 0.8) -> None:
self.min_wh = min_wh
self.max_wh = max_wh

def extra_repr(self) -> str:
return f"min_wh={self.min_wh}, max_wh={self.max_wh}"

def __call__(self, img: Any, boxes: np.ndarray) -> Tuple[Any, np.ndarray]:
h, w = img.shape[:2]
crop_w = random.uniform(self.min_wh, self.max_wh)
crop_h = random.uniform(self.min_wh, self.max_wh)
start_x, start_y = random.uniform(0, 1 - crop_w), random.uniform(0, 1 - crop_h)
crop_box = (int(start_x * w), int(start_y * h), int((start_x + crop_w) * w), int((start_y + crop_h) * h))
croped_img, crop_boxes = F.crop_detection(img, boxes, crop_box)
return croped_img, crop_boxes
charlesmindee marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 17 additions & 5 deletions test/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import torch
import numpy as np
from doctr.transforms import Resize, ColorInversion, RandomRotate
from doctr.transforms import Resize, ColorInversion, RandomRotate, RandomCrop
from doctr.transforms.functional import rotate, crop_detection


Expand Down Expand Up @@ -86,7 +86,7 @@ def test_rotate():
boxes = np.array([
[15, 20, 35, 30]
])
r_img, r_boxes = rotate(input_t, boxes, angle=12.)
r_img, r_boxes = rotate(input_t, boxes, angle=12., expand=False)
assert r_img.shape == (3, 50, 50)
assert r_img[0, 0, 0] == 0.
assert r_boxes.all() == np.array([[25., 25., 20., 10., 12.]]).all()
Expand All @@ -110,19 +110,19 @@ def test_rotate():


def test_random_rotate():
rotator = RandomRotate(max_angle=10.)
rotator = RandomRotate(max_angle=10., expand=False)
input_t = torch.ones((3, 50, 50), dtype=torch.float32)
boxes = np.array([
[15, 20, 35, 30]
])
r_img, target = rotator(input_t, dict(boxes=boxes))
r_img, target = rotator(input_t, boxes=boxes)
assert r_img.shape == input_t.shape
assert abs(target["boxes"][-1, -1]) <= 10.

# FP16 (only on GPU)
if torch.cuda.is_available():
input_t = torch.ones((3, 50, 50), dtype=torch.float16).cuda()
r_img, _ = rotator(input_t, dict(boxes=boxes))
r_img, _ = rotator(input_t, boxes=boxes)
assert r_img.dtype == torch.float16


Expand All @@ -148,3 +148,15 @@ def test_crop_detection():
img = torch.ones((3, 50, 50), dtype=torch.float16)
c_img, _ = crop_detection(img, abs_boxes, crop_box)
assert c_img.dtype == torch.float16


def test_random_crop():
cropper = RandomCrop()
input_t = torch.ones((50, 50, 3), dtype=torch.float32)
boxes = np.array([
[15, 20, 35, 30]
])
c_img, _ = cropper(input_t, boxes=boxes)
new_h, new_w = c_img[:2]
assert torch.all(20 <= new_h <= 40)
assert torch.all(20 <= new_w <= 40)
20 changes: 16 additions & 4 deletions test/tensorflow/test_transforms_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_rotate():
boxes = np.array([
[15, 20, 35, 30]
])
r_img, r_boxes = rotate(input_t, boxes, angle=12.)
r_img, r_boxes = rotate(input_t, boxes, angle=12., expand=False)
assert r_img.shape == (50, 50, 3)
assert r_img[0, 0, 0] == 0.
assert r_boxes.all() == np.array([[25., 25., 20., 10., 12.]]).all()
Expand All @@ -274,18 +274,18 @@ def test_rotate():


def test_random_rotate():
rotator = T.RandomRotate(max_angle=10.)
rotator = T.RandomRotate(max_angle=10., expand=False)
input_t = tf.ones((50, 50, 3), dtype=tf.float32)
boxes = np.array([
[15, 20, 35, 30]
])
r_img, target = rotator(input_t, dict(boxes=boxes))
r_img, target = rotator(input_t, boxes=boxes)
assert r_img.shape == input_t.shape
assert abs(target["boxes"][-1, -1]) <= 10.

# FP16
input_t = tf.ones((50, 50, 3), dtype=tf.float16)
r_img, _ = rotator(input_t, dict(boxes=boxes))
r_img, _ = rotator(input_t, boxes=boxes)
assert r_img.dtype == tf.float16


Expand All @@ -311,3 +311,15 @@ def test_crop_detection():
img = tf.ones((50, 50, 3), dtype=tf.float16)
c_img, _ = crop_detection(img, rel_boxes, crop_box)
assert c_img.dtype == tf.float16


def test_random_crop():
cropper = T.RandomCrop()
input_t = tf.ones((50, 50, 3), dtype=tf.float32)
boxes = np.array([
[15, 20, 35, 30]
])
c_img, _ = cropper(input_t, boxes=boxes)
new_h, new_w = c_img[:2]
assert 20 <= new_h <= 40
assert 20 <= new_w <= 40