Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sort formulas and text in a line && bug fix #3568

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddlex/configs/pipelines/layout_parsing_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ SubModules:
layout_merge_bboxes_mode:
1: "large" # image
18: "large" # chart
7: "large" # formula

SubPipelines:
DocPreprocessor:
Expand All @@ -45,7 +46,7 @@ SubPipelines:
SubModules:
TextDetection:
module_name: text_detection
model_name: PP-OCRv4_mobile_det
model_name: PP-OCRv4_server_det
model_dir: null
limit_side_len: 960
limit_type: max
Expand Down
11 changes: 9 additions & 2 deletions paddlex/inference/pipelines/layout_parsing/pipeline_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,17 @@ def get_layout_parsing_res(
del overall_ocr_res["rec_polys"][matched_idx]
del overall_ocr_res["rec_scores"][matched_idx]

if sub_ocr_res["rec_boxes"] is not []:
if sub_ocr_res["rec_boxes"].size > 0:
sub_ocr_res["rec_labels"] = ["text"] * len(sub_ocr_res["rec_texts"])

overall_ocr_res["dt_polys"].extend(sub_ocr_res["dt_polys"])
overall_ocr_res["rec_texts"].extend(sub_ocr_res["rec_texts"])
overall_ocr_res["rec_boxes"] = np.concatenate(
[overall_ocr_res["rec_boxes"], sub_ocr_res["rec_boxes"]], axis=0
)
overall_ocr_res["rec_polys"].extend(sub_ocr_res["rec_polys"])
overall_ocr_res["rec_scores"].extend(sub_ocr_res["rec_scores"])
overall_ocr_res["rec_labels"].extend(sub_ocr_res["rec_labels"])

for formula_res in formula_res_list:
x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
Expand All @@ -332,10 +335,12 @@ def get_layout_parsing_res(
overall_ocr_res["rec_boxes"] = np.vstack(
(overall_ocr_res["rec_boxes"], [formula_res["dt_polys"]])
)
overall_ocr_res["rec_labels"].append("formula")
overall_ocr_res["rec_polys"].append(poly_points)
overall_ocr_res["rec_scores"].append(1)

parsing_res_list = get_single_block_parsing_res(
self.general_ocr_pipeline,
overall_ocr_res=overall_ocr_res,
layout_det_res=layout_det_res,
table_res_list=table_res_list,
Expand Down Expand Up @@ -473,7 +478,7 @@ def predict(
if not self.check_model_settings_valid(model_settings):
yield {"error": "the input params for model settings are invalid!"}

for img_id, batch_data in enumerate(self.batch_sampler(input)):
for batch_data in self.batch_sampler(input):
image_array = self.img_reader(batch_data.instances)[0]

if model_settings["use_doc_preprocessor"]:
Expand Down Expand Up @@ -536,6 +541,8 @@ def predict(
else:
overall_ocr_res = {}

overall_ocr_res["rec_labels"] = ["text"] * len(overall_ocr_res["rec_texts"])

if model_settings["use_table_recognition"]:
table_overall_ocr_res = copy.deepcopy(overall_ocr_res)
for formula_res in formula_res_list:
Expand Down
2 changes: 1 addition & 1 deletion paddlex/inference/pipelines/layout_parsing/result_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def format_table():
"table": format_table,
"reference": lambda: block["block_content"],
"algorithm": lambda: block["block_content"].strip("\n"),
"seal": lambda: format_image("block_content"),
"seal": lambda: f"Words of Seals:\n{block['block_content']}",
}
parsing_res_list = obj["parsing_res_list"]
markdown_content = ""
Expand Down
177 changes: 159 additions & 18 deletions paddlex/inference/pipelines/layout_parsing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import uuid
import re
from pathlib import Path
from copy import deepcopy
from typing import Optional, Union, List, Tuple, Dict, Any
from ..ocr.result import OCRResult
from ...models.object_detection.result import DetResult
Expand Down Expand Up @@ -253,6 +254,7 @@ def _adjust_span_text(span: List[str], prepend: bool = False, append: bool = Fal
span[1] = "\n" + span[1]
if append:
span[1] = span[1] + "\n"
return span


def _format_line(
Expand All @@ -278,17 +280,127 @@ def _format_line(

if not is_reference:
if first_span[0][0] - layout_min > 10:
_adjust_span_text(first_span, prepend=True)
first_span = _adjust_span_text(first_span, prepend=True)
if layout_max - end_span[0][2] > 10:
_adjust_span_text(end_span, append=True)
end_span = _adjust_span_text(end_span, append=True)
else:
if first_span[0][0] - layout_min < 5:
_adjust_span_text(first_span, prepend=True)
first_span = _adjust_span_text(first_span, prepend=True)
if layout_max - end_span[0][2] > 20:
_adjust_span_text(end_span, append=True)
end_span = _adjust_span_text(end_span, append=True)

line[0] = first_span
line[-1] = end_span

return line


def split_boxes_if_x_contained(boxes, offset=1e-5):
"""
Check if there is any complete containment in the x-direction
between the bounding boxes and split the containing box accordingly.

Args:
boxes (list of lists): Each element is a list containing an ndarray of length 4, a description, and a label.
offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
Returns:
A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
"""

def is_x_contained(box_a, box_b):
"""Check if box_a completely contains box_b in the x-direction."""
return box_a[0][0] <= box_b[0][0] and box_a[0][2] >= box_b[0][2]

new_boxes = []

for i in range(len(boxes)):
box_a = boxes[i]
is_split = False
for j in range(len(boxes)):
if i == j:
continue
box_b = boxes[j]
if is_x_contained(box_a, box_b):
is_split = True
# Split box_a based on the x-coordinates of box_b
if box_a[0][0] < box_b[0][0]:
w = box_b[0][0] - offset - box_a[0][0]
if w > 1:
new_boxes.append(
[
np.array(
[
box_a[0][0],
box_a[0][1],
box_b[0][0] - offset,
box_a[0][3],
]
),
box_a[1],
box_a[2],
]
)
if box_a[0][2] > box_b[0][2]:
w = box_a[0][2] - box_b[0][2] + offset
if w > 1:
box_a = [
np.array(
[
box_b[0][2] + offset,
box_a[0][1],
box_a[0][2],
box_a[0][3],
]
),
box_a[1],
box_a[2],
]
if j == len(boxes) - 1 and is_split:
new_boxes.append(box_a)
if not is_split:
new_boxes.append(box_a)

return new_boxes


def _sort_line_by_x_projection(
input_img: np.ndarray,
general_ocr_pipeline: Any,
line: List[List[Union[List[int], str]]],
) -> None:
"""
Sort a line of text spans based on their vertical position within the layout bounding box.

Args:
input_img (ndarray): The input image used for OCR.
general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
line (list): A list of spans, where each span is a list containing a bounding box and text.

Returns:
list: The sorted line of text spans.
"""
splited_boxes = split_boxes_if_x_contained(line)
splited_lines = []
if len(line) != len(splited_boxes):
splited_boxes.sort(key=lambda span: span[0][0])
text_rec_model = general_ocr_pipeline.text_rec_model
for span in splited_boxes:
if span[2] == "text":
crop_img = input_img[
int(span[0][1]) : int(span[0][3]),
int(span[0][0]) : int(span[0][2]),
]
span[1] = next(text_rec_model([crop_img]))["rec_text"]
splited_lines.append(span)
else:
splited_lines = line

return splited_lines


def _sort_ocr_res_by_y_projection(
input_img: np.ndarray,
general_ocr_pipeline: Any,
label: Any,
block_bbox: Tuple[int, int, int, int],
ocr_res: Dict[str, List[Any]],
Expand All @@ -298,6 +410,8 @@ def _sort_ocr_res_by_y_projection(
Sorts OCR results based on their spatial arrangement, grouping them into lines and blocks.

Args:
input_img (ndarray): The input image used for OCR.
general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
label (Any): The label associated with the OCR results. It's not used in the function but might be
relevant for other parts of the calling context.
block_bbox (Tuple[int, int, int, int]): A tuple representing the layout bounding box, defined as
Expand All @@ -318,12 +432,13 @@ def _sort_ocr_res_by_y_projection(

boxes = ocr_res["boxes"]
rec_texts = ocr_res["rec_texts"]
rec_labels = ocr_res["rec_labels"]

x_min, _, x_max, _ = block_bbox
inline_x_min = min([box[0] for box in boxes])
inline_x_max = max([box[2] for box in boxes])

spans = list(zip(boxes, rec_texts))
spans = list(zip(boxes, rec_texts, rec_labels))

spans.sort(key=lambda span: span[0][1])
spans = [list(span) for span in spans]
Expand All @@ -350,16 +465,21 @@ def _sort_ocr_res_by_y_projection(
if current_line:
lines.append(current_line)

new_lines = []
for line in lines:
line.sort(key=lambda span: span[0][0])

ocr_labels = [span[2] for span in line]
if "formula" in ocr_labels:
line = _sort_line_by_x_projection(input_img, general_ocr_pipeline, line)
if label == "reference":
line = _format_line(line, inline_x_min, inline_x_max, is_reference=True)
else:
line = _format_line(line, x_min, x_max)
new_lines.append(line)

# Flatten lines back into a single list for boxes and texts
ocr_res["boxes"] = [span[0] for line in lines for span in line]
ocr_res["rec_texts"] = [span[1] + " " for line in lines for span in line]
ocr_res["boxes"] = [span[0] for line in new_lines for span in line]
ocr_res["rec_texts"] = [span[1] + " " for line in new_lines for span in line]

return ocr_res

Expand Down Expand Up @@ -418,6 +538,7 @@ def handle_spaces_(text: str) -> str:


def get_single_block_parsing_res(
general_ocr_pipeline: Any,
overall_ocr_res: OCRResult,
layout_det_res: DetResult,
table_res_list: list,
Expand Down Expand Up @@ -452,10 +573,16 @@ def get_single_block_parsing_res(
input_img = overall_ocr_res["doc_preprocessor_res"]["output_img"]
seal_index = 0

for box_info in layout_det_res["boxes"]:
layout_det_res_list, _ = _remove_overlap_blocks(
deepcopy(layout_det_res["boxes"]),
threshold=0.5,
smaller=True,
)

for box_info in layout_det_res_list:
block_bbox = box_info["coordinate"]
label = box_info["label"]
rec_res = {"boxes": [], "rec_texts": [], "flag": False}
rec_res = {"boxes": [], "rec_texts": [], "rec_labels": [], "flag": False}
seg_start_flag = True
seg_end_flag = True

Expand Down Expand Up @@ -504,10 +631,15 @@ def get_single_block_parsing_res(
rec_res["rec_texts"].append(
overall_ocr_res["rec_texts"][box_no],
)
rec_res["rec_labels"].append(
overall_ocr_res["rec_labels"][box_no],
)
rec_res["flag"] = True

if rec_res["flag"]:
rec_res = _sort_ocr_res_by_y_projection(label, block_bbox, rec_res, 0.7)
rec_res = _sort_ocr_res_by_y_projection(
input_img, general_ocr_pipeline, label, block_bbox, rec_res, 0.7
)
rec_res_first_bbox = rec_res["boxes"][0]
rec_res_end_bbox = rec_res["boxes"][-1]
if rec_res_first_bbox[0] - block_bbox[0] < 10:
Expand Down Expand Up @@ -548,6 +680,20 @@ def get_single_block_parsing_res(
},
)

if len(layout_det_res_list) == 0:
for ocr_rec_box, ocr_rec_text in zip(
overall_ocr_res["rec_boxes"], overall_ocr_res["rec_texts"]
):
single_block_layout_parsing_res.append(
{
"block_label": "text",
"block_content": ocr_rec_text,
"block_bbox": ocr_rec_box,
"seg_start_flag": True,
"seg_end_flag": True,
},
)

single_block_layout_parsing_res = get_layout_ordering(
single_block_layout_parsing_res,
no_mask_labels=[
Expand Down Expand Up @@ -910,8 +1056,8 @@ def _remove_overlap_blocks(
continue
# Check for overlap and determine which block to remove
overlap_box_index = _get_minbox_if_overlap_by_ratio(
block1["block_bbox"],
block2["block_bbox"],
block1["coordinate"],
block2["coordinate"],
threshold,
smaller=smaller,
)
Expand Down Expand Up @@ -1419,11 +1565,6 @@ def get_layout_ordering(
vision_labels = ["image", "table", "seal", "chart", "figure"]
vision_title_labels = ["table_title", "chart_title", "figure_title"]

parsing_res_list, _ = _remove_overlap_blocks(
parsing_res_list,
threshold=0.5,
smaller=True,
)
parsing_res_list, pre_cuts = _get_sub_category(parsing_res_list, title_text_labels)

parsing_res_by_pre_cuts_list = []
Expand Down