Skip to content

Commit

Permalink
Implement Mask RCNN via Tensorflow serverless function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikita Manovich committed Jul 8, 2020
1 parent ecb8a72 commit 9b0faf6
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 430 deletions.
111 changes: 104 additions & 7 deletions serverless/tensorflow/matterport/mask_rcnn/nuclio/function.yaml
Original file line number Diff line number Diff line change
@@ -1,35 +1,132 @@
metadata:
name: tf-maskrcnn
name: tf.matterport.mask_rcnn
namespace: cvat
annotations:
name: Mask RCNN via Tensorflow
type: detector
framework: tensorflow
spec: |
[
{ "id": 0, "name": "BG" },
{ "id": 1, "name": "person" },
{ "id": 2, "name": "bicycle" },
{ "id": 3, "name": "car" },
{ "id": 4, "name": "motorcycle" },
{ "id": 5, "name": "airplane" },
{ "id": 6, "name": "bus" },
{ "id": 7, "name": "train" },
{ "id": 8, "name": "truck" },
{ "id": 9, "name": "boat" },
{ "id": 10, "name": "traffic_light" },
{ "id": 11, "name": "fire_hydrant" },
{ "id": 12, "name": "stop_sign" },
{ "id": 13, "name": "parking_meter" },
{ "id": 14, "name": "bench" },
{ "id": 15, "name": "bird" },
{ "id": 16, "name": "cat" },
{ "id": 17, "name": "dog" },
{ "id": 18, "name": "horse" },
{ "id": 19, "name": "sheep" },
{ "id": 20, "name": "cow" },
{ "id": 21, "name": "elephant" },
{ "id": 22, "name": "bear" },
{ "id": 23, "name": "zebra" },
{ "id": 24, "name": "giraffe" },
{ "id": 25, "name": "backpack" },
{ "id": 26, "name": "umbrella" },
{ "id": 27, "name": "handbag" },
{ "id": 28, "name": "tie" },
{ "id": 29, "name": "suitcase" },
{ "id": 30, "name": "frisbee" },
{ "id": 31, "name": "skis" },
{ "id": 32, "name": "snowboard" },
{ "id": 33, "name": "sports_ball" },
{ "id": 34, "name": "kite" },
{ "id": 35, "name": "baseball_bat" },
{ "id": 36, "name": "baseball_glove" },
{ "id": 37, "name": "skateboard" },
{ "id": 38, "name": "surfboard" },
{ "id": 39, "name": "tennis_racket" },
{ "id": 40, "name": "bottle" },
{ "id": 41, "name": "wine_glass" },
{ "id": 42, "name": "cup" },
{ "id": 43, "name": "fork" },
{ "id": 44, "name": "knife" },
{ "id": 45, "name": "spoon" },
{ "id": 46, "name": "bowl" },
{ "id": 47, "name": "banana" },
{ "id": 48, "name": "apple" },
{ "id": 49, "name": "sandwich" },
{ "id": 50, "name": "orange" },
{ "id": 51, "name": "broccoli" },
{ "id": 52, "name": "carrot" },
{ "id": 53, "name": "hot_dog" },
{ "id": 54, "name": "pizza" },
{ "id": 55, "name": "donut" },
{ "id": 56, "name": "cake" },
{ "id": 57, "name": "chair" },
{ "id": 58, "name": "couch" },
{ "id": 59, "name": "potted_plant" },
{ "id": 60, "name": "bed" },
{ "id": 61, "name": "dining_table" },
{ "id": 62, "name": "toilet" },
{ "id": 63, "name": "tv" },
{ "id": 64, "name": "laptop" },
{ "id": 65, "name": "mouse" },
{ "id": 66, "name": "remote" },
{ "id": 67, "name": "keyboard" },
{ "id": 68, "name": "cell_phone" },
{ "id": 69, "name": "microwave" },
{ "id": 70, "name": "oven" },
{ "id": 71, "name": "toaster" },
{ "id": 72, "name": "sink" },
{ "id": 73, "name": "refrigerator" },
{ "id": 74, "name": "book" },
{ "id": 75, "name": "clock" },
{ "id": 76, "name": "vase" },
{ "id": 77, "name": "scissors" },
{ "id": 78, "name": "teddy_bear" },
{ "id": 79, "name": "hair_drier" },
{ "id": 80, "name": "toothbrush" }
]
spec:
description: TensorFlow MASK RCNN
description: |
An implementation of Mask RCNN on Python 3, Keras, and TensorFlow.
runtime: "python:3.6"
handler: main:handler
eventTimeout: 30s
env:
- name: MASK_RCNN_PATH
- name: MASK_RCNN_DIR
value: /opt/nuclio/Mask_RCNN

build:
image: cvat/tf-maskrcnn
image: cvat/tf.matterport.mask_rcnn
baseImage: tensorflow/tensorflow:2.1.0-py3

