Skip to content

Commit

Permalink
Bugfix in NER baselines inference in merging strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
simsa-st committed Mar 7, 2023
1 parent 8aa4d1e commit fb91d58
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 64 deletions.
73 changes: 41 additions & 32 deletions baselines/NER/docile_inference_NER_multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ def get_sorted_field_candidates(ocr_fields):
return fields, clusters


def _join_texts(text1: str, text2: str, separator: str) -> str:
return (
f"{text1}{separator}{text2}"
if text1 != "" and text2 != ""
else text1
if text1 != ""
else text2
)


def merge_text_boxes(text_boxes, merge_strategy="naive"):
# group by fieldtype:
groups = {}
Expand Down Expand Up @@ -263,30 +273,29 @@ def merge_text_boxes(text_boxes, merge_strategy="naive"):
# horizontal merging
after_horizontal_merging = []
for fields in textline_group.values():
not_processed = [fields[0]]
processed = [fields[0]]

while len(not_processed):
new_field = FieldWithGroups.from_dict(not_processed.pop(0).to_dict())
fields = sorted(fields, key=lambda f: f.bbox.left)
processed = []

# Iterate over the fields and try if they can be merged with any of the following fields.
for field_to_process in fields:
if field_to_process in processed:
continue
processed.append(field_to_process)
new_field = FieldWithGroups.from_dict(field_to_process.to_dict())
glued_count = 1
for field in fields:
if field not in processed and field != new_field:
# if (field.bbox.left - new_field.bbox.right) <= (field.bbox.width/len(field.text))*1.5:
if (
field.bbox.left - new_field.bbox.right
) <= field.bbox.height * 1.25:
new_field = dataclasses.replace(
new_field,
bbox=new_field.bbox.union(field.bbox),
score=new_field.score + field.score,
text=f"{new_field.text} {field.text}",
)
processed.append(field)
glued_count += 1
else:
not_processed.append(field)
processed.append(new_field)

if field in processed:
continue
# if (field.bbox.left - new_field.bbox.right) <= (field.bbox.width/len(field.text))*1.5:
if field.bbox.left - new_field.bbox.right <= field.bbox.height * 1.25:
new_field = dataclasses.replace(
new_field,
bbox=new_field.bbox.union(field.bbox),
score=new_field.score + field.score,
text=_join_texts(new_field.text, field.text, " "),
)
processed.append(field)
glued_count += 1
after_horizontal_merging.append((new_field, glued_count))

