diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index da637858f2..bfec013db2 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -358,11 +358,11 @@ def build_target( # Draw polygon on gt if it is valid if len(shrinked) == 0: - seg_mask[box[1]: box[3] + 1, box[0]: box[2] + 1] = False + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False continue shrinked = np.array(shrinked[0]).reshape(-1, 2) if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid: - seg_mask[box[1]: box[3] + 1, box[0]: box[2] + 1] = False + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False continue cv2.fillPoly(seg_target[idx], [shrinked.astype(np.int32)], 1)