Skip to content

Commit d8e1199

Browse files
committed
sort formulas and text in a line && bug fix
1 parent a4c9257 commit d8e1199

File tree

4 files changed

+171
-22
lines changed

4 files changed

+171
-22
lines changed

paddlex/configs/pipelines/layout_parsing_v2.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ SubModules:
2121
layout_merge_bboxes_mode:
2222
1: "large" # image
2323
18: "large" # chart
24+
7: "large" # formula
2425

2526
SubPipelines:
2627
DocPreprocessor:
@@ -45,7 +46,7 @@ SubPipelines:
4546
SubModules:
4647
TextDetection:
4748
module_name: text_detection
48-
model_name: PP-OCRv4_mobile_det
49+
model_name: PP-OCRv4_server_det
4950
model_dir: null
5051
limit_side_len: 960
5152
limit_type: max

paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -310,14 +310,17 @@ def get_layout_parsing_res(
310310
del overall_ocr_res["rec_polys"][matched_idx]
311311
del overall_ocr_res["rec_scores"][matched_idx]
312312

313-
if sub_ocr_res["rec_boxes"] is not []:
313+
if sub_ocr_res["rec_boxes"].size > 0:
314+
sub_ocr_res["rec_labels"] = ["text"] * len(sub_ocr_res["rec_texts"])
315+
314316
overall_ocr_res["dt_polys"].extend(sub_ocr_res["dt_polys"])
315317
overall_ocr_res["rec_texts"].extend(sub_ocr_res["rec_texts"])
316318
overall_ocr_res["rec_boxes"] = np.concatenate(
317319
[overall_ocr_res["rec_boxes"], sub_ocr_res["rec_boxes"]], axis=0
318320
)
319321
overall_ocr_res["rec_polys"].extend(sub_ocr_res["rec_polys"])
320322
overall_ocr_res["rec_scores"].extend(sub_ocr_res["rec_scores"])
323+
overall_ocr_res["rec_labels"].extend(sub_ocr_res["rec_labels"])
321324

322325
for formula_res in formula_res_list:
323326
x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
@@ -332,10 +335,12 @@ def get_layout_parsing_res(
332335
overall_ocr_res["rec_boxes"] = np.vstack(
333336
(overall_ocr_res["rec_boxes"], [formula_res["dt_polys"]])
334337
)
338+
overall_ocr_res["rec_labels"].append("formula")
335339
overall_ocr_res["rec_polys"].append(poly_points)
336340
overall_ocr_res["rec_scores"].append(1)
337341

338342
parsing_res_list = get_single_block_parsing_res(
343+
self.general_ocr_pipeline,
339344
overall_ocr_res=overall_ocr_res,
340345
layout_det_res=layout_det_res,
341346
table_res_list=table_res_list,
@@ -473,7 +478,7 @@ def predict(
473478
if not self.check_model_settings_valid(model_settings):
474479
yield {"error": "the input params for model settings are invalid!"}
475480

476-
for img_id, batch_data in enumerate(self.batch_sampler(input)):
481+
for batch_data in self.batch_sampler(input):
477482
image_array = self.img_reader(batch_data.instances)[0]
478483

479484
if model_settings["use_doc_preprocessor"]:
@@ -536,6 +541,8 @@ def predict(
536541
else:
537542
overall_ocr_res = {}
538543

544+
overall_ocr_res["rec_labels"] = ["text"] * len(overall_ocr_res["rec_texts"])
545+
539546
if model_settings["use_table_recognition"]:
540547
table_overall_ocr_res = copy.deepcopy(overall_ocr_res)
541548
for formula_res in formula_res_list:

paddlex/inference/pipelines/layout_parsing/result_v2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def format_table():
312312
"table": format_table,
313313
"reference": lambda: block["block_content"],
314314
"algorithm": lambda: block["block_content"].strip("\n"),
315-
"seal": lambda: format_image("block_content"),
315+
"seal": lambda: f"Words of Seals:\n{block['block_content']}",
316316
}
317317
parsing_res_list = obj["parsing_res_list"]
318318
markdown_content = ""

paddlex/inference/pipelines/layout_parsing/utils.py

+159-18
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import uuid
2727
import re
2828
from pathlib import Path
29+
from copy import deepcopy
2930
from typing import Optional, Union, List, Tuple, Dict, Any
3031
from ..ocr.result import OCRResult
3132
from ...models.object_detection.result import DetResult
@@ -253,6 +254,7 @@ def _adjust_span_text(span: List[str], prepend: bool = False, append: bool = Fal
253254
span[1] = "\n" + span[1]
254255
if append:
255256
span[1] = span[1] + "\n"
257+
return span
256258

257259

258260
def _format_line(
@@ -278,17 +280,127 @@ def _format_line(
278280

279281
if not is_reference:
280282
if first_span[0][0] - layout_min > 10:
281-
_adjust_span_text(first_span, prepend=True)
283+
first_span = _adjust_span_text(first_span, prepend=True)
282284
if layout_max - end_span[0][2] > 10:
283-
_adjust_span_text(end_span, append=True)
285+
end_span = _adjust_span_text(end_span, append=True)
284286
else:
285287
if first_span[0][0] - layout_min < 5:
286-
_adjust_span_text(first_span, prepend=True)
288+
first_span = _adjust_span_text(first_span, prepend=True)
287289
if layout_max - end_span[0][2] > 20:
288-
_adjust_span_text(end_span, append=True)
290+
end_span = _adjust_span_text(end_span, append=True)
291+
292+
line[0] = first_span
293+
line[-1] = end_span
294+
295+
return line
296+
297+
298+
def split_boxes_if_x_contained(boxes, offset=1e-5):
299+
"""
300+
Check if there is any complete containment in the x-direction
301+
between the bounding boxes and split the containing box accordingly.
302+
303+
Args:
304+
boxes (list of lists): Each element is a list containing an ndarray of length 4, a description, and a label.
305+
offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
306+
Returns:
307+
A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
308+
"""
309+
310+
def is_x_contained(box_a, box_b):
311+
"""Check if box_a completely contains box_b in the x-direction."""
312+
return box_a[0][0] <= box_b[0][0] and box_a[0][2] >= box_b[0][2]
313+
314+
new_boxes = []
315+
316+
for i in range(len(boxes)):
317+
box_a = boxes[i]
318+
is_split = False
319+
for j in range(len(boxes)):
320+
if i == j:
321+
continue
322+
box_b = boxes[j]
323+
if is_x_contained(box_a, box_b):
324+
is_split = True
325+
# Split box_a based on the x-coordinates of box_b
326+
if box_a[0][0] < box_b[0][0]:
327+
w = box_b[0][0] - offset - box_a[0][0]
328+
if w > 1:
329+
new_boxes.append(
330+
[
331+
np.array(
332+
[
333+
box_a[0][0],
334+
box_a[0][1],
335+
box_b[0][0] - offset,
336+
box_a[0][3],
337+
]
338+
),
339+
box_a[1],
340+
box_a[2],
341+
]
342+
)
343+
if box_a[0][2] > box_b[0][2]:
344+
w = box_a[0][2] - box_b[0][2] + offset
345+
if w > 1:
346+
box_a = [
347+
np.array(
348+
[
349+
box_b[0][2] + offset,
350+
box_a[0][1],
351+
box_a[0][2],
352+
box_a[0][3],
353+
]
354+
),
355+
box_a[1],
356+
box_a[2],
357+
]
358+
if j == len(boxes) - 1 and is_split:
359+
new_boxes.append(box_a)
360+
if not is_split:
361+
new_boxes.append(box_a)
362+
363+
return new_boxes
364+
365+
366+
def _sort_line_by_x_projection(
367+
input_img: np.ndarray,
368+
general_ocr_pipeline: Any,
369+
line: List[List[Union[List[int], str]]],
370+
) -> None:
371+
"""
372+
Sort a line of text spans based on their vertical position within the layout bounding box.
373+
374+
Args:
375+
input_img (ndarray): The input image used for OCR.
376+
general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
377+
line (list): A list of spans, where each span is a list containing a bounding box and text.
378+
379+
Returns:
380+
list: The sorted line of text spans.
381+
"""
382+
splited_boxes = split_boxes_if_x_contained(line)
383+
splited_lines = []
384+
if len(line) != len(splited_boxes):
385+
splited_boxes.sort(key=lambda span: span[0][0])
386+
text_rec_model = general_ocr_pipeline.text_rec_model
387+
for span in splited_boxes:
388+
if span[2] == "text":
389+
crop_img = input_img[
390+
int(span[0][1]) : int(span[0][3]),
391+
int(span[0][0]) : int(span[0][2]),
392+
]
393+
span[1] = next(text_rec_model([crop_img]))["rec_text"]
394+
splited_lines.append(span)
395+
else:
396+
splited_lines = line
397+
398+
return splited_lines
289399

290400

291401
def _sort_ocr_res_by_y_projection(
402+
input_img: np.ndarray,
403+
general_ocr_pipeline: Any,
292404
label: Any,
293405
block_bbox: Tuple[int, int, int, int],
294406
ocr_res: Dict[str, List[Any]],
@@ -298,6 +410,8 @@ def _sort_ocr_res_by_y_projection(
298410
Sorts OCR results based on their spatial arrangement, grouping them into lines and blocks.
299411
300412
Args:
413+
input_img (ndarray): The input image used for OCR.
414+
general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
301415
label (Any): The label associated with the OCR results. It's not used in the function but might be
302416
relevant for other parts of the calling context.
303417
block_bbox (Tuple[int, int, int, int]): A tuple representing the layout bounding box, defined as
@@ -318,12 +432,13 @@ def _sort_ocr_res_by_y_projection(
318432

319433
boxes = ocr_res["boxes"]
320434
rec_texts = ocr_res["rec_texts"]
435+
rec_labels = ocr_res["rec_labels"]
321436

322437
x_min, _, x_max, _ = block_bbox
323438
inline_x_min = min([box[0] for box in boxes])
324439
inline_x_max = max([box[2] for box in boxes])
325440

326-
spans = list(zip(boxes, rec_texts))
441+
spans = list(zip(boxes, rec_texts, rec_labels))
327442

328443
spans.sort(key=lambda span: span[0][1])
329444
spans = [list(span) for span in spans]
@@ -350,16 +465,21 @@ def _sort_ocr_res_by_y_projection(
350465
if current_line:
351466
lines.append(current_line)
352467

468+
new_lines = []
353469
for line in lines:
354470
line.sort(key=lambda span: span[0][0])
471+
472+
ocr_labels = [span[2] for span in line]
473+
if "formula" in ocr_labels:
474+
line = _sort_line_by_x_projection(input_img, general_ocr_pipeline, line)
355475
if label == "reference":
356476
line = _format_line(line, inline_x_min, inline_x_max, is_reference=True)
357477
else:
358478
line = _format_line(line, x_min, x_max)
479+
new_lines.append(line)
359480

360-
# Flatten lines back into a single list for boxes and texts
361-
ocr_res["boxes"] = [span[0] for line in lines for span in line]
362-
ocr_res["rec_texts"] = [span[1] + " " for line in lines for span in line]
481+
ocr_res["boxes"] = [span[0] for line in new_lines for span in line]
482+
ocr_res["rec_texts"] = [span[1] + " " for line in new_lines for span in line]
363483

364484
return ocr_res
365485

@@ -418,6 +538,7 @@ def handle_spaces_(text: str) -> str:
418538

419539

420540
def get_single_block_parsing_res(
541+
general_ocr_pipeline: Any,
421542
overall_ocr_res: OCRResult,
422543
layout_det_res: DetResult,
423544
table_res_list: list,
@@ -452,10 +573,16 @@ def get_single_block_parsing_res(
452573
input_img = overall_ocr_res["doc_preprocessor_res"]["output_img"]
453574
seal_index = 0
454575

455-
for box_info in layout_det_res["boxes"]:
576+
layout_det_res_list, _ = _remove_overlap_blocks(
577+
deepcopy(layout_det_res["boxes"]),
578+
threshold=0.5,
579+
smaller=True,
580+
)
581+
582+
for box_info in layout_det_res_list:
456583
block_bbox = box_info["coordinate"]
457584
label = box_info["label"]
458-
rec_res = {"boxes": [], "rec_texts": [], "flag": False}
585+
rec_res = {"boxes": [], "rec_texts": [], "rec_labels": [], "flag": False}
459586
seg_start_flag = True
460587
seg_end_flag = True
461588

@@ -504,10 +631,15 @@ def get_single_block_parsing_res(
504631
rec_res["rec_texts"].append(
505632
overall_ocr_res["rec_texts"][box_no],
506633
)
634+
rec_res["rec_labels"].append(
635+
overall_ocr_res["rec_labels"][box_no],
636+
)
507637
rec_res["flag"] = True
508638

509639
if rec_res["flag"]:
510-
rec_res = _sort_ocr_res_by_y_projection(label, block_bbox, rec_res, 0.7)
640+
rec_res = _sort_ocr_res_by_y_projection(
641+
input_img, general_ocr_pipeline, label, block_bbox, rec_res, 0.7
642+
)
511643
rec_res_first_bbox = rec_res["boxes"][0]
512644
rec_res_end_bbox = rec_res["boxes"][-1]
513645
if rec_res_first_bbox[0] - block_bbox[0] < 10:
@@ -548,6 +680,20 @@ def get_single_block_parsing_res(
548680
},
549681
)
550682

683+
if len(layout_det_res_list) == 0:
684+
for ocr_rec_box, ocr_rec_text in zip(
685+
overall_ocr_res["rec_boxes"], overall_ocr_res["rec_texts"]
686+
):
687+
single_block_layout_parsing_res.append(
688+
{
689+
"block_label": "text",
690+
"block_content": ocr_rec_text,
691+
"block_bbox": ocr_rec_box,
692+
"seg_start_flag": True,
693+
"seg_end_flag": True,
694+
},
695+
)
696+
551697
single_block_layout_parsing_res = get_layout_ordering(
552698
single_block_layout_parsing_res,
553699
no_mask_labels=[
@@ -910,8 +1056,8 @@ def _remove_overlap_blocks(
9101056
continue
9111057
# Check for overlap and determine which block to remove
9121058
overlap_box_index = _get_minbox_if_overlap_by_ratio(
913-
block1["block_bbox"],
914-
block2["block_bbox"],
1059+
block1["coordinate"],
1060+
block2["coordinate"],
9151061
threshold,
9161062
smaller=smaller,
9171063
)
@@ -1419,11 +1565,6 @@ def get_layout_ordering(
14191565
vision_labels = ["image", "table", "seal", "chart", "figure"]
14201566
vision_title_labels = ["table_title", "chart_title", "figure_title"]
14211567

1422-
parsing_res_list, _ = _remove_overlap_blocks(
1423-
parsing_res_list,
1424-
threshold=0.5,
1425-
smaller=True,
1426-
)
14271568
parsing_res_list, pre_cuts = _get_sub_category(parsing_res_list, title_text_labels)
14281569

14291570
parsing_res_by_pre_cuts_list = []

0 commit comments

Comments
 (0)