Skip to content

Commit

Permalink
feat: add RandomCrop module in transforms (#448)
Browse files Browse the repository at this point in the history
* feat: add sar ckpt + perf

* feat: add randomcrop

* fix: test

* fix: test

* fix: scale and ratio

* fix: unused file

* fix: unused file

* fix: typo

* fix: typo

* fix: requested changes

* fix: flake8

* fix: test

* feat: add entry in docs
  • Loading branch information
charlesmindee committed Aug 31, 2021
1 parent 422da0e commit 51c961f
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Here are all transformations that are available through DocTR:
.. autoclass:: RandomGamma
.. autoclass:: RandomJpegQuality
.. autoclass:: RandomRotate
.. autoclass:: RandomCrop


Composing transformations
Expand Down
37 changes: 32 additions & 5 deletions doctr/transforms/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

import random
import math
from typing import List, Any, Callable, Dict, Tuple
import numpy as np

from doctr.utils.repr import NestedObject
from .. import functional as F


__all__ = ['ColorInversion', 'OneOf', 'RandomApply', 'RandomRotate']
__all__ = ['ColorInversion', 'OneOf', 'RandomApply', 'RandomRotate', 'RandomCrop']


class ColorInversion(NestedObject):
Expand Down Expand Up @@ -89,14 +90,14 @@ def __call__(self, img: Any) -> Any:


class RandomRotate(NestedObject):
"""Randomly rotate a tensor image
"""Randomly rotate a tensor image and its boxes
Args:
max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in
[-max_angle, max_angle]
expand: whether the image should be padded before the rotation
"""
def __init__(self, max_angle: float = 25., expand: bool = False) -> None:
def __init__(self, max_angle: float = 5., expand: bool = False) -> None:
self.max_angle = max_angle
self.expand = expand

Expand All @@ -105,5 +106,31 @@ def extra_repr(self) -> str:

def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]:
angle = random.uniform(-self.max_angle, self.max_angle)
img, target['boxes'] = F.rotate(img, target['boxes'], angle, self.expand)
return img, target
r_img, r_boxes = F.rotate(img, target["boxes"], angle, self.expand)
return r_img, dict(boxes=r_boxes)


class RandomCrop(NestedObject):
"""Randomly crop a tensor image and its boxes
Args:
scale: tuple of floats, relative (min_area, max_area) of the crop
ratio: tuple of float, relative (min_ratio, max_ratio) where ratio = h/w
"""
def __init__(self, scale: Tuple[float, float] = (0.08, 1.), ratio: Tuple[float, float] = (0.75, 1.33)) -> None:
self.scale = scale
self.ratio = ratio

def extra_repr(self) -> str:
return f"scale={self.scale}, ratio={self.ratio}"

def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]:
h, w = img.shape[:2]
random_scale = random.uniform(self.scale[0], self.scale[1])
random_ratio = random.uniform(self.ratio[0], self.ratio[1])
crop_h = math.sqrt(random_scale * random_ratio)
crop_w = math.sqrt(random_scale / random_ratio)
start_x, start_y = random.uniform(0, 1 - crop_w), random.uniform(0, 1 - crop_h)
crop_box = (int(start_x * w), int(start_y * h), int((start_x + crop_w) * w), int((start_y + crop_h) * h))
croped_img, crop_boxes = F.crop_detection(img, target["boxes"], crop_box)
return croped_img, dict(boxes=crop_boxes)
22 changes: 17 additions & 5 deletions test/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import torch
import numpy as np
from doctr.transforms import Resize, ColorInversion, RandomRotate
from doctr.transforms import Resize, ColorInversion, RandomRotate, RandomCrop
from doctr.transforms.functional import rotate, crop_detection


Expand Down Expand Up @@ -86,7 +86,7 @@ def test_rotate():
boxes = np.array([
[15, 20, 35, 30]
])
r_img, r_boxes = rotate(input_t, boxes, angle=12.)
r_img, r_boxes = rotate(input_t, boxes, angle=12., expand=False)
assert r_img.shape == (3, 50, 50)
assert r_img[0, 0, 0] == 0.
assert r_boxes.all() == np.array([[25., 25., 20., 10., 12.]]).all()
Expand All @@ -110,14 +110,14 @@ def test_rotate():


def test_random_rotate():
rotator = RandomRotate(max_angle=10.)
rotator = RandomRotate(max_angle=10., expand=False)
input_t = torch.ones((3, 50, 50), dtype=torch.float32)
boxes = np.array([
[15, 20, 35, 30]
])
r_img, target = rotator(input_t, dict(boxes=boxes))
r_img, r_boxes = rotator(input_t, dict(boxes=boxes))
assert r_img.shape == input_t.shape
assert abs(target["boxes"][-1, -1]) <= 10.
assert abs(r_boxes["boxes"][-1, -1]) <= 10.

# FP16 (only on GPU)
if torch.cuda.is_available():
Expand Down Expand Up @@ -148,3 +148,15 @@ def test_crop_detection():
img = torch.ones((3, 50, 50), dtype=torch.float16)
c_img, _ = crop_detection(img, abs_boxes, crop_box)
assert c_img.dtype == torch.float16


def test_random_crop():
cropper = RandomCrop()
input_t = torch.ones((50, 50, 3), dtype=torch.float32)
boxes = np.array([
[15, 20, 35, 30]
])
c_img, _ = cropper(input_t, dict(boxes=boxes))
new_h, new_w = c_img.shape[:2]
assert new_h >= 3
assert new_w >= 3
20 changes: 16 additions & 4 deletions test/tensorflow/test_transforms_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_rotate():
boxes = np.array([
[15, 20, 35, 30]
])
r_img, r_boxes = rotate(input_t, boxes, angle=12.)
r_img, r_boxes = rotate(input_t, boxes, angle=12., expand=False)
assert r_img.shape == (50, 50, 3)
assert r_img[0, 0, 0] == 0.
assert r_boxes.all() == np.array([[25., 25., 20., 10., 12.]]).all()
Expand All @@ -274,14 +274,14 @@ def test_rotate():


def test_random_rotate():
rotator = T.RandomRotate(max_angle=10.)
rotator = T.RandomRotate(max_angle=10., expand=False)
input_t = tf.ones((50, 50, 3), dtype=tf.float32)
boxes = np.array([
[15, 20, 35, 30]
])
r_img, target = rotator(input_t, dict(boxes=boxes))
r_img, r_boxes = rotator(input_t, dict(boxes=boxes))
assert r_img.shape == input_t.shape
assert abs(target["boxes"][-1, -1]) <= 10.
assert abs(r_boxes["boxes"][-1, -1]) <= 10.

# FP16
input_t = tf.ones((50, 50, 3), dtype=tf.float16)
Expand Down Expand Up @@ -311,3 +311,15 @@ def test_crop_detection():
img = tf.ones((50, 50, 3), dtype=tf.float16)
c_img, _ = crop_detection(img, rel_boxes, crop_box)
assert c_img.dtype == tf.float16


def test_random_crop():
cropper = T.RandomCrop()
input_t = tf.ones((50, 50, 3), dtype=tf.float32)
boxes = np.array([
[15, 20, 35, 30]
])
c_img, _ = cropper(input_t, dict(boxes=boxes))
new_h, new_w = c_img.shape[:2]
assert new_h >= 3
assert new_w >= 3

0 comments on commit 51c961f

Please sign in to comment.