From 4c634e0f171bac81211aedba84a09ceb5b721616 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Wed, 12 Sep 2018 13:32:57 -0700 Subject: [PATCH] Revert "Change the way NDArrayIter handle the last batch" (#12537) * Revert "Removing the re-size for validation data, which breaking the validation accuracy of CIFAR training (#12362)" This reverts commit ceabcaac77543d99246415b2fb2d8c973a830453. * Revert "[MXNET-580] Add SN-GAN example (#12419)" This reverts commit 46a5cee2515a1ac0a1ae5afbe7e639debb998587. * Revert "Remove regression checks for website links (#12507)" This reverts commit 619bc3ea3c9093b72634d16e91596b3a65f3f1fc. * Revert "Revert "Fix flaky test: test_mkldnn.test_activation #12377 (#12418)" (#12516)" This reverts commit 7ea05333efc8ca868443b89233b101d068f6af9f. * Revert "further bump up tolerance for sparse dot (#12527)" This reverts commit 90599e1038a4ff6604e9ed0d55dc274c2df635f8. * Revert "Fix broken URLs (#12508)" This reverts commit 3d83c896fd8b237c53003888e35a4d792c1e5389. * Revert "Temporarily disable flaky tests (#12520)" This reverts commit 35ca13c3b5a0e57d904d1fead079152a15dfeac4. * Revert "Add support for more req patterns for bilinear sampler backward (#12386)" This reverts commit 4ee866fc75307b284cc0eae93d0cf4dad3b62533. * Revert "Change the way NDArrayIter handle the last batch (#12285)" This reverts commit 597a637fb1b8fa5b16331218cda8be61ce0ee202. --- CONTRIBUTORS.md | 1 - python/mxnet/{io => }/io.py | 280 +++++++++++++++---------------- python/mxnet/io/__init__.py | 29 ---- python/mxnet/io/utils.py | 86 ---------- tests/python/unittest/test_io.py | 122 ++++++-------- 5 files changed, 190 insertions(+), 328 deletions(-) rename python/mxnet/{io => }/io.py (82%) delete mode 100644 python/mxnet/io/__init__.py delete mode 100644 python/mxnet/io/utils.py diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 1c005d57c4a6..8d8aeaca73e4 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -178,4 +178,3 @@ List of Contributors * [Aaron Markham](https://github.com/aaronmarkham) * [Sam Skalicky](https://github.com/samskalicky) * [Per Goncalves da Silva](https://github.com/perdasilva) -* [Cheng-Che Lee](https://github.com/stu1130) diff --git a/python/mxnet/io/io.py b/python/mxnet/io.py similarity index 82% rename from python/mxnet/io/io.py rename to python/mxnet/io.py index 2ae3e70045fb..884e9294741a 100644 --- a/python/mxnet/io/io.py +++ b/python/mxnet/io.py @@ -17,26 +17,30 @@ """Data iterators for common data formats.""" from __future__ import absolute_import -from collections import namedtuple +from collections import OrderedDict, namedtuple import sys import ctypes import logging import threading +try: + import h5py +except ImportError: + h5py = None import numpy as np - -from ..base import _LIB -from ..base import c_str_array, mx_uint, py_str -from ..base import DataIterHandle, NDArrayHandle -from ..base import mx_real_t -from ..base import check_call, build_param_doc as _build_param_doc -from ..ndarray import NDArray -from ..ndarray.sparse import CSRNDArray -from ..ndarray import _ndarray_cls -from ..ndarray import array -from ..ndarray import concat - -from .utils import init_data, has_instance, getdata_by_idx +from .base import _LIB +from .base import c_str_array, mx_uint, py_str +from .base import DataIterHandle, NDArrayHandle +from .base import mx_real_t +from .base import check_call, build_param_doc as _build_param_doc +from .ndarray import NDArray +from .ndarray.sparse import CSRNDArray +from .ndarray.sparse import array as sparse_array +from .ndarray import _ndarray_cls +from .ndarray import array +from .ndarray import concatenate +from .ndarray import arange +from .ndarray.random import shuffle as random_shuffle class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): """DataDesc is used to store name, shape, type and layout @@ -485,6 +489,59 @@ def getindex(self): def getpad(self): return self.current_batch.pad +def _init_data(data, allow_empty, default_name): + """Convert data into canonical form.""" + assert (data is not None) or allow_empty + if data is None: + data = [] + + if isinstance(data, (np.ndarray, NDArray, h5py.Dataset) + if h5py else (np.ndarray, NDArray)): + data = [data] + if isinstance(data, list): + if not allow_empty: + assert(len(data) > 0) + if len(data) == 1: + data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type + else: + data = OrderedDict( # pylint: disable=redefined-variable-type + [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)]) + if not isinstance(data, dict): + raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \ + "a list of them or dict with them as values") + for k, v in data.items(): + if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray): + try: + data[k] = array(v) + except: + raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \ + "should be NDArray, numpy.ndarray or h5py.Dataset") + + return list(sorted(data.items())) + +def _has_instance(data, dtype): + """Return True if ``data`` has instance of ``dtype``. + This function is called after _init_data. + ``data`` is a list of (str, NDArray)""" + for item in data: + _, arr = item + if isinstance(arr, dtype): + return True + return False + +def _shuffle(data, idx): + """Shuffle the data.""" + shuffle_data = [] + + for k, v in data: + if (isinstance(v, h5py.Dataset) if h5py else False): + shuffle_data.append((k, v)) + elif isinstance(v, CSRNDArray): + shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context))) + else: + shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) + + return shuffle_data class NDArrayIter(DataIter): """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset`` @@ -544,22 +601,6 @@ class NDArrayIter(DataIter): ... >>> batchidx # Remaining examples are discarded. So, 10/3 batches are created. 3 - >>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over') - >>> batchidx = 0 - >>> for batch in dataiter: - ... batchidx += 1 - ... - >>> batchidx # Remaining examples are rolled over to the next iteration. - 3 - >>> dataiter.reset() - >>> dataiter.next().data[0].asnumpy() - [[[ 36. 37.] - [ 38. 39.]] - [[ 0. 1.] - [ 2. 3.]] - [[ 4. 5.] - [ 6. 7.]]] - (3L, 2L, 2L) `NDArrayIter` also supports multiple input and labels. @@ -592,11 +633,8 @@ class NDArrayIter(DataIter): Only supported if no h5py.Dataset inputs are used. last_batch_handle : str, optional How to handle the last batch. This parameter can be 'pad', 'discard' or - 'roll_over'. - If 'pad', the last batch will be padded with data starting from the begining - If 'discard', the last batch will be discarded - If 'roll_over', the remaining elements will be rolled over to the next iteration and - note that it is intended for training and can cause problems if used for prediction. + 'roll_over'. 'roll_over' is intended for training and can cause problems + if used for prediction. data_name : str, optional The data name. label_name : str, optional @@ -607,28 +645,36 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False, label_name='softmax_label'): super(NDArrayIter, self).__init__(batch_size) - self.data = init_data(data, allow_empty=False, default_name=data_name) - self.label = init_data(label, allow_empty=True, default_name=label_name) + self.data = _init_data(data, allow_empty=False, default_name=data_name) + self.label = _init_data(label, allow_empty=True, default_name=label_name) - if ((has_instance(self.data, CSRNDArray) or has_instance(self.label, CSRNDArray)) and + if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and (last_batch_handle != 'discard')): raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \ " with `last_batch_handle` set to `discard`.") - self.idx = np.arange(self.data[0][1].shape[0]) - self.shuffle = shuffle - self.last_batch_handle = last_batch_handle - self.batch_size = batch_size - self.cursor = -self.batch_size - self.num_data = self.idx.shape[0] - # shuffle - self.reset() + # shuffle data + if shuffle: + tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32) + self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy() + self.data = _shuffle(self.data, self.idx) + self.label = _shuffle(self.label, self.idx) + else: + self.idx = np.arange(self.data[0][1].shape[0]) + + # batching + if last_batch_handle == 'discard': + new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size + self.idx = self.idx[:new_n] self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label] self.num_source = len(self.data_list) - # used for 'roll_over' - self._cache_data = None - self._cache_label = None + self.num_data = self.idx.shape[0] + assert self.num_data >= batch_size, \ + "batch_size needs to be smaller than data size." + self.cursor = -batch_size + self.batch_size = batch_size + self.last_batch_handle = last_batch_handle @property def provide_data(self): @@ -648,126 +694,74 @@ def provide_label(self): def hard_reset(self): """Ignore roll over data and set to start.""" - if self.shuffle: - self._shuffle_data() self.cursor = -self.batch_size - self._cache_data = None - self._cache_label = None def reset(self): - """Resets the iterator to the beginning of the data.""" - if self.shuffle: - self._shuffle_data() - # the range below indicate the last batch - if self.last_batch_handle == 'roll_over' and \ - self.num_data - self.batch_size < self.cursor < self.num_data: - # (self.cursor - self.num_data) represents the data we have for the last batch - self.cursor = self.cursor - self.num_data - self.batch_size + if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data: + self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size else: self.cursor = -self.batch_size def iter_next(self): - """Increments the coursor by batch_size for next batch - and check current cursor if it exceed the number of data points.""" self.cursor += self.batch_size return self.cursor < self.num_data def next(self): - """Returns the next batch of data.""" - if not self.iter_next(): - raise StopIteration - data = self.getdata() - label = self.getlabel() - # iter should stop when last batch is not complete - if data[0].shape[0] != self.batch_size: - # in this case, cache it for next epoch - self._cache_data = data - self._cache_label = label + if self.iter_next(): + return DataBatch(data=self.getdata(), label=self.getlabel(), \ + pad=self.getpad(), index=None) + else: raise StopIteration - return DataBatch(data=data, label=label, \ - pad=self.getpad(), index=None) - - def _getdata(self, data_source, start=None, end=None): - """Load data from underlying arrays.""" - assert start is not None or end is not None, 'should at least specify start or end' - start = start if start is not None else 0 - end = end if end is not None else data_source[0][1].shape[0] - s = slice(start, end) - return [ - x[1][s] - if isinstance(x[1], (np.ndarray, NDArray)) else - # h5py (only supports indices in increasing order) - array(x[1][sorted(self.idx[s])][[ - list(self.idx[s]).index(i) - for i in sorted(self.idx[s]) - ]]) for x in data_source - ] - def _concat(self, first_data, second_data): - """Helper function to concat two NDArrays.""" - return [ - concat(first_data[0], second_data[0], dim=0) - ] - - def _batchify(self, data_source): + def _getdata(self, data_source): """Load data from underlying arrays, internal use only.""" - assert self.cursor < self.num_data, 'DataIter needs reset.' - # first batch of next epoch with 'roll_over' - if self.last_batch_handle == 'roll_over' and \ - -self.batch_size < self.cursor < 0: - assert self._cache_data is not None or self._cache_label is not None, \ - 'next epoch should have cached data' - cache_data = self._cache_data if self._cache_data is not None else self._cache_label - second_data = self._getdata( - data_source, end=self.cursor + self.batch_size) - if self._cache_data is not None: - self._cache_data = None - else: - self._cache_label = None - return self._concat(cache_data, second_data) - # last batch with 'pad' - elif self.last_batch_handle == 'pad' and \ - self.cursor + self.batch_size > self.num_data: - pad = self.batch_size - self.num_data + self.cursor - first_data = self._getdata(data_source, start=self.cursor) - second_data = self._getdata(data_source, end=pad) - return self._concat(first_data, second_data) - # normal case + assert(self.cursor < self.num_data), "DataIter needs reset." + if self.cursor + self.batch_size <= self.num_data: + return [ + # np.ndarray or NDArray case + x[1][self.cursor:self.cursor + self.batch_size] + if isinstance(x[1], (np.ndarray, NDArray)) else + # h5py (only supports indices in increasing order) + array(x[1][sorted(self.idx[ + self.cursor:self.cursor + self.batch_size])][[ + list(self.idx[self.cursor: + self.cursor + self.batch_size]).index(i) + for i in sorted(self.idx[ + self.cursor:self.cursor + self.batch_size]) + ]]) for x in data_source + ] else: - if self.cursor + self.batch_size < self.num_data: - end_idx = self.cursor + self.batch_size - # get incomplete last batch - else: - end_idx = self.num_data - return self._getdata(data_source, self.cursor, end_idx) + pad = self.batch_size - self.num_data + self.cursor + return [ + # np.ndarray or NDArray case + concatenate([x[1][self.cursor:], x[1][:pad]]) + if isinstance(x[1], (np.ndarray, NDArray)) else + # h5py (only supports indices in increasing order) + concatenate([ + array(x[1][sorted(self.idx[self.cursor:])][[ + list(self.idx[self.cursor:]).index(i) + for i in sorted(self.idx[self.cursor:]) + ]]), + array(x[1][sorted(self.idx[:pad])][[ + list(self.idx[:pad]).index(i) + for i in sorted(self.idx[:pad]) + ]]) + ]) for x in data_source + ] def getdata(self): - """Get data.""" - return self._batchify(self.data) + return self._getdata(self.data) def getlabel(self): - """Get label.""" - return self._batchify(self.label) + return self._getdata(self.label) def getpad(self): - """Get pad value of DataBatch.""" if self.last_batch_handle == 'pad' and \ self.cursor + self.batch_size > self.num_data: return self.cursor + self.batch_size - self.num_data - # check the first batch - elif self.last_batch_handle == 'roll_over' and \ - -self.batch_size < self.cursor < 0: - return -self.cursor else: return 0 - def _shuffle_data(self): - """Shuffle the data.""" - # shuffle index - np.random.shuffle(self.idx) - # get the data by corresponding index - self.data = getdata_by_idx(self.data, self.idx) - self.label = getdata_by_idx(self.label, self.idx) class MXDataIter(DataIter): """A python wrapper a C++ data iterator. @@ -779,7 +773,7 @@ class MXDataIter(DataIter): underlying C++ data iterators. Usually you don't need to interact with `MXDataIter` directly unless you are - implementing your own data iterators in C+ +. To do that, please refer to + implementing your own data iterators in C++. To do that, please refer to examples under the `src/io` folder. Parameters diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py deleted file mode 100644 index 5c5e2e68d84a..000000000000 --- a/python/mxnet/io/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# coding: utf-8 -# pylint: disable=wildcard-import -""" Data iterators for common data formats and utility functions.""" -from __future__ import absolute_import - -from . import io -from .io import * - -from . import utils -from .utils import * diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py deleted file mode 100644 index 872e6410d7de..000000000000 --- a/python/mxnet/io/utils.py +++ /dev/null @@ -1,86 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""utility functions for io.py""" -from collections import OrderedDict - -import numpy as np -try: - import h5py -except ImportError: - h5py = None - -from ..ndarray.sparse import CSRNDArray -from ..ndarray.sparse import array as sparse_array -from ..ndarray import NDArray -from ..ndarray import array - -def init_data(data, allow_empty, default_name): - """Convert data into canonical form.""" - assert (data is not None) or allow_empty - if data is None: - data = [] - - if isinstance(data, (np.ndarray, NDArray, h5py.Dataset) - if h5py else (np.ndarray, NDArray)): - data = [data] - if isinstance(data, list): - if not allow_empty: - assert(len(data) > 0) - if len(data) == 1: - data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type - else: - data = OrderedDict( # pylint: disable=redefined-variable-type - [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)]) - if not isinstance(data, dict): - raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + - "a list of them or dict with them as values") - for k, v in data.items(): - if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray): - try: - data[k] = array(v) - except: - raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + - "should be NDArray, numpy.ndarray or h5py.Dataset") - - return list(sorted(data.items())) - - -def has_instance(data, dtype): - """Return True if ``data`` has instance of ``dtype``. - This function is called after _init_data. - ``data`` is a list of (str, NDArray)""" - for item in data: - _, arr = item - if isinstance(arr, dtype): - return True - return False - - -def getdata_by_idx(data, idx): - """Shuffle the data.""" - shuffle_data = [] - - for k, v in data: - if (isinstance(v, h5py.Dataset) if h5py else False): - shuffle_data.append((k, v)) - elif isinstance(v, CSRNDArray): - shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context))) - else: - shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) - - return shuffle_data diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 69b7a0d562a2..872763f5a783 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -106,88 +106,80 @@ def check_cifar10_exception(): pass assertRaises(MXNetError, check_cifar10_exception) -def _init_NDArrayIter_data(): +def test_NDArrayIter(): data = np.ones([1000, 2, 2]) - labels = np.ones([1000, 1]) + label = np.ones([1000, 1]) for i in range(1000): data[i] = i / 100 - labels[i] = i / 100 - return data, labels - - -def _test_last_batch_handle(data, labels): - # Test the three parameters 'pad', 'discard', 'roll_over' - last_batch_handle_list = ['pad', 'discard' , 'roll_over'] - labelcount_list = [(124, 100), (100, 96), (100, 96)] - batch_count_list = [8, 7, 7] - - for idx in range(len(last_batch_handle_list)): - dataiter = mx.io.NDArrayIter( - data, labels, 128, False, last_batch_handle=last_batch_handle_list[idx]) - batch_count = 0 - labelcount = [0 for i in range(10)] - for batch in dataiter: - label = batch.label[0].asnumpy().flatten() - # check data if it matches corresponding labels - assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()), last_batch_handle_list[idx] - for i in range(label.shape[0]): - labelcount[int(label[i])] += 1 - # keep the last batch of 'pad' to be used later - # to test first batch of roll_over in second iteration - batch_count += 1 - if last_batch_handle_list[idx] == 'pad' and \ - batch_count == 8: - cache = batch.data[0].asnumpy() - # check if batchifying functionality work properly - assert labelcount[0] == labelcount_list[idx][0], last_batch_handle_list[idx] - assert labelcount[8] == labelcount_list[idx][1], last_batch_handle_list[idx] - assert batch_count == batch_count_list[idx] - # roll_over option - dataiter.reset() - assert np.array_equal(dataiter.next().data[0].asnumpy(), cache) - - -def _test_shuffle(data, labels): - dataiter = mx.io.NDArrayIter(data, labels, 1, False) - batch_list = [] + label[i] = i / 100 + dataiter = mx.io.NDArrayIter( + data, label, 128, True, last_batch_handle='pad') + batchidx = 0 for batch in dataiter: - # cache the original data - batch_list.append(batch.data[0].asnumpy()) - dataiter = mx.io.NDArrayIter(data, labels, 1, True) - idx_list = dataiter.idx - i = 0 + batchidx += 1 + assert(batchidx == 8) + dataiter = mx.io.NDArrayIter( + data, label, 128, False, last_batch_handle='pad') + batchidx = 0 + labelcount = [0 for i in range(10)] for batch in dataiter: - # check if each data point have been shuffled to corresponding positions - assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]]) - i += 1 - + label = batch.label[0].asnumpy().flatten() + assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()) + for i in range(label.shape[0]): + labelcount[int(label[i])] += 1 -def test_NDArrayIter(): - data, labels = _init_NDArrayIter_data() - _test_last_batch_handle(data, labels) - _test_shuffle(data, labels) + for i in range(10): + if i == 0: + assert(labelcount[i] == 124) + else: + assert(labelcount[i] == 100) def test_NDArrayIter_h5py(): if not h5py: return - data, labels = _init_NDArrayIter_data() + data = np.ones([1000, 2, 2]) + label = np.ones([1000, 1]) + for i in range(1000): + data[i] = i / 100 + label[i] = i / 100 try: - os.remove('ndarraytest.h5') + os.remove("ndarraytest.h5") except OSError: pass - with h5py.File('ndarraytest.h5') as f: - f.create_dataset('data', data=data) - f.create_dataset('label', data=labels) + with h5py.File("ndarraytest.h5") as f: + f.create_dataset("data", data=data) + f.create_dataset("label", data=label) + + dataiter = mx.io.NDArrayIter( + f["data"], f["label"], 128, True, last_batch_handle='pad') + batchidx = 0 + for batch in dataiter: + batchidx += 1 + assert(batchidx == 8) + + dataiter = mx.io.NDArrayIter( + f["data"], f["label"], 128, False, last_batch_handle='pad') + labelcount = [0 for i in range(10)] + for batch in dataiter: + label = batch.label[0].asnumpy().flatten() + assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()) + for i in range(label.shape[0]): + labelcount[int(label[i])] += 1 - _test_last_batch_handle(f['data'], f['label']) try: os.remove("ndarraytest.h5") except OSError: pass + for i in range(10): + if i == 0: + assert(labelcount[i] == 124) + else: + assert(labelcount[i] == 100) + def test_NDArrayIter_csr(): # creating toy data @@ -208,20 +200,12 @@ def test_NDArrayIter_csr(): {'data': train_data}, dns, batch_size) except ImportError: pass - # scipy.sparse.csr_matrix with shuffle - num_batch = 0 - csr_iter = iter(mx.io.NDArrayIter({'data': train_data}, dns, batch_size, - shuffle=True, last_batch_handle='discard')) - for _ in csr_iter: - num_batch += 1 - - assert(num_batch == num_rows // batch_size) # CSRNDArray with shuffle csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, batch_size, shuffle=True, last_batch_handle='discard')) num_batch = 0 - for _ in csr_iter: + for batch in csr_iter: num_batch += 1 assert(num_batch == num_rows // batch_size)