Skip to content

Commit

Permalink
chore: Cleaned and improved the code
Browse files Browse the repository at this point in the history
  • Loading branch information
SiddhantBahuguna committed Dec 23, 2021
1 parent e90d9b2 commit 878bc83
Showing 1 changed file with 20 additions and 22 deletions.
42 changes: 20 additions & 22 deletions scripts/detect_artefacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,28 @@

import cv2
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import torch

from doctr.models import obj_detection
from doctr.io.image import read_img_as_tensor
from doctr.models import obj_detection

CLASSES = ["__background__", "QR Code", "Barcode", "Logo", "Photo"]
CM = [(255, 255, 255), (0, 0, 150), (0, 0, 0), (0, 150, 0), (150, 0, 0)]

def plot_predictions(image, tg_boxes, tg_labels, cl_map, cm):
for ind_2, val_2 in enumerate(tg_boxes):

def plot_predictions(image, boxes, labels):
for box, label in zip(boxes, labels):
# Bounding box around artefacts
cv2.rectangle(image, (val_2[0], val_2[1]), (val_2[2], val_2[3]),
cm[tg_labels[ind_2]], 2)
text_size, _ = cv2.getTextSize(cl_map[int(tg_labels[ind_2])], cv2.FONT_HERSHEY_SIMPLEX, 2, 2)
cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]),
CM[label], 2)
text_size, _ = cv2.getTextSize(CLASSES[label], cv2.FONT_HERSHEY_SIMPLEX, 2, 2)
text_w, text_h = text_size
# Filled rectangle above bounding box
cv2.rectangle(image, (val_2[0], val_2[1]), (val_2[0] + text_w, val_2[1] - text_h),
cm[tg_labels[ind_2]], -1)
cv2.rectangle(image, (box[0], box[1]), (box[0] + text_w, box[1] - text_h),
CM[label], -1)
# Text bearing the name of the artefact detected
cv2.putText(image, cl_map[int(tg_labels[ind_2])], (int(val_2[0]), int(val_2[1])),
cv2.putText(image, CLASSES[label], (int(box[0]), int(box[1])),
cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3)
figure(figsize=(9, 7), dpi=100)
plt.axis('off')
plt.imshow(image)
plt.show()
Expand All @@ -55,28 +56,25 @@ def main(args):
args.device = 0
else:
logging.warning("No accessible GPU, target device set to CPU.")
img = read_img_as_tensor(args.img_path).unsqueeze(0)
if torch.cuda.is_available():
torch.cuda.set_device(args.device)
model = model.cuda()

cm = {1: (0, 0, 150), 2: (0, 0, 0), 3: (0, 150, 0), 4: (150, 0, 0)}
cl_map = {1: "QR_Code", 2: "Bar_Code", 3: "Logo", 4: "Photo"}
img = read_img_as_tensor(args.img_path).unsqueeze(0)
if torch.cuda.is_available():
img = img.cuda()

pred = model(img)
tg_labels = pred[0]['labels'].detach().cpu().numpy()
tg_labels = tg_labels.round().astype(int)
tg_boxes = pred[0]['boxes'].detach().cpu().numpy()
tg_boxes = tg_boxes.round().astype(int)
labels = pred[0]['labels'].detach().cpu().numpy()
labels = labels.round().astype(int)
boxes = pred[0]['boxes'].detach().cpu().numpy()
boxes = boxes.round().astype(int)
img = img.cpu().permute(0, 2, 3, 1).numpy()[0].copy()
plot_predictions(img, tg_boxes, tg_labels, cl_map, cm)
plot_predictions(img, boxes, labels)


def parse_args():
parser = argparse.ArgumentParser(description="Artefact detection model to use",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('arch', type=str, help='text-detection model to train')
parser.add_argument('arch', type=str, help='Artefact detection model to use')
parser.add_argument('img_path', type=str, help='path to the image')
parser.add_argument('--device', default=None, type=int, help='device')
args = parser.parse_args()
Expand Down

0 comments on commit 878bc83

Please sign in to comment.