Skip to content

Commit

Permalink
feat: Added plot_samples support to visualize the images along with t…
Browse files Browse the repository at this point in the history
…he targets (#704)

* feat: Added inference script for artefact detection

* chore: Moved artefact inference script to doctr/scripts

* feat: Added plot_samples script

* fix: Fixed  the utils script

* chore: Deleted unnecessary print commands

* fix: Fixed taking into an account the comments

* chore: Increased line width of cv2 bounding box
  • Loading branch information
SiddhantBahuguna committed Dec 14, 2021
1 parent f119778 commit 618b0fa
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
14 changes: 9 additions & 5 deletions references/obj_detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from doctr.datasets import DocArtefacts
from doctr.models import obj_detection
from doctr.utils import DetectionMetric
from utils import plot_recorder
from utils import plot_recorder, plot_samples


def record_lr(
Expand Down Expand Up @@ -104,7 +104,6 @@ def record_lr(

def convert_to_abs_coords(targets, img_shape):
height, width = img_shape[-2:]

for idx, t in enumerate(targets):
targets[idx]['boxes'][:, 0::2] = (t['boxes'][:, 0::2] * width).round()
targets[idx]['boxes'][:, 1::2] = (t['boxes'][:, 1::2] * height).round()
Expand All @@ -119,7 +118,6 @@ def convert_to_abs_coords(targets, img_shape):


def fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp=False):

if amp:
scaler = torch.cuda.amp.GradScaler()

Expand Down Expand Up @@ -181,7 +179,6 @@ def evaluate(model, val_loader, metric, amp=False):


def main(args):

print(args)

if not isinstance(args.workers, int):
Expand Down Expand Up @@ -260,6 +257,12 @@ def main(args):
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{len(train_loader)} batches)")

if args.show_samples:
images, targets = next(iter(train_loader))
targets = convert_to_abs_coords(targets, images.shape)
plot_samples(images, targets, train_set.CLASSES)
return

# Backbone freezing
if args.freeze_backbone:
for p in model.backbone.parameters():
Expand All @@ -282,7 +285,6 @@ def main(args):

# W&B
if args.wb:

run = wandb.init(
name=exp_name,
project="object-detection",
Expand Down Expand Up @@ -347,6 +349,8 @@ def parse_args():
parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading')
parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint')
parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop")
parser.add_argument('--show-samples', dest='show_samples', action='store_true',
help='Display unormalized training samples')
parser.add_argument('--freeze-backbone', dest='freeze_backbone', action='store_true',
help='freeze model backbone for fine-tuning')
parser.add_argument('--wb', dest='wb', action='store_true',
Expand Down
31 changes: 31 additions & 0 deletions references/obj_detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,39 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import Dict, List

import cv2
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.cm import get_cmap


def plot_samples(images, targets: List[Dict[str, np.ndarray]], classes: List[str]) -> None:
cmap = get_cmap('gist_rainbow', len(classes))
# Unnormalize image
nb_samples = min(len(images), 4)
_, axes = plt.subplots(1, nb_samples, figsize=(20, 5))
for idx in range(nb_samples):
img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8)
if img.shape[0] == 3 and img.shape[2] != 3:
img = img.transpose(1, 2, 0)
target = img.copy()
for box, class_idx in zip(targets[idx]['boxes'].numpy(), targets[idx]['labels']):
r, g, b, _ = cmap(class_idx.numpy())
color = int(round(255 * r)), int(round(255 * g)), int(round(255 * b))
cv2.rectangle(target, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, 2)
text_size, _ = cv2.getTextSize(classes[class_idx], cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
text_w, text_h = text_size
cv2.rectangle(target, (int(box[0]), int(box[1])), (int(box[0]) + text_w, int(box[1]) - text_h), color, -1)
cv2.putText(target, classes[class_idx], (int(box[0]), int(box[1])), cv2.FONT_HERSHEY_SIMPLEX, 1,
(255, 255, 255), 2)

axes[idx].imshow(target)
# Disable axis
for ax in axes.ravel():
ax.axis('off')
plt.show()


def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None:
Expand Down

0 comments on commit 618b0fa

Please sign in to comment.