Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add RandomCrop module in transforms #448

Merged
merged 15 commits into from
Aug 31, 2021
37 changes: 32 additions & 5 deletions doctr/transforms/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import random
import math
from typing import List, Any, Callable, Dict, Tuple
import numpy as np

from doctr.utils.repr import NestedObject
from .. import functional as F


__all__ = ['ColorInversion', 'OneOf', 'RandomApply', 'RandomRotate']
__all__ = ['ColorInversion', 'OneOf', 'RandomApply', 'RandomRotate', 'RandomCrop']


class ColorInversion(NestedObject):
Expand Down Expand Up @@ -89,14 +90,14 @@ def __call__(self, img: Any) -> Any:


class RandomRotate(NestedObject):
"""Randomly rotate a tensor image
"""Randomly rotate a tensor image and its boxes
Args:
max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in
[-max_angle, max_angle]
expand: whether the image should be padded before the rotation
"""
def __init__(self, max_angle: float = 25., expand: bool = False) -> None:
def __init__(self, max_angle: float = 5., expand: bool = False) -> None:
self.max_angle = max_angle
self.expand = expand

Expand All @@ -105,5 +106,31 @@ def extra_repr(self) -> str:

def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]:
angle = random.uniform(-self.max_angle, self.max_angle)
img, target['boxes'] = F.rotate(img, target['boxes'], angle, self.expand)
return img, target
r_img, r_boxes = F.rotate(img, target["boxes"], angle, self.expand)
return r_img, dict(boxes=r_boxes)


class RandomCrop(NestedObject):
"""Randomly crop a tensor image and its boxes
Args:
scale: tuple of floats, relative (min_area, max_area) of the crop
ratio: tuple of float, relative (min_ratio, max_ratio) where ratio = h/w
"""
def __init__(self, scale: Tuple[float, float] = (0.08, 1.), ratio: Tuple[float, float] = (0.75, 1.33)) -> None:
self.scale = scale
self.ratio = ratio

def extra_repr(self) -> str:
return f"scale={self.scale}, ratio={self.ratio}"

def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]:
h, w = img.shape[:2]
random_scale = random.uniform(self.scale[0], self.scale[1])
random_ratio = random.uniform(self.ratio[0], self.ratio[1])
crop_h = math.sqrt(random_scale * random_ratio)
crop_w = math.sqrt(random_scale / random_ratio)
start_x, start_y = random.uniform(0, 1 - crop_w), random.uniform(0, 1 - crop_h)
crop_box = (int(start_x * w), int(start_y * h), int((start_x + crop_w) * w), int((start_y + crop_h) * h))
croped_img, crop_boxes = F.crop_detection(img, target["boxes"], crop_box)
return croped_img, dict(boxes=crop_boxes)
22 changes: 17 additions & 5 deletions test/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import torch
import numpy as np
from doctr.transforms import Resize, ColorInversion, RandomRotate
from doctr.transforms import Resize, ColorInversion, RandomRotate, RandomCrop
from doctr.transforms.functional import rotate, crop_detection


Expand Down Expand Up @@ -86,7 +86,7 @@ def test_rotate():
boxes = np.array([
[15, 20, 35, 30]
])
r_img, r_boxes = rotate(input_t, boxes, angle=12.)
r_img, r_boxes = rotate(input_t, boxes, angle=12., expand=False)
assert r_img.shape == (3, 50, 50)
assert r_img[0, 0, 0] == 0.
assert r_boxes.all() == np.array([[25., 25., 20., 10., 12.]]).all()
Expand All @@ -110,14 +110,14 @@ def test_rotate():


def test_random_rotate():
rotator = RandomRotate(max_angle=10.)
rotator = RandomRotate(max_angle=10., expand=False)
input_t = torch.ones((3, 50, 50), dtype=torch.float32)
boxes = np.array([
[15, 20, 35, 30]
])
r_img, target = rotator(input_t, dict(boxes=boxes))
r_img, r_boxes = rotator(input_t, dict(boxes=boxes))
assert r_img.shape == input_t.shape
assert abs(target["boxes"][-1, -1]) <= 10.
assert abs(r_boxes["boxes"][-1, -1]) <= 10.

# FP16 (only on GPU)
if torch.cuda.is_available():
Expand Down Expand Up @@ -148,3 +148,15 @@ def test_crop_detection():
img = torch.ones((3, 50, 50), dtype=torch.float16)
c_img, _ = crop_detection(img, abs_boxes, crop_box)
assert c_img.dtype == torch.float16


def test_random_crop():
cropper = RandomCrop()
input_t = torch.ones((50, 50, 3), dtype=torch.float32)
boxes = np.array([
[15, 20, 35, 30]
])
c_img, _ = cropper(input_t, dict(boxes=boxes))
new_h, new_w = c_img.shape[:2]
assert new_h >= 3
assert new_w >= 3
20 changes: 16 additions & 4 deletions test/tensorflow/test_transforms_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_rotate():
boxes = np.array([
[15, 20, 35, 30]
])
r_img, r_boxes = rotate(input_t, boxes, angle=12.)
r_img, r_boxes = rotate(input_t, boxes, angle=12., expand=False)
assert r_img.shape == (50, 50, 3)
assert r_img[0, 0, 0] == 0.
assert r_boxes.all() == np.array([[25., 25., 20., 10., 12.]]).all()
Expand All @@ -274,14 +274,14 @@ def test_rotate():


def test_random_rotate():
rotator = T.RandomRotate(max_angle=10.)
rotator = T.RandomRotate(max_angle=10., expand=False)
input_t = tf.ones((50, 50, 3), dtype=tf.float32)
boxes = np.array([
[15, 20, 35, 30]
])
r_img, target = rotator(input_t, dict(boxes=boxes))
r_img, r_boxes = rotator(input_t, dict(boxes=boxes))
assert r_img.shape == input_t.shape
assert abs(target["boxes"][-1, -1]) <= 10.
assert abs(r_boxes["boxes"][-1, -1]) <= 10.

# FP16
input_t = tf.ones((50, 50, 3), dtype=tf.float16)
Expand Down Expand Up @@ -311,3 +311,15 @@ def test_crop_detection():
img = tf.ones((50, 50, 3), dtype=tf.float16)
c_img, _ = crop_detection(img, rel_boxes, crop_box)
assert c_img.dtype == tf.float16


def test_random_crop():
cropper = T.RandomCrop()
input_t = tf.ones((50, 50, 3), dtype=tf.float32)
boxes = np.array([
[15, 20, 35, 30]
])
c_img, _ = cropper(input_t, dict(boxes=boxes))
new_h, new_w = c_img.shape[:2]
assert new_h >= 3
assert new_w >= 3