26
26
import uuid
27
27
import re
28
28
from pathlib import Path
29
+ from copy import deepcopy
29
30
from typing import Optional , Union , List , Tuple , Dict , Any
30
31
from ..ocr .result import OCRResult
31
32
from ...models .object_detection .result import DetResult
@@ -253,6 +254,7 @@ def _adjust_span_text(span: List[str], prepend: bool = False, append: bool = Fal
253
254
span [1 ] = "\n " + span [1 ]
254
255
if append :
255
256
span [1 ] = span [1 ] + "\n "
257
+ return span
256
258
257
259
258
260
def _format_line (
@@ -278,17 +280,129 @@ def _format_line(
278
280
279
281
if not is_reference :
280
282
if first_span [0 ][0 ] - layout_min > 10 :
281
- _adjust_span_text (first_span , prepend = True )
283
+ first_span = _adjust_span_text (first_span , prepend = True )
282
284
if layout_max - end_span [0 ][2 ] > 10 :
283
- _adjust_span_text (end_span , append = True )
285
+ end_span = _adjust_span_text (end_span , append = True )
284
286
else :
285
287
if first_span [0 ][0 ] - layout_min < 5 :
286
- _adjust_span_text (first_span , prepend = True )
288
+ first_span = _adjust_span_text (first_span , prepend = True )
287
289
if layout_max - end_span [0 ][2 ] > 20 :
288
- _adjust_span_text (end_span , append = True )
290
+ end_span = _adjust_span_text (end_span , append = True )
291
+
292
+ line [0 ] = first_span
293
+ line [- 1 ] = end_span
294
+
295
+ return line
296
+
297
+
298
+ def split_boxes_if_x_contained (boxes , offset = 1e-5 ):
299
+ """
300
+ Check if there is any complete containment in the x-direction
301
+ between the bounding boxes and split the containing box accordingly.
302
+
303
+ Args:
304
+ boxes (list of lists): Each element is a list containing an ndarray of length 4, a description, and a label.
305
+ offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
306
+ Returns:
307
+ A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
308
+ """
309
+
310
+ def is_x_contained (box_a , box_b ):
311
+ """Check if box_a completely contains box_b in the x-direction."""
312
+ return box_a [0 ][0 ] <= box_b [0 ][0 ] and box_a [0 ][2 ] >= box_b [0 ][2 ]
313
+
314
+ new_boxes = []
315
+
316
+ for i in range (len (boxes )):
317
+ box_a = boxes [i ]
318
+ is_split = False
319
+ for j in range (len (boxes )):
320
+ if i == j :
321
+ continue
322
+ box_b = boxes [j ]
323
+ if is_x_contained (box_a , box_b ):
324
+ is_split = True
325
+ # Split box_a based on the x-coordinates of box_b
326
+ if box_a [0 ][0 ] < box_b [0 ][0 ]:
327
+ w = box_b [0 ][0 ] - offset - box_a [0 ][0 ]
328
+ h = box_a [0 ][3 ] - box_a [0 ][1 ]
329
+ if w > h / 2 :
330
+ new_boxes .append (
331
+ [
332
+ np .array (
333
+ [
334
+ box_a [0 ][0 ],
335
+ box_a [0 ][1 ],
336
+ box_b [0 ][0 ] - offset ,
337
+ box_a [0 ][3 ],
338
+ ]
339
+ ),
340
+ box_a [1 ],
341
+ box_a [2 ],
342
+ ]
343
+ )
344
+ if box_a [0 ][2 ] > box_b [0 ][2 ]:
345
+ w = box_a [0 ][2 ] - box_b [0 ][2 ] + offset
346
+ h = box_a [0 ][3 ] - box_a [0 ][1 ]
347
+ if w > h / 2 :
348
+ box_a = [
349
+ np .array (
350
+ [
351
+ box_b [0 ][2 ] + offset ,
352
+ box_a [0 ][1 ],
353
+ box_a [0 ][2 ],
354
+ box_a [0 ][3 ],
355
+ ]
356
+ ),
357
+ box_a [1 ],
358
+ box_a [2 ],
359
+ ]
360
+ if j == len (boxes ) - 1 and is_split :
361
+ new_boxes .append (box_a )
362
+ if not is_split :
363
+ new_boxes .append (box_a )
364
+
365
+ return new_boxes
366
+
367
+
368
+ def _sort_line_by_x_projection (
369
+ input_img : np .ndarray ,
370
+ general_ocr_pipeline : Any ,
371
+ line : List [List [Union [List [int ], str ]]],
372
+ ) -> None :
373
+ """
374
+ Sort a line of text spans based on their vertical position within the layout bounding box.
375
+
376
+ Args:
377
+ input_img (ndarray): The input image used for OCR.
378
+ general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
379
+ line (list): A list of spans, where each span is a list containing a bounding box and text.
380
+
381
+ Returns:
382
+ list: The sorted line of text spans.
383
+ """
384
+ splited_boxes = split_boxes_if_x_contained (line )
385
+ splited_lines = []
386
+ if len (line ) != len (splited_boxes ):
387
+ splited_boxes .sort (key = lambda span : span [0 ][0 ])
388
+ text_rec_model = general_ocr_pipeline .text_rec_model
389
+ for span in splited_boxes :
390
+ if span [2 ] == "text" :
391
+ crop_img = input_img [
392
+ int (span [0 ][1 ]) : int (span [0 ][3 ]),
393
+ int (span [0 ][0 ]) : int (span [0 ][2 ]),
394
+ ]
395
+ span [1 ] = next (text_rec_model ([crop_img ]))["rec_text" ]
396
+ splited_lines .append (span )
397
+ else :
398
+ splited_lines = line
399
+
400
+ return splited_lines
289
401
290
402
291
403
def _sort_ocr_res_by_y_projection (
404
+ input_img : np .ndarray ,
405
+ general_ocr_pipeline : Any ,
292
406
label : Any ,
293
407
block_bbox : Tuple [int , int , int , int ],
294
408
ocr_res : Dict [str , List [Any ]],
@@ -298,6 +412,8 @@ def _sort_ocr_res_by_y_projection(
298
412
Sorts OCR results based on their spatial arrangement, grouping them into lines and blocks.
299
413
300
414
Args:
415
+ input_img (ndarray): The input image used for OCR.
416
+ general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
301
417
label (Any): The label associated with the OCR results. It's not used in the function but might be
302
418
relevant for other parts of the calling context.
303
419
block_bbox (Tuple[int, int, int, int]): A tuple representing the layout bounding box, defined as
@@ -318,12 +434,13 @@ def _sort_ocr_res_by_y_projection(
318
434
319
435
boxes = ocr_res ["boxes" ]
320
436
rec_texts = ocr_res ["rec_texts" ]
437
+ rec_labels = ocr_res ["rec_labels" ]
321
438
322
439
x_min , _ , x_max , _ = block_bbox
323
440
inline_x_min = min ([box [0 ] for box in boxes ])
324
441
inline_x_max = max ([box [2 ] for box in boxes ])
325
442
326
- spans = list (zip (boxes , rec_texts ))
443
+ spans = list (zip (boxes , rec_texts , rec_labels ))
327
444
328
445
spans .sort (key = lambda span : span [0 ][1 ])
329
446
spans = [list (span ) for span in spans ]
@@ -350,16 +467,21 @@ def _sort_ocr_res_by_y_projection(
350
467
if current_line :
351
468
lines .append (current_line )
352
469
470
+ new_lines = []
353
471
for line in lines :
354
472
line .sort (key = lambda span : span [0 ][0 ])
473
+
474
+ ocr_labels = [span [2 ] for span in line ]
475
+ if "formula" in ocr_labels :
476
+ line = _sort_line_by_x_projection (input_img , general_ocr_pipeline , line )
355
477
if label == "reference" :
356
478
line = _format_line (line , inline_x_min , inline_x_max , is_reference = True )
357
479
else :
358
480
line = _format_line (line , x_min , x_max )
481
+ new_lines .append (line )
359
482
360
- # Flatten lines back into a single list for boxes and texts
361
- ocr_res ["boxes" ] = [span [0 ] for line in lines for span in line ]
362
- ocr_res ["rec_texts" ] = [span [1 ] + " " for line in lines for span in line ]
483
+ ocr_res ["boxes" ] = [span [0 ] for line in new_lines for span in line ]
484
+ ocr_res ["rec_texts" ] = [span [1 ] + " " for line in new_lines for span in line ]
363
485
364
486
return ocr_res
365
487
@@ -418,6 +540,7 @@ def handle_spaces_(text: str) -> str:
418
540
419
541
420
542
def get_single_block_parsing_res (
543
+ general_ocr_pipeline : Any ,
421
544
overall_ocr_res : OCRResult ,
422
545
layout_det_res : DetResult ,
423
546
table_res_list : list ,
@@ -452,10 +575,16 @@ def get_single_block_parsing_res(
452
575
input_img = overall_ocr_res ["doc_preprocessor_res" ]["output_img" ]
453
576
seal_index = 0
454
577
455
- for box_info in layout_det_res ["boxes" ]:
578
+ layout_det_res_list , _ = _remove_overlap_blocks (
579
+ deepcopy (layout_det_res ["boxes" ]),
580
+ threshold = 0.5 ,
581
+ smaller = True ,
582
+ )
583
+
584
+ for box_info in layout_det_res_list :
456
585
block_bbox = box_info ["coordinate" ]
457
586
label = box_info ["label" ]
458
- rec_res = {"boxes" : [], "rec_texts" : [], "flag" : False }
587
+ rec_res = {"boxes" : [], "rec_texts" : [], "rec_labels" : [], " flag" : False }
459
588
seg_start_flag = True
460
589
seg_end_flag = True
461
590
@@ -504,10 +633,15 @@ def get_single_block_parsing_res(
504
633
rec_res ["rec_texts" ].append (
505
634
overall_ocr_res ["rec_texts" ][box_no ],
506
635
)
636
+ rec_res ["rec_labels" ].append (
637
+ overall_ocr_res ["rec_labels" ][box_no ],
638
+ )
507
639
rec_res ["flag" ] = True
508
640
509
641
if rec_res ["flag" ]:
510
- rec_res = _sort_ocr_res_by_y_projection (label , block_bbox , rec_res , 0.7 )
642
+ rec_res = _sort_ocr_res_by_y_projection (
643
+ input_img , general_ocr_pipeline , label , block_bbox , rec_res , 0.7
644
+ )
511
645
rec_res_first_bbox = rec_res ["boxes" ][0 ]
512
646
rec_res_end_bbox = rec_res ["boxes" ][- 1 ]
513
647
if rec_res_first_bbox [0 ] - block_bbox [0 ] < 10 :
@@ -548,6 +682,20 @@ def get_single_block_parsing_res(
548
682
},
549
683
)
550
684
685
+ if len (layout_det_res_list ) == 0 :
686
+ for ocr_rec_box , ocr_rec_text in zip (
687
+ overall_ocr_res ["rec_boxes" ], overall_ocr_res ["rec_texts" ]
688
+ ):
689
+ single_block_layout_parsing_res .append (
690
+ {
691
+ "block_label" : "text" ,
692
+ "block_content" : ocr_rec_text ,
693
+ "block_bbox" : ocr_rec_box ,
694
+ "seg_start_flag" : True ,
695
+ "seg_end_flag" : True ,
696
+ },
697
+ )
698
+
551
699
single_block_layout_parsing_res = get_layout_ordering (
552
700
single_block_layout_parsing_res ,
553
701
no_mask_labels = [
@@ -910,8 +1058,8 @@ def _remove_overlap_blocks(
910
1058
continue
911
1059
# Check for overlap and determine which block to remove
912
1060
overlap_box_index = _get_minbox_if_overlap_by_ratio (
913
- block1 ["block_bbox " ],
914
- block2 ["block_bbox " ],
1061
+ block1 ["coordinate " ],
1062
+ block2 ["coordinate " ],
915
1063
threshold ,
916
1064
smaller = smaller ,
917
1065
)
@@ -1419,11 +1567,6 @@ def get_layout_ordering(
1419
1567
vision_labels = ["image" , "table" , "seal" , "chart" , "figure" ]
1420
1568
vision_title_labels = ["table_title" , "chart_title" , "figure_title" ]
1421
1569
1422
- parsing_res_list , _ = _remove_overlap_blocks (
1423
- parsing_res_list ,
1424
- threshold = 0.5 ,
1425
- smaller = True ,
1426
- )
1427
1570
parsing_res_list , pre_cuts = _get_sub_category (parsing_res_list , title_text_labels )
1428
1571
1429
1572
parsing_res_by_pre_cuts_list = []
0 commit comments