# vertical merging
Expand All @@ -296,11 +305,14 @@ def merge_text_boxes(text_boxes, merge_strategy="naive"):
for j, (field2, _gc2) in enumerate(after_horizontal_merging):
# ignore the same field (diagonal in the adjacency matrix)
if field1 != field2:
y_dist = abs(field2.bbox.top - field1.bbox.bottom)
if field1.bbox.left < field2.bbox.left: # field1, ..., field2
x_dist = field2.bbox.left - field1.bbox.right
else: # field2, ..., field1
x_dist = field1.bbox.left - field2.bbox.right
y_dist = max(
field2.bbox.top - field1.bbox.bottom,
field1.bbox.top - field2.bbox.bottom,
)
x_dist = max(
field2.bbox.left - field1.bbox.right,
field1.bbox.left - field2.bbox.right,
)
# if (y_dist < field1.bbox.height*1.2):
if (y_dist < field1.bbox.height * 1.2) and (
x_dist <= field1.bbox.height * 1.25
Expand All @@ -318,6 +330,7 @@ def merge_text_boxes(text_boxes, merge_strategy="naive"):
#
# Merge found components
for idxs in components.values():
idxs = sorted(idxs, key=lambda i: after_horizontal_merging[i][0].bbox.top)
tmp_field = after_horizontal_merging[idxs[0]][0]
new_field = FieldWithGroups(
fieldtype=tmp_field.fieldtype,
Expand All @@ -334,11 +347,7 @@ def merge_text_boxes(text_boxes, merge_strategy="naive"):
new_field,
bbox=new_field.bbox.union(field.bbox),
score=new_field.score + field.score,
text=(
f"{new_field.text}\n{field.text}"
if new_field.text
else f"{field.text}"
),
text=_join_texts(new_field.text, field.text, "\n"),
)
glued_count += gc
new_field = dataclasses.replace(new_field, score=new_field.score / glued_count)
Expand Down
73 changes: 41 additions & 32 deletions baselines/NER/docile_inference_NER_multilabel_layoutLMv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ def get_sorted_field_candidates(ocr_fields):
return fields, clusters


def _join_texts(text1: str, text2: str, separator: str) -> str:
return (
f"{text1}{separator}{text2}"
if text1 != "" and text2 != ""
else text1
if text1 != ""
else text2
)


def merge_text_boxes(text_boxes, merge_strategy="naive"):
# group by fieldtype:
groups = {}
Expand Down Expand Up @@ -274,30 +284,29 @@ def merge_text_boxes(text_boxes, merge_strategy="naive"):
# horizontal merging
after_horizontal_merging = []
for fields in textline_group.values():
not_processed = [fields[0]]
processed = [fields[0]]

while len(not_processed):
new_field = FieldWithGroups.from_dict(not_processed.pop(0).to_dict())
fields = sorted(fields, key=lambda f: f.bbox.left)
processed = []

# Iterate over the fields and try if they can be merged with any of the following fields.
for field_to_process in fields:
if field_to_process in processed:
continue
processed.append(field_to_process)
new_field = FieldWithGroups.from_dict(field_to_process.to_dict())
glued_count = 1
for field in fields:
if field not in processed and field != new_field:
# if (field.bbox.left - new_field.bbox.right) <= (field.bbox.width/len(field.text))*1.5:
if (
field.bbox.left - new_field.bbox.right
) <= field.bbox.height * 1.25:
new_field = dataclasses.replace(
new_field,
bbox=new_field.bbox.union(field.bbox),
score=new_field.score + field.score,
text=f"{new_field.text} {field.text}",
)
processed.append(field)
glued_count += 1
else:
not_processed.append(field)
processed.append(new_field)

if field in processed:
continue
# if (field.bbox.left - new_field.bbox.right) <= (field.bbox.width/len(field.text))*1.5:
if field.bbox.left - new_field.bbox.right <= field.bbox.height * 1.25:
new_field = dataclasses.replace(
new_field,
bbox=new_field.bbox.union(field.bbox),
score=new_field.score + field.score,
text=_join_texts(new_field.text, field.text, " "),
)
processed.append(field)
glued_count += 1
after_horizontal_merging.append((new_field, glued_count))

# vertical merging
Expand All @@ -307,11 +316,14 @@ def merge_text_boxes(text_boxes, merge_strategy="naive"):
for j, (field2, _gc2) in enumerate(after_horizontal_merging):
# ignore the same field (diagonal in the adjacency matrix)
if field1 != field2:
y_dist = abs(field2.bbox.top - field1.bbox.bottom)
if field1.bbox.left < field2.bbox.left: # field1, ..., field2
x_dist = field2.bbox.left - field1.bbox.right
else: # field2, ..., field1
x_dist = field1.bbox.left - field2.bbox.right
y_dist = max(
field2.bbox.top - field1.bbox.bottom,
field1.bbox.top - field2.bbox.bottom,
)
x_dist = max(
field2.bbox.left - field1.bbox.right,
field1.bbox.left - field2.bbox.right,
)
# if (y_dist < field1.bbox.height*1.2):
if (y_dist < field1.bbox.height * 1.2) and (
x_dist <= field1.bbox.height * 1.25
Expand All @@ -329,6 +341,7 @@ def merge_text_boxes(text_boxes, merge_strategy="naive"):
#
# Merge found components
for idxs in components.values():
idxs = sorted(idxs, key=lambda i: after_horizontal_merging[i][0].bbox.top)
tmp_field = after_horizontal_merging[idxs[0]][0]
new_field = FieldWithGroups(
fieldtype=tmp_field.fieldtype,
Expand All @@ -345,11 +358,7 @@ def merge_text_boxes(text_boxes, merge_strategy="naive"):
new_field,
bbox=new_field.bbox.union(field.bbox),
score=new_field.score + field.score,
text=(
f"{new_field.text}\n{field.text}"
if new_field.text
else f"{field.text}"
),
text=_join_texts(new_field.text, field.text, "\n"),
)
glued_count += gc
new_field = dataclasses.replace(new_field, score=new_field.score / glued_count)
Expand Down

0 comments on commit fb91d58

Please sign in to comment.