Skip to content

Commit

Permalink
feat: Added random noise augmentation to object detection (#654)
Browse files Browse the repository at this point in the history
* chore: Added rotations as augmentations

* chore: Added miinor augmentations

1) random horizontal flip
2) random vertical flip

* fix: Added the height, width of the image for random-horizontal/vertical_flips

* chore: added rotations

* feat: Added shadows as augmentations

* chore: Used numpy for random bool response

* feat: Removed geometric augmentations.

* chore: Added photometric augmentation in dataloader

* fix: Fixed .gitignore

* chore: Removed unnecessary functions related to photometric augmentations

* chore: Reverted last commit

* feat: Removed shadows and added randomgaussian noise

* fix: Fixed style

* fix: Fixed the distribution range

* test: Added unittest for GaussianNoise

* fix: Reverted to original script

* fix: Reverted to old train script

* chore: Merged with main

* fix: Fixed typo

* style: Fixed brackets indentation
  • Loading branch information
SiddhantBahuguna committed Dec 29, 2021
1 parent c72da96 commit 8e7b0ee
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
31 changes: 30 additions & 1 deletion doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T

__all__ = ['Resize']
__all__ = ['Resize', 'GaussianNoise']


class Resize(T.Resize):
Expand Down Expand Up @@ -53,3 +53,32 @@ def __repr__(self) -> str:
if self.preserve_aspect_ratio:
_repr += f", preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}"
return f"{self.__class__.__name__}({_repr})"


class GaussianNoise(torch.nn.Module):
"""Adds Gaussian Noise to an inout image of type torch.tensor
Example::
>>> from doctr.transforms import GaussianNoise
>>> import torch
>>> transfo = GaussianNoise(0., 1.)
>>> out = transfo(torch.rand((3, 224, 224)))
Args:
mean : mean of the gaussian distribution
std : std of the gaussian distribution
"""
def __init__(self, mean: float = 0., std: float = 1.) -> None:
super().__init__()
self.std = std
self.mean = mean

def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dtype == torch.uint8:
return (x + 255 * (self.mean + self.std * torch.rand(x.shape, device=x.device))).round().clamp(0, 255).to(
dtype=torch.uint8)
else:
return (x + self.mean + self.std * 2 * torch.rand_like(x) - self.std).clamp(0, 1)

def extra_repr(self) -> str:
return f"mean={self.mean}, std={self.std}"
5 changes: 3 additions & 2 deletions references/obj_detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,13 @@ def main(args):
return

st = time.time()
# Load both train and val data generators
# Load train data generators
train_set = DocArtefacts(
train=True,
download=True,
img_transforms=Compose([
T.Resize((args.input_size, args.input_size)),
T.RandomApply(T.GaussianNoise(0., 0.25), p=0.5),
ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02),
T.RandomApply(GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 3)), .3),
]))
Expand Down Expand Up @@ -334,7 +335,7 @@ def parse_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('arch', type=str, help='text-detection model to train')
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('--epochs', type=int, default=20, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for training')
parser.add_argument('--device', default=None, type=int, help='device')
parser.add_argument('--input_size', type=int, default=1024, help='model input size, H = W')
Expand Down
27 changes: 26 additions & 1 deletion tests/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from doctr.transforms import ColorInversion, RandomCrop, RandomRotate, Resize
from doctr.transforms import ColorInversion, GaussianNoise, RandomCrop, RandomRotate, Resize
from doctr.transforms.functional import crop_detection, rotate


Expand Down Expand Up @@ -164,3 +164,28 @@ def test_random_crop():
new_h, new_w = c_img.shape[:2]
assert new_h >= 3
assert new_w >= 3


@pytest.mark.parametrize(
"input_dtype,input_shape",
[
[torch.float32, (3, 32, 32)],
[torch.uint8, (3, 32, 32)],
]
)
def test_gaussian_noise(input_dtype, input_shape):
transform = GaussianNoise(0., 1.)
input_t = torch.rand(input_shape, dtype=torch.float32)
if input_dtype == torch.uint8:
input_t = (255 * input_t).round()
input_t = input_t.to(dtype=input_dtype)
transformed = transform(input_t)
assert isinstance(transformed, torch.Tensor)
assert transformed.shape == input_shape
assert transformed.dtype == input_dtype
assert torch.any(transformed != input_t)
assert torch.all(transformed >= 0)
if input_dtype == torch.uint8:
assert torch.all(transformed <= 255)
else:
assert torch.all(transformed <= 1.)

0 comments on commit 8e7b0ee

Please sign in to comment.