diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index b846700fb503..5c4f01e2fd24 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -140,10 +140,6 @@ def imdecode(buf, *args, **kwargs): 'if you would like to input type str, please convert to bytes') buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8) - if len(buf) == 0: - # empty buf causes OpenCV crash. - raise ValueError("input buf cannot be empty.") - return _internal._cvimdecode(buf, *args, **kwargs) diff --git a/src/io/image_io.cc b/src/io/image_io.cc index a996a2208d79..b3f7c40b2b1a 100644 --- a/src/io/image_io.cc +++ b/src/io/image_io.cc @@ -143,11 +143,8 @@ void ImdecodeImpl(int flag, bool to_rgb, void* data, size_t size, cv::Mat dst; if (out->is_none()) { cv::Mat res = cv::imdecode(buf, flag); - if (res.empty()) { - LOG(INFO) << "Decoding failed. Invalid image file."; - *out = NDArray(); - return; - } + CHECK(!res.empty()) << "Decoding failed. Invalid image file."; + *out = NDArray(mshadow::Shape3(res.rows, res.cols, flag == 0 ? 1 : 3), Context::CPU(), false, mshadow::kUint8); dst = cv::Mat(out->shape()[0], out->shape()[1], flag == 0 ? CV_8U : CV_8UC3, @@ -189,6 +186,8 @@ void Imdecode(const nnvm::NodeAttrs& attrs, uint8_t* str_img = inputs[0].data().dptr(); size_t len = inputs[0].shape().Size(); + CHECK(len > 0) << "Input cannot be an empty buffer"; + TShape oshape(3); oshape[2] = param.flag == 0 ? 1 : 3; if (get_jpeg_size(str_img, len, &oshape[1], &oshape[0])) { diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index c8022b67bee8..3e35d6de2bcb 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -104,10 +104,15 @@ def test_imdecode_bytearray(self): cv_image = cv2.imread(img) assert_almost_equal(image.asnumpy(), cv_image) - @raises(ValueError) + @raises(mx.base.MXNetError) def test_imdecode_empty_buffer(self): mx.image.imdecode(b'', to_rgb=0) + @raises(mx.base.MXNetError) + def test_imdecode_invalid_image(self): + image = mx.image.imdecode(b'clearly not image content') + assert_equal(image, None) + def test_scale_down(self): assert mx.image.scale_down((640, 480), (720, 120)) == (640, 106) assert mx.image.scale_down((360, 1000), (480, 500)) == (360, 375)