Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added random noise augmentation to object detection #654

Merged
merged 30 commits into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
327c244
chore: Added rotations as augmentations
SiddhantBahuguna Nov 24, 2021
77402de
chore: Added miinor augmentations
SiddhantBahuguna Nov 29, 2021
d82786f
fix: Added the height, width of the image for random-horizontal/verti…
SiddhantBahuguna Nov 29, 2021
1ee4810
Merge branch 'main' into minor_aug
SiddhantBahuguna Nov 29, 2021
41b9dd8
chore: added rotations
SiddhantBahuguna Dec 2, 2021
1c71c3c
Merge branch 'main' into minor_aug
SiddhantBahuguna Dec 2, 2021
e8e3791
feat: Added shadows as augmentations
SiddhantBahuguna Dec 7, 2021
e513c04
feat: Added shadows as augmentations
SiddhantBahuguna Dec 7, 2021
349ea72
chore: Used numpy for random bool response
SiddhantBahuguna Dec 7, 2021
d3e537b
fix: FIxed the conflicts with the main
SiddhantBahuguna Dec 17, 2021
fad5b7b
fix: Merged main with this PR fixing timeouts in CI
SiddhantBahuguna Dec 18, 2021
2617e27
feat: Removed geometric augmentations.
SiddhantBahuguna Dec 20, 2021
76cffad
chore: Added photometric augmentation in dataloader
SiddhantBahuguna Dec 23, 2021
49a2a56
fix: Fixed .gitignore
SiddhantBahuguna Dec 23, 2021
2581ca8
chore: Merged main with the current branch
SiddhantBahuguna Dec 23, 2021
287ee76
chore: Removed unnecessary functions related to photometric augmentat…
SiddhantBahuguna Dec 23, 2021
7507cd5
chore: Reverted last commit
SiddhantBahuguna Dec 23, 2021
d4985ab
feat: Removed shadows and added randomgaussian noise
SiddhantBahuguna Dec 24, 2021
2492ab5
chore: Merged main with the PR
SiddhantBahuguna Dec 27, 2021
db8fe38
fix: Fixed style
SiddhantBahuguna Dec 27, 2021
ecff1c8
fix: Fixed the distribution range
SiddhantBahuguna Dec 28, 2021
ad9cf4a
test: Added unittest for GaussianNoise
SiddhantBahuguna Dec 28, 2021
6161e25
fix: Reverted to original script
SiddhantBahuguna Dec 28, 2021
bacbfb5
fix: Reverted to old train script
SiddhantBahuguna Dec 28, 2021
89d5be2
chore: Merged main with the PR
SiddhantBahuguna Dec 29, 2021
5f21008
Merge branch 'main' into minor_aug
SiddhantBahuguna Dec 29, 2021
55d984a
fix: Updated the new checkpoint path
SiddhantBahuguna Dec 29, 2021
76adc41
chore: Merged with main
SiddhantBahuguna Dec 29, 2021
2f260ae
fix: Fixed typo
SiddhantBahuguna Dec 29, 2021
94535b1
style: Fixed brackets indentation
SiddhantBahuguna Dec 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T

__all__ = ['Resize']
__all__ = ['Resize', 'RandomGaussianNoise']


class Resize(T.Resize):
def __init__(
self,
size: Tuple[int, int],
interpolation=F.InterpolationMode.BILINEAR,
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = False,
self,
size: Tuple[int, int],
interpolation=F.InterpolationMode.BILINEAR,
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = False,
SiddhantBahuguna marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
super().__init__(size, interpolation)
self.preserve_aspect_ratio = preserve_aspect_ratio
Expand Down Expand Up @@ -53,3 +53,15 @@ def __repr__(self) -> str:
if self.preserve_aspect_ratio:
_repr += f", preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}"
return f"{self.__class__.__name__}({_repr})"


class RandomGaussianNoise():
def __init__(self, mean=0.5, std=1.5):
self.std = std
self.mean = mean

def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
SiddhantBahuguna marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self) -> str:
return self.__class__.__name__ + f"mean = {self.mean}, std = {self.std}"
SiddhantBahuguna marked this conversation as resolved.
Show resolved Hide resolved
25 changes: 10 additions & 15 deletions references/obj_detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from fastprogress.fastprogress import master_bar, progress_bar
from torch.optim.lr_scheduler import MultiplicativeLR, StepLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.transforms import Compose

from doctr import transforms as T
from doctr.datasets import DocArtefacts
Expand Down Expand Up @@ -104,33 +105,28 @@ def record_lr(

def convert_to_abs_coords(targets, img_shape):
height, width = img_shape[-2:]
for idx, t in enumerate(targets):
targets[idx]['boxes'][:, 0::2] = (t['boxes'][:, 0::2] * width).round()
targets[idx]['boxes'][:, 1::2] = (t['boxes'][:, 1::2] * height).round()

for idx in range(len(targets)):
targets[idx]['boxes'][:, 0::2] = (targets[idx]['boxes'][:, 0::2] * width).round()
targets[idx]['boxes'][:, 1::2] = (targets[idx]['boxes'][:, 1::2] * height).round()
targets = [{
"boxes": torch.from_numpy(t['boxes']).to(dtype=torch.float32),
"labels": torch.tensor(t['labels']).to(dtype=torch.long)}
for t in targets
]

return targets


def fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp=False):
def fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp):
SiddhantBahuguna marked this conversation as resolved.
Show resolved Hide resolved
if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
for images, targets in progress_bar(train_iter, parent=mb):

targets = convert_to_abs_coords(targets, images.shape)
if torch.cuda.is_available():
images = images.cuda()
targets = [{k: v.cuda() for k, v in t.items()} for t in targets]

optimizer.zero_grad()
if amp:
with torch.cuda.amp.autocast():
Expand All @@ -145,7 +141,6 @@ def fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp=False):
loss = sum(v for v in loss_dict.values())
loss.backward()
optimizer.step()

mb.child.comment = f'Training loss: {loss.item()}'
scheduler.step()

Expand All @@ -156,8 +151,8 @@ def evaluate(model, val_loader, metric, amp=False):
metric.reset()
val_iter = iter(val_loader)
for images, targets in val_iter:

images, targets = next(val_iter)
# batch_transforms
targets = convert_to_abs_coords(targets, images.shape)
if torch.cuda.is_available():
images = images.cuda()
Expand Down Expand Up @@ -238,11 +233,12 @@ def main(args):
return

st = time.time()
# Load both train and val data generators

train_set = DocArtefacts(
train=True,
download=True,
sample_transforms=T.Resize((args.input_size, args.input_size)),
sample_transforms=Compose([T.Resize((args.input_size, args.input_size)),
T.RandomGaussianNoise(0.5, 1.5)])
)

train_loader = DataLoader(
Expand Down Expand Up @@ -305,7 +301,6 @@ def main(args):

mb = master_bar(range(args.epochs))
max_score = 0.

for epoch in mb:
fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp=args.amp)
# Validation loop at the end of each epoch
Expand Down Expand Up @@ -340,7 +335,7 @@ def parse_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('arch', type=str, help='text-detection model to train')
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('--epochs', type=int, default=20, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for training')
parser.add_argument('--device', default=None, type=int, help='device')
parser.add_argument('--input_size', type=int, default=1024, help='model input size, H = W')
Expand Down
1 change: 1 addition & 0 deletions references/obj_detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.


SiddhantBahuguna marked this conversation as resolved.
Show resolved Hide resolved
from typing import Dict, List

import cv2
Expand Down