Skip to content

Commit b499eff

Browse files
committed
sort formulas and text in a line && bug fix
1 parent a4c9257 commit b499eff

File tree

4 files changed

+176
-21
lines changed

4 files changed

+176
-21
lines changed

paddlex/configs/pipelines/layout_parsing_v2.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ SubModules:
2121
layout_merge_bboxes_mode:
2222
1: "large" # image
2323
18: "large" # chart
24+
7: "large" # formula
2425

2526
SubPipelines:
2627
DocPreprocessor:
@@ -45,7 +46,7 @@ SubPipelines:
4546
SubModules:
4647
TextDetection:
4748
module_name: text_detection
48-
model_name: PP-OCRv4_mobile_det
49+
model_name: PP-OCRv4_server_det
4950
model_dir: null
5051
limit_side_len: 960
5152
limit_type: max

paddlex/inference/pipelines/layout_parsing/pipeline_v2.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -310,14 +310,18 @@ def get_layout_parsing_res(
310310
del overall_ocr_res["rec_polys"][matched_idx]
311311
del overall_ocr_res["rec_scores"][matched_idx]
312312

313-
if sub_ocr_res["rec_boxes"] is not []:
313+
if sub_ocr_res["rec_boxes"].size > 0:
314+
sub_ocr_res["rec_labels"] = [
315+
"text" for _ in range(len(sub_ocr_res["rec_texts"]))
316+
]
314317
overall_ocr_res["dt_polys"].extend(sub_ocr_res["dt_polys"])
315318
overall_ocr_res["rec_texts"].extend(sub_ocr_res["rec_texts"])
316319
overall_ocr_res["rec_boxes"] = np.concatenate(
317320
[overall_ocr_res["rec_boxes"], sub_ocr_res["rec_boxes"]], axis=0
318321
)
319322
overall_ocr_res["rec_polys"].extend(sub_ocr_res["rec_polys"])
320323
overall_ocr_res["rec_scores"].extend(sub_ocr_res["rec_scores"])
324+
overall_ocr_res["rec_labels"].extend(sub_ocr_res["rec_labels"])
321325

322326
for formula_res in formula_res_list:
323327
x_min, y_min, x_max, y_max = list(map(int, formula_res["dt_polys"]))
@@ -332,10 +336,12 @@ def get_layout_parsing_res(
332336
overall_ocr_res["rec_boxes"] = np.vstack(
333337
(overall_ocr_res["rec_boxes"], [formula_res["dt_polys"]])
334338
)
339+
overall_ocr_res["rec_labels"].append("formula")
335340
overall_ocr_res["rec_polys"].append(poly_points)
336341
overall_ocr_res["rec_scores"].append(1)
337342

338343
parsing_res_list = get_single_block_parsing_res(
344+
self.general_ocr_pipeline,
339345
overall_ocr_res=overall_ocr_res,
340346
layout_det_res=layout_det_res,
341347
table_res_list=table_res_list,
@@ -474,6 +480,7 @@ def predict(
474480
yield {"error": "the input params for model settings are invalid!"}
475481

476482
for img_id, batch_data in enumerate(self.batch_sampler(input)):
483+
print(batch_data.input_paths[0])
477484
image_array = self.img_reader(batch_data.instances)[0]
478485

479486
if model_settings["use_doc_preprocessor"]:
@@ -536,6 +543,10 @@ def predict(
536543
else:
537544
overall_ocr_res = {}
538545

546+
overall_ocr_res["rec_labels"] = [
547+
"text" for i in range(len(overall_ocr_res["rec_texts"]))
548+
]
549+
539550
if model_settings["use_table_recognition"]:
540551
table_overall_ocr_res = copy.deepcopy(overall_ocr_res)
541552
for formula_res in formula_res_list:

paddlex/inference/pipelines/layout_parsing/result_v2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def format_table():
312312
"table": format_table,
313313
"reference": lambda: block["block_content"],
314314
"algorithm": lambda: block["block_content"].strip("\n"),
315-
"seal": lambda: format_image("block_content"),
315+
"seal": lambda: f"Words of Seals:\n{block['block_content']}",
316316
}
317317
parsing_res_list = obj["parsing_res_list"]
318318
markdown_content = ""

paddlex/inference/pipelines/layout_parsing/utils.py

+161-18
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import uuid
2727
import re
2828
from pathlib import Path
29+
from copy import deepcopy
2930
from typing import Optional, Union, List, Tuple, Dict, Any
3031
from ..ocr.result import OCRResult
3132
from ...models.object_detection.result import DetResult
@@ -253,6 +254,7 @@ def _adjust_span_text(span: List[str], prepend: bool = False, append: bool = Fal
253254
span[1] = "\n" + span[1]
254255
if append:
255256
span[1] = span[1] + "\n"
257+
return span
256258

257259

258260
def _format_line(
@@ -278,17 +280,129 @@ def _format_line(
278280

279281
if not is_reference:
280282
if first_span[0][0] - layout_min > 10:
281-
_adjust_span_text(first_span, prepend=True)
283+
first_span = _adjust_span_text(first_span, prepend=True)
282284
if layout_max - end_span[0][2] > 10:
283-
_adjust_span_text(end_span, append=True)
285+
end_span = _adjust_span_text(end_span, append=True)
284286
else:
285287
if first_span[0][0] - layout_min < 5:
286-
_adjust_span_text(first_span, prepend=True)
288+
first_span = _adjust_span_text(first_span, prepend=True)
287289
if layout_max - end_span[0][2] > 20:
288-
_adjust_span_text(end_span, append=True)
290+
end_span = _adjust_span_text(end_span, append=True)
291+
292+
line[0] = first_span
293+
line[-1] = end_span
294+
295+
return line
296+
297+
298+
def split_boxes_if_x_contained(boxes, offset=1e-5):
299+
"""
300+
Check if there is any complete containment in the x-direction
301+
between the bounding boxes and split the containing box accordingly.
302+
303+
Args:
304+
boxes (list of lists): Each element is a list containing an ndarray of length 4, a description, and a label.
305+
offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
306+
Returns:
307+
A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
308+
"""
309+
310+
def is_x_contained(box_a, box_b):
311+
"""Check if box_a completely contains box_b in the x-direction."""
312+
return box_a[0][0] <= box_b[0][0] and box_a[0][2] >= box_b[0][2]
313+
314+
new_boxes = []
315+
316+
for i in range(len(boxes)):
317+
box_a = boxes[i]
318+
is_split = False
319+
for j in range(len(boxes)):
320+
if i == j:
321+
continue
322+
box_b = boxes[j]
323+
if is_x_contained(box_a, box_b):
324+
is_split = True
325+
# Split box_a based on the x-coordinates of box_b
326+
if box_a[0][0] < box_b[0][0]:
327+
w = box_b[0][0] - offset - box_a[0][0]
328+
h = box_a[0][3] - box_a[0][1]
329+
if w > h / 2:
330+
new_boxes.append(
331+
[
332+
np.array(
333+
[
334+
box_a[0][0],
335+
box_a[0][1],
336+
box_b[0][0] - offset,
337+
box_a[0][3],
338+
]
339+
),
340+
box_a[1],
341+
box_a[2],
342+
]
343+
)
344+
if box_a[0][2] > box_b[0][2]:
345+
w = box_a[0][2] - box_b[0][2] + offset
346+
h = box_a[0][3] - box_a[0][1]
347+
if w > h / 2:
348+
box_a = [
349+
np.array(
350+
[
351+
box_b[0][2] + offset,
352+
box_a[0][1],
353+
box_a[0][2],
354+
box_a[0][3],
355+
]
356+
),
357+
box_a[1],
358+
box_a[2],
359+
]
360+
if j == len(boxes) - 1 and is_split:
361+
new_boxes.append(box_a)
362+
if not is_split:
363+
new_boxes.append(box_a)
364+
365+
return new_boxes
366+
367+
368+
def _sort_line_by_x_projection(
369+
input_img: np.ndarray,
370+
general_ocr_pipeline: Any,
371+
line: List[List[Union[List[int], str]]],
372+
) -> None:
373+
"""
374+
Sort a line of text spans based on their vertical position within the layout bounding box.
375+
376+
Args:
377+
input_img (ndarray): The input image used for OCR.
378+
general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
379+
line (list): A list of spans, where each span is a list containing a bounding box and text.
380+
381+
Returns:
382+
list: The sorted line of text spans.
383+
"""
384+
splited_boxes = split_boxes_if_x_contained(line)
385+
splited_lines = []
386+
if len(line) != len(splited_boxes):
387+
splited_boxes.sort(key=lambda span: span[0][0])
388+
text_rec_model = general_ocr_pipeline.text_rec_model
389+
for span in splited_boxes:
390+
if span[2] == "text":
391+
crop_img = input_img[
392+
int(span[0][1]) : int(span[0][3]),
393+
int(span[0][0]) : int(span[0][2]),
394+
]
395+
span[1] = next(text_rec_model([crop_img]))["rec_text"]
396+
splited_lines.append(span)
397+
else:
398+
splited_lines = line
399+
400+
return splited_lines
289401

290402

291403
def _sort_ocr_res_by_y_projection(
404+
input_img: np.ndarray,
405+
general_ocr_pipeline: Any,
292406
label: Any,
293407
block_bbox: Tuple[int, int, int, int],
294408
ocr_res: Dict[str, List[Any]],
@@ -298,6 +412,8 @@ def _sort_ocr_res_by_y_projection(
298412
Sorts OCR results based on their spatial arrangement, grouping them into lines and blocks.
299413
300414
Args:
415+
input_img (ndarray): The input image used for OCR.
416+
general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
301417
label (Any): The label associated with the OCR results. It's not used in the function but might be
302418
relevant for other parts of the calling context.
303419
block_bbox (Tuple[int, int, int, int]): A tuple representing the layout bounding box, defined as
@@ -318,12 +434,13 @@ def _sort_ocr_res_by_y_projection(
318434

319435
boxes = ocr_res["boxes"]
320436
rec_texts = ocr_res["rec_texts"]
437+
rec_labels = ocr_res["rec_labels"]
321438

322439
x_min, _, x_max, _ = block_bbox
323440
inline_x_min = min([box[0] for box in boxes])
324441
inline_x_max = max([box[2] for box in boxes])
325442

326-
spans = list(zip(boxes, rec_texts))
443+
spans = list(zip(boxes, rec_texts, rec_labels))
327444

328445
spans.sort(key=lambda span: span[0][1])
329446
spans = [list(span) for span in spans]
@@ -350,16 +467,21 @@ def _sort_ocr_res_by_y_projection(
350467
if current_line:
351468
lines.append(current_line)
352469

470+
new_lines = []
353471
for line in lines:
354472
line.sort(key=lambda span: span[0][0])
473+
474+
ocr_labels = [span[2] for span in line]
475+
if "formula" in ocr_labels:
476+
line = _sort_line_by_x_projection(input_img, general_ocr_pipeline, line)
355477
if label == "reference":
356478
line = _format_line(line, inline_x_min, inline_x_max, is_reference=True)
357479
else:
358480
line = _format_line(line, x_min, x_max)
481+
new_lines.append(line)
359482

360-
# Flatten lines back into a single list for boxes and texts
361-
ocr_res["boxes"] = [span[0] for line in lines for span in line]
362-
ocr_res["rec_texts"] = [span[1] + " " for line in lines for span in line]
483+
ocr_res["boxes"] = [span[0] for line in new_lines for span in line]
484+
ocr_res["rec_texts"] = [span[1] + " " for line in new_lines for span in line]
363485

364486
return ocr_res
365487

@@ -418,6 +540,7 @@ def handle_spaces_(text: str) -> str:
418540

419541

420542
def get_single_block_parsing_res(
543+
general_ocr_pipeline: Any,
421544
overall_ocr_res: OCRResult,
422545
layout_det_res: DetResult,
423546
table_res_list: list,
@@ -452,10 +575,16 @@ def get_single_block_parsing_res(
452575
input_img = overall_ocr_res["doc_preprocessor_res"]["output_img"]
453576
seal_index = 0
454577

455-
for box_info in layout_det_res["boxes"]:
578+
layout_det_res_list, _ = _remove_overlap_blocks(
579+
deepcopy(layout_det_res["boxes"]),
580+
threshold=0.5,
581+
smaller=True,
582+
)
583+
584+
for box_info in layout_det_res_list:
456585
block_bbox = box_info["coordinate"]
457586
label = box_info["label"]
458-
rec_res = {"boxes": [], "rec_texts": [], "flag": False}
587+
rec_res = {"boxes": [], "rec_texts": [], "rec_labels": [], "flag": False}
459588
seg_start_flag = True
460589
seg_end_flag = True
461590

@@ -504,10 +633,15 @@ def get_single_block_parsing_res(
504633
rec_res["rec_texts"].append(
505634
overall_ocr_res["rec_texts"][box_no],
506635
)
636+
rec_res["rec_labels"].append(
637+
overall_ocr_res["rec_labels"][box_no],
638+
)
507639
rec_res["flag"] = True
508640

509641
if rec_res["flag"]:
510-
rec_res = _sort_ocr_res_by_y_projection(label, block_bbox, rec_res, 0.7)
642+
rec_res = _sort_ocr_res_by_y_projection(
643+
input_img, general_ocr_pipeline, label, block_bbox, rec_res, 0.7
644+
)
511645
rec_res_first_bbox = rec_res["boxes"][0]
512646
rec_res_end_bbox = rec_res["boxes"][-1]
513647
if rec_res_first_bbox[0] - block_bbox[0] < 10:
@@ -548,6 +682,20 @@ def get_single_block_parsing_res(
548682
},
549683
)
550684

685+
if len(layout_det_res_list) == 0:
686+
for ocr_rec_box, ocr_rec_text in zip(
687+
overall_ocr_res["rec_boxes"], overall_ocr_res["rec_texts"]
688+
):
689+
single_block_layout_parsing_res.append(
690+
{
691+
"block_label": "text",
692+
"block_content": ocr_rec_text,
693+
"block_bbox": ocr_rec_box,
694+
"seg_start_flag": True,
695+
"seg_end_flag": True,
696+
},
697+
)
698+
551699
single_block_layout_parsing_res = get_layout_ordering(
552700
single_block_layout_parsing_res,
553701
no_mask_labels=[
@@ -910,8 +1058,8 @@ def _remove_overlap_blocks(
9101058
continue
9111059
# Check for overlap and determine which block to remove
9121060
overlap_box_index = _get_minbox_if_overlap_by_ratio(
913-
block1["block_bbox"],
914-
block2["block_bbox"],
1061+
block1["coordinate"],
1062+
block2["coordinate"],
9151063
threshold,
9161064
smaller=smaller,
9171065
)
@@ -1419,11 +1567,6 @@ def get_layout_ordering(
14191567
vision_labels = ["image", "table", "seal", "chart", "figure"]
14201568
vision_title_labels = ["table_title", "chart_title", "figure_title"]
14211569

1422-
parsing_res_list, _ = _remove_overlap_blocks(
1423-
parsing_res_list,
1424-
threshold=0.5,
1425-
smaller=True,
1426-
)
14271570
parsing_res_list, pre_cuts = _get_sub_category(parsing_res_list, title_text_labels)
14281571

14291572
parsing_res_by_pre_cuts_list = []

0 commit comments

Comments
 (0)