diff --git a/tasks/ofa_task.py b/tasks/ofa_task.py index 2771d351..45fc1fa8 100644 --- a/tasks/ofa_task.py +++ b/tasks/ofa_task.py @@ -132,7 +132,7 @@ def get_batch_iterator( total_row_count = dataset.dataset.get_total_row_count() num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences) if len(batch_sampler) < num_batches: - batch_sampler.append([]) + batch_sampler.append([1]) # return a reusable, sharded iterator epoch_iter = iterators.EpochBatchIterator(