directives:
postCopy:
- kind: WORKDIR
value: /opt/nuclio
- kind: RUN
value: apt update && apt install --no-install-recommends -y git curl
value: apt update && apt install --no-install-recommends -y git curl libsm6 libxext6 libxrender-dev
- kind: RUN
value: git clone https://github.com/matterport/Mask_RCNN.git
- kind: RUN
value: curl -L https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5 -o Mask_RCNN/mask_rcnn_coco.h5
- kind: RUN
value: pip3 install -r Mask_RCNN/requirements.txt
- kind: RUN
value: pip3 install pycocotools tensorflow==1.13.1 keras==2.1.0 pillow pyyaml

triggers:
myHttpTrigger:
maxWorkers: 2
kind: "http"
workerAvailabilityTimeoutMilliseconds: 10000
attributes:
maxRequestBodySize: 33554432 # 32MB

platform:
restartPolicy:
name: always
maximumRetryCount: 3
28 changes: 19 additions & 9 deletions serverless/tensorflow/matterport/mask_rcnn/nuclio/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,31 @@
import base64
from PIL import Image
import io
from model_loader import ModelLoader
import numpy as np
import yaml


def init_context(context):
context.logger.info("Init context... 0%")
maskrcnn_handler = None
setattr(context.user_data, 'maskrcnn_handler', maskrcnn_handler)

functionconfig = yaml.safe_load(open("/opt/nuclio/function.yaml"))
labels_spec = functionconfig['metadata']['annotations']['spec']
labels = {item['id']: item['name'] for item in json.loads(labels_spec)}

model_handler = ModelLoader(labels)
setattr(context.user_data, 'model_handler', model_handler)

context.logger.info("Init context...100%")

def handler(context, event):
context.logger.info("call handler")
context.logger.info("Run tf.matterport.mask_rcnn model")
data = event.body
buf = io.BytesIO(base64.b64decode(data["image"]))
buf = io.BytesIO(base64.b64decode(data["image"].encode('utf-8')))
threshold = float(data.get("threshold", 0.2))
image = Image.open(buf)

objects = context.user_data.maskrcnn_handler.handle(image)
return context.Response(body=json.dumps(objects),
headers={},
content_type='application/json',
status_code=200)
results = context.user_data.model_handler.infer(np.array(image), threshold)

return context.Response(body=json.dumps(results), headers={},
content_type='application/json', status_code=200)
104 changes: 0 additions & 104 deletions serverless/tensorflow/matterport/mask_rcnn/nuclio/mask_rcnn.py

This file was deleted.

80 changes: 80 additions & 0 deletions serverless/tensorflow/matterport/mask_rcnn/nuclio/model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (C) 2018-2020 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os
import numpy as np
import sys
from skimage.measure import find_contours, approximate_polygon

# workarround for tf.placeholder() is not compatible with eager execution
# https://github.com/tensorflow/tensorflow/issues/18165
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
#import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()

# The directory should contain a clone of
# https://github.com/matterport/Mask_RCNN repository and
# downloaded mask_rcnn_coco.h5 model.
MASK_RCNN_DIR = os.environ.get('MASK_RCNN_DIR')
if MASK_RCNN_DIR:
sys.path.append(MASK_RCNN_DIR) # To find local version of the library
sys.path.append(os.path.join(MASK_RCNN_DIR, 'samples/coco'))

from mrcnn import model as modellib, utils
import coco

class ModelLoader:
def __init__(self, labels):
COCO_MODEL_PATH = os.path.join(MASK_RCNN_DIR, "mask_rcnn_coco.h5")
if COCO_MODEL_PATH is None:
raise OSError('Model path env not found in the system.')

class InferenceConfig(coco.CocoConfig):
# Set batch size to 1 since we'll be running inference on
# one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
GPU_COUNT = 1
IMAGES_PER_GPU = 1

# Print config details
self.config = InferenceConfig()
self.config.display()

self.model = modellib.MaskRCNN(mode="inference",
config=self.config, model_dir=MASK_RCNN_DIR)
self.model.load_weights(COCO_MODEL_PATH, by_name=True)
self.labels = labels

def infer(self, image, threshold):
output = self.model.detect([image], verbose=1)[0]

results = []
MASK_THRESHOLD = 0.5
for i, box in enumerate(output["rois"]):
score = output["scores"][i]
class_id = output["class_ids"][i]
mask = output["masks"][:,:,i]
if score >= threshold:
mask = mask.astype(np.uint8)
contours = find_contours(mask, MASK_THRESHOLD)
# only one contour exist in our case
contour = contours[0]
contour = np.flip(contour, axis=1)
# Approximate the contour and reduce the number of points
contour = approximate_polygon(contour, tolerance=2.5)
polygon = contour.ravel().tolist()
if len(contour) < 3:
continue
label = self.labels[class_id]

results.append({
"confidence": str(score),
"label": label,
"points": polygon,
"type": "polygon",
})

return results


Loading

0 comments on commit 9b0faf6

Please sign in to comment.