Skip to content

Commit

Permalink
[MXNET-1033] Fix a bug in MultiboxTarget GPU implementation (apache#1…
Browse files Browse the repository at this point in the history
…2840)

* remove num_labels check in multibox_target

* add unit test

* test both cpu and gpu

* add contrib operator to GPU unit test

* do not test all contrib operator in gpu
  • Loading branch information
apeforest authored and ChaiBapchya committed Oct 30, 2018
1 parent 709632e commit 3610c2b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
1 change: 0 additions & 1 deletion src/operator/contrib/multibox_target.cu
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ inline void MultiBoxTargetForward(const Tensor<gpu, 2, DType> &loc_target,
const int num_anchors = anchors.size(0);
const int num_classes = cls_preds.size(1);
CHECK_GE(num_batches, 1);
CHECK_GT(num_labels, 2);
CHECK_GE(num_anchors, 1);
CHECK_EQ(variances.ndim(), 4);

Expand Down
1 change: 1 addition & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from test_sparse_operator import *
from test_ndarray import *
from test_subgraph_op import *
from test_contrib_operator import test_multibox_target_op

set_default_context(mx.gpu(0))
del test_support_vector_machine_l1_svm # noqa
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_contrib_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,22 @@ def assert_match(inputs, x, y, threshold, is_ascend=False):
assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False)
assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True)

def test_multibox_target_op():
anchors = mx.nd.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]], ctx=default_context()).reshape((1, -1, 4))
cls_pred = mx.nd.array(list(range(10)), ctx=default_context()).reshape((1, -1, 2))
label = mx.nd.array([1, 0.1, 0.1, 0.5, 0.6], ctx=default_context()).reshape((1, -1, 5))

loc_target, loc_mask, cls_target = \
mx.nd.contrib.MultiBoxTarget(anchors, label, cls_pred,
overlap_threshold=0.5,
negative_mining_ratio=3,
negative_mining_thresh=0.4)
expected_loc_target = np.array([[5.0, 2.5000005, 3.4657357, 4.581454, 0., 0., 0., 0.]])
expected_loc_mask = np.array([[1, 1, 1, 1, 0, 0, 0, 0]])
expected_cls_target = np.array([[2, 0]])
assert_allclose(loc_target.asnumpy(), expected_loc_target, rtol=1e-5, atol=1e-5)
assert_array_equal(loc_mask.asnumpy(), expected_loc_mask)
assert_array_equal(cls_target.asnumpy(), expected_cls_target)

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 3610c2b

Please sign in to comment.