26
26
import uuid
27
27
import re
28
28
from pathlib import Path
29
+ from copy import deepcopy
29
30
from typing import Optional , Union , List , Tuple , Dict , Any
30
31
from ..ocr .result import OCRResult
31
32
from ...models .object_detection .result import DetResult
@@ -253,6 +254,7 @@ def _adjust_span_text(span: List[str], prepend: bool = False, append: bool = Fal
253
254
span [1 ] = "\n " + span [1 ]
254
255
if append :
255
256
span [1 ] = span [1 ] + "\n "
257
+ return span
256
258
257
259
258
260
def _format_line (
@@ -278,17 +280,127 @@ def _format_line(
278
280
279
281
if not is_reference :
280
282
if first_span [0 ][0 ] - layout_min > 10 :
281
- _adjust_span_text (first_span , prepend = True )
283
+ first_span = _adjust_span_text (first_span , prepend = True )
282
284
if layout_max - end_span [0 ][2 ] > 10 :
283
- _adjust_span_text (end_span , append = True )
285
+ end_span = _adjust_span_text (end_span , append = True )
284
286
else :
285
287
if first_span [0 ][0 ] - layout_min < 5 :
286
- _adjust_span_text (first_span , prepend = True )
288
+ first_span = _adjust_span_text (first_span , prepend = True )
287
289
if layout_max - end_span [0 ][2 ] > 20 :
288
- _adjust_span_text (end_span , append = True )
290
+ end_span = _adjust_span_text (end_span , append = True )
291
+
292
+ line [0 ] = first_span
293
+ line [- 1 ] = end_span
294
+
295
+ return line
296
+
297
+
298
+ def split_boxes_if_x_contained (boxes , offset = 1e-5 ):
299
+ """
300
+ Check if there is any complete containment in the x-direction
301
+ between the bounding boxes and split the containing box accordingly.
302
+
303
+ Args:
304
+ boxes (list of lists): Each element is a list containing an ndarray of length 4, a description, and a label.
305
+ offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
306
+ Returns:
307
+ A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
308
+ """
309
+
310
+ def is_x_contained (box_a , box_b ):
311
+ """Check if box_a completely contains box_b in the x-direction."""
312
+ return box_a [0 ][0 ] <= box_b [0 ][0 ] and box_a [0 ][2 ] >= box_b [0 ][2 ]
313
+
314
+ new_boxes = []
315
+
316
+ for i in range (len (boxes )):
317
+ box_a = boxes [i ]
318
+ is_split = False
319
+ for j in range (len (boxes )):
320
+ if i == j :
321
+ continue
322
+ box_b = boxes [j ]
323
+ if is_x_contained (box_a , box_b ):
324
+ is_split = True
325
+ # Split box_a based on the x-coordinates of box_b
326
+ if box_a [0 ][0 ] < box_b [0 ][0 ]:
327
+ w = box_b [0 ][0 ] - offset - box_a [0 ][0 ]
328
+ if w > 1 :
329
+ new_boxes .append (
330
+ [
331
+ np .array (
332
+ [
333
+ box_a [0 ][0 ],
334
+ box_a [0 ][1 ],
335
+ box_b [0 ][0 ] - offset ,
336
+ box_a [0 ][3 ],
337
+ ]
338
+ ),
339
+ box_a [1 ],
340
+ box_a [2 ],
341
+ ]
342
+ )
343
+ if box_a [0 ][2 ] > box_b [0 ][2 ]:
344
+ w = box_a [0 ][2 ] - box_b [0 ][2 ] + offset
345
+ if w > 1 :
346
+ box_a = [
347
+ np .array (
348
+ [
349
+ box_b [0 ][2 ] + offset ,
350
+ box_a [0 ][1 ],
351
+ box_a [0 ][2 ],
352
+ box_a [0 ][3 ],
353
+ ]
354
+ ),
355
+ box_a [1 ],
356
+ box_a [2 ],
357
+ ]
358
+ if j == len (boxes ) - 1 and is_split :
359
+ new_boxes .append (box_a )
360
+ if not is_split :
361
+ new_boxes .append (box_a )
362
+
363
+ return new_boxes
364
+
365
+
366
+ def _sort_line_by_x_projection (
367
+ input_img : np .ndarray ,
368
+ general_ocr_pipeline : Any ,
369
+ line : List [List [Union [List [int ], str ]]],
370
+ ) -> None :
371
+ """
372
+ Sort a line of text spans based on their vertical position within the layout bounding box.
373
+
374
+ Args:
375
+ input_img (ndarray): The input image used for OCR.
376
+ general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
377
+ line (list): A list of spans, where each span is a list containing a bounding box and text.
378
+
379
+ Returns:
380
+ list: The sorted line of text spans.
381
+ """
382
+ splited_boxes = split_boxes_if_x_contained (line )
383
+ splited_lines = []
384
+ if len (line ) != len (splited_boxes ):
385
+ splited_boxes .sort (key = lambda span : span [0 ][0 ])
386
+ text_rec_model = general_ocr_pipeline .text_rec_model
387
+ for span in splited_boxes :
388
+ if span [2 ] == "text" :
389
+ crop_img = input_img [
390
+ int (span [0 ][1 ]) : int (span [0 ][3 ]),
391
+ int (span [0 ][0 ]) : int (span [0 ][2 ]),
392
+ ]
393
+ span [1 ] = next (text_rec_model ([crop_img ]))["rec_text" ]
394
+ splited_lines .append (span )
395
+ else :
396
+ splited_lines = line
397
+
398
+ return splited_lines
289
399
290
400
291
401
def _sort_ocr_res_by_y_projection (
402
+ input_img : np .ndarray ,
403
+ general_ocr_pipeline : Any ,
292
404
label : Any ,
293
405
block_bbox : Tuple [int , int , int , int ],
294
406
ocr_res : Dict [str , List [Any ]],
@@ -298,6 +410,8 @@ def _sort_ocr_res_by_y_projection(
298
410
Sorts OCR results based on their spatial arrangement, grouping them into lines and blocks.
299
411
300
412
Args:
413
+ input_img (ndarray): The input image used for OCR.
414
+ general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
301
415
label (Any): The label associated with the OCR results. It's not used in the function but might be
302
416
relevant for other parts of the calling context.
303
417
block_bbox (Tuple[int, int, int, int]): A tuple representing the layout bounding box, defined as
@@ -318,12 +432,13 @@ def _sort_ocr_res_by_y_projection(
318
432
319
433
boxes = ocr_res ["boxes" ]
320
434
rec_texts = ocr_res ["rec_texts" ]
435
+ rec_labels = ocr_res ["rec_labels" ]
321
436
322
437
x_min , _ , x_max , _ = block_bbox
323
438
inline_x_min = min ([box [0 ] for box in boxes ])
324
439
inline_x_max = max ([box [2 ] for box in boxes ])
325
440
326
- spans = list (zip (boxes , rec_texts ))
441
+ spans = list (zip (boxes , rec_texts , rec_labels ))
327
442
328
443
spans .sort (key = lambda span : span [0 ][1 ])
329
444
spans = [list (span ) for span in spans ]
@@ -350,16 +465,21 @@ def _sort_ocr_res_by_y_projection(
350
465
if current_line :
351
466
lines .append (current_line )
352
467
468
+ new_lines = []
353
469
for line in lines :
354
470
line .sort (key = lambda span : span [0 ][0 ])
471
+
472
+ ocr_labels = [span [2 ] for span in line ]
473
+ if "formula" in ocr_labels :
474
+ line = _sort_line_by_x_projection (input_img , general_ocr_pipeline , line )
355
475
if label == "reference" :
356
476
line = _format_line (line , inline_x_min , inline_x_max , is_reference = True )
357
477
else :
358
478
line = _format_line (line , x_min , x_max )
479
+ new_lines .append (line )
359
480
360
- # Flatten lines back into a single list for boxes and texts
361
- ocr_res ["boxes" ] = [span [0 ] for line in lines for span in line ]
362
- ocr_res ["rec_texts" ] = [span [1 ] + " " for line in lines for span in line ]
481
+ ocr_res ["boxes" ] = [span [0 ] for line in new_lines for span in line ]
482
+ ocr_res ["rec_texts" ] = [span [1 ] + " " for line in new_lines for span in line ]
363
483
364
484
return ocr_res
365
485
@@ -418,6 +538,7 @@ def handle_spaces_(text: str) -> str:
418
538
419
539
420
540
def get_single_block_parsing_res (
541
+ general_ocr_pipeline : Any ,
421
542
overall_ocr_res : OCRResult ,
422
543
layout_det_res : DetResult ,
423
544
table_res_list : list ,
@@ -452,10 +573,16 @@ def get_single_block_parsing_res(
452
573
input_img = overall_ocr_res ["doc_preprocessor_res" ]["output_img" ]
453
574
seal_index = 0
454
575
455
- for box_info in layout_det_res ["boxes" ]:
576
+ layout_det_res_list , _ = _remove_overlap_blocks (
577
+ deepcopy (layout_det_res ["boxes" ]),
578
+ threshold = 0.5 ,
579
+ smaller = True ,
580
+ )
581
+
582
+ for box_info in layout_det_res_list :
456
583
block_bbox = box_info ["coordinate" ]
457
584
label = box_info ["label" ]
458
- rec_res = {"boxes" : [], "rec_texts" : [], "flag" : False }
585
+ rec_res = {"boxes" : [], "rec_texts" : [], "rec_labels" : [], " flag" : False }
459
586
seg_start_flag = True
460
587
seg_end_flag = True
461
588
@@ -504,10 +631,15 @@ def get_single_block_parsing_res(
504
631
rec_res ["rec_texts" ].append (
505
632
overall_ocr_res ["rec_texts" ][box_no ],
506
633
)
634
+ rec_res ["rec_labels" ].append (
635
+ overall_ocr_res ["rec_labels" ][box_no ],
636
+ )
507
637
rec_res ["flag" ] = True
508
638
509
639
if rec_res ["flag" ]:
510
- rec_res = _sort_ocr_res_by_y_projection (label , block_bbox , rec_res , 0.7 )
640
+ rec_res = _sort_ocr_res_by_y_projection (
641
+ input_img , general_ocr_pipeline , label , block_bbox , rec_res , 0.7
642
+ )
511
643
rec_res_first_bbox = rec_res ["boxes" ][0 ]
512
644
rec_res_end_bbox = rec_res ["boxes" ][- 1 ]
513
645
if rec_res_first_bbox [0 ] - block_bbox [0 ] < 10 :
@@ -548,6 +680,20 @@ def get_single_block_parsing_res(
548
680
},
549
681
)
550
682
683
+ if len (layout_det_res_list ) == 0 :
684
+ for ocr_rec_box , ocr_rec_text in zip (
685
+ overall_ocr_res ["rec_boxes" ], overall_ocr_res ["rec_texts" ]
686
+ ):
687
+ single_block_layout_parsing_res .append (
688
+ {
689
+ "block_label" : "text" ,
690
+ "block_content" : ocr_rec_text ,
691
+ "block_bbox" : ocr_rec_box ,
692
+ "seg_start_flag" : True ,
693
+ "seg_end_flag" : True ,
694
+ },
695
+ )
696
+
551
697
single_block_layout_parsing_res = get_layout_ordering (
552
698
single_block_layout_parsing_res ,
553
699
no_mask_labels = [
@@ -910,8 +1056,8 @@ def _remove_overlap_blocks(
910
1056
continue
911
1057
# Check for overlap and determine which block to remove
912
1058
overlap_box_index = _get_minbox_if_overlap_by_ratio (
913
- block1 ["block_bbox " ],
914
- block2 ["block_bbox " ],
1059
+ block1 ["coordinate " ],
1060
+ block2 ["coordinate " ],
915
1061
threshold ,
916
1062
smaller = smaller ,
917
1063
)
@@ -1419,11 +1565,6 @@ def get_layout_ordering(
1419
1565
vision_labels = ["image" , "table" , "seal" , "chart" , "figure" ]
1420
1566
vision_title_labels = ["table_title" , "chart_title" , "figure_title" ]
1421
1567
1422
- parsing_res_list , _ = _remove_overlap_blocks (
1423
- parsing_res_list ,
1424
- threshold = 0.5 ,
1425
- smaller = True ,
1426
- )
1427
1568
parsing_res_list , pre_cuts = _get_sub_category (parsing_res_list , title_text_labels )
1428
1569
1429
1570
parsing_res_by_pre_cuts_list = []
0 commit comments