TableOCR
Classes
TableOCR
Represents a table extractor for extracting tables from a document.
Methods
__init__(self, text_ocr, spellchecker=None, device=None, model_dir=None, root='/home/docs/.pix2text', structure_thresholds=None, table_expansion_margin=10, threshold_percentage=0.1, **kwargs)
special
Initialize an TableDataExtractor object.
Source code in pix2text/table_ocr.py
def __init__(
self,
text_ocr: TextOcrEngine,
spellchecker=None,
device: str = None,
model_dir: Optional[Union[str, Path]] = None,
root: Union[str, Path] = data_dir(),
structure_thresholds=None,
table_expansion_margin=10,
threshold_percentage=0.10,
**kwargs,
):
"""
Initialize an TableDataExtractor object.
"""
self.text_ocr = text_ocr
self.spellchecker = spellchecker
self.str_device = select_device(device)
self.str_class_name2idx = get_class_map('structure')
self.str_class_idx2name = {v: k for k, v in self.str_class_name2idx.items()}
self.str_class_thresholds = structure_thresholds or DEFAULT_STRUCTURE_THRESHOLDS
if model_dir is None:
model_dir = self._prepare_model_files(root, None)
# Initialize the model for identifying table structures
self.str_model = AutoModelForObjectDetection.from_pretrained(model_dir).to(
self.str_device
)
self.str_model.eval()
# Expand the bounding box slightly for better cropping
self._table_expansion_margin = table_expansion_margin
# Use a percentage (e.g., 10%) of the average height as the threshold for a new row
self._threshold_percentage = threshold_percentage
self.test = []
recognize(self, img, tokens=None, out_objects=False, out_cells=True, out_html=False, out_csv=False, out_markdown=True, **kwargs)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
img |
required | ||
tokens |
None |
||
out_objects |
False |
||
out_cells |
True |
||
out_html |
False |
||
out_csv |
False |
||
out_markdown |
True |
||
**kwargs |
|
{} |
Source code in pix2text/table_ocr.py
def recognize(
self,
img,
tokens=None,
out_objects=False,
out_cells=True,
out_html=False,
out_csv=False,
out_markdown=True,
**kwargs,
) -> Dict[str, Any]:
"""
Args:
img ():
tokens ():
out_objects ():
out_cells ():
out_html ():
out_csv ():
out_markdown ():
**kwargs ():
* save_analysis_res (str): Save the parsed result image in this file; default value is `None`, which means not to save
Returns:
"""
out_formats = {}
if self.str_model is None:
print("No structure model loaded.")
return out_formats
if not (out_objects or out_cells or out_html or out_csv):
print("No output format specified")
return out_formats
if isinstance(img, (str, Path)):
img = read_img(img, return_type='Image')
# Transform the image how the model expects it
img_tensor = structure_transform(img)
# Run input image through the model
with torch.no_grad():
outputs = self.str_model(img_tensor.unsqueeze(0).to(self.str_device))
# Post-process detected objects, assign class labels
objects = outputs_to_objects(outputs, img.size, self.str_class_idx2name)
if out_objects:
out_formats['objects'] = objects
if not (out_cells or out_html or out_csv):
return out_formats
# Further process the detected objects so they correspond to a consistent table
tokens = tokens or []
tables_structure = objects_to_structures(
objects, tokens, self.str_class_thresholds
)
# Enumerate all table cells: grid cells and spanning cells
tables_cells = [
structure_to_cells(structure, tokens)[0] for structure in tables_structure
]
for cells in tables_cells:
self._ocr_texts(img, cells)
if out_cells:
out_formats['cells'] = tables_cells
if kwargs.get('save_analysis_res'):
visualize_cells(img, tables_cells[0], kwargs['save_analysis_res'])
if not (out_html or out_csv):
return out_formats
# Convert cells to HTML
if out_html:
tables_htmls = [cells_to_html(cells) for cells in tables_cells]
out_formats['html'] = tables_htmls
# Convert cells to CSV, including flattening multi-row column headers to a single row
if out_csv:
tables_csvs = [cells_to_csv(cells) for cells in tables_cells]
out_formats['csv'] = tables_csvs
if out_markdown:
tables_mds = [cells_to_markdown(cells) for cells in tables_cells]
out_formats['markdown'] = tables_mds
return out_formats
Functions
align_headers(headers, rows)
Adjust the header boundary to be the convex hull of the rows it intersects at least 50% of the height of.
For now, we are not supporting tables with multiple headers, so we need to eliminate anything besides the top-most header.
Source code in pix2text/table_ocr.py
def align_headers(headers, rows):
"""
Adjust the header boundary to be the convex hull of the rows it intersects
at least 50% of the height of.
For now, we are not supporting tables with multiple headers, so we need to
eliminate anything besides the top-most header.
"""
aligned_headers = []
for row in rows:
row['column header'] = False
header_row_nums = []
for header in headers:
for row_num, row in enumerate(rows):
row_height = row['bbox'][3] - row['bbox'][1]
min_row_overlap = max(row['bbox'][1], header['bbox'][1])
max_row_overlap = min(row['bbox'][3], header['bbox'][3])
overlap_height = max_row_overlap - min_row_overlap
if overlap_height / row_height >= 0.5:
header_row_nums.append(row_num)
if len(header_row_nums) == 0:
return aligned_headers
header_rect = Rect()
if header_row_nums[0] > 0:
header_row_nums = list(range(header_row_nums[0] + 1)) + header_row_nums
last_row_num = -1
for row_num in header_row_nums:
if row_num == last_row_num + 1:
row = rows[row_num]
row['column header'] = True
header_rect = header_rect.include_rect(row['bbox'])
last_row_num = row_num
else:
# Break as soon as a non-header row is encountered.
# This ignores any subsequent rows in the table labeled as a header.
# Having more than 1 header is not supported currently.
break
header = {'bbox': list(header_rect)}
aligned_headers.append(header)
return aligned_headers
etree_to_markdown_table(etree)
将XML ElementTree对象转换为Markdown格式的表格。
Parameters:
Name | Type | Description | Default |
---|---|---|---|
etree |
xml.etree.ElementTree.Element |
XML表格的根元素。 |
required |
Returns:
Type | Description |
---|---|
str |
Markdown格式的表格字符串。 |
Source code in pix2text/table_ocr.py
def etree_to_markdown_table(etree):
"""
将XML ElementTree对象转换为Markdown格式的表格。
Args:
etree (xml.etree.ElementTree.Element): XML表格的根元素。
Returns:
str: Markdown格式的表格字符串。
"""
if etree.tag != 'table':
return "Invalid XML input: root element is not a table."
markdown_table = []
headers = [th.text for th in etree.findall('.//th')]
if headers:
markdown_table.append("| " + " | ".join(headers) + " |")
markdown_table.append("| " + " | ".join(["---"] * len(headers)) + " |")
rows = etree.findall('.//tr')
if rows:
for row in rows:
cells = [td.text.replace('\n', ' ') for td in row.findall('td')]
if not cells:
continue
markdown_table.append("| " + " | ".join(cells) + " |")
else:
return "Invalid XML input: no rows found."
return "\n".join(markdown_table)
iob(bbox1, bbox2)
Compute the intersection area over box area, for bbox1.
Source code in pix2text/table_ocr.py
def iob(bbox1, bbox2):
"""
Compute the intersection area over box area, for bbox1.
"""
intersection = Rect(bbox1).intersect(bbox2)
bbox1_area = Rect(bbox1).get_area()
if bbox1_area > 0:
return intersection.get_area() / bbox1_area
return 0
objects_to_crops(img, tokens, objects, class_thresholds, padding=10)
Process the bounding boxes produced by the table detection model into cropped table images and cropped tokens.
Source code in pix2text/table_ocr.py
def objects_to_crops(img, tokens, objects, class_thresholds, padding=10):
"""
Process the bounding boxes produced by the table detection model into
cropped table images and cropped tokens.
"""
table_crops = []
for obj in objects:
if obj['score'] < class_thresholds[obj['label']]:
continue
cropped_table = {}
bbox = obj['bbox']
bbox = [
bbox[0] - padding,
bbox[1] - padding,
bbox[2] + padding,
bbox[3] + padding,
]
cropped_img = img.crop(bbox)
table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
for token in table_tokens:
token['bbox'] = [
token['bbox'][0] - bbox[0],
token['bbox'][1] - bbox[1],
token['bbox'][2] - bbox[0],
token['bbox'][3] - bbox[1],
]
# If table is predicted to be rotated, rotate cropped image and tokens/words:
if obj['label'] == 'table rotated':
cropped_img = cropped_img.rotate(270, expand=True)
for token in table_tokens:
bbox = token['bbox']
bbox = [
cropped_img.size[0] - bbox[3] - 1,
bbox[0],
cropped_img.size[0] - bbox[1] - 1,
bbox[2],
]
token['bbox'] = bbox
cropped_table['image'] = cropped_img
cropped_table['tokens'] = table_tokens
table_crops.append(cropped_table)
return table_crops
objects_to_structures(objects, tokens, class_thresholds)
Process the bounding boxes produced by the table structure recognition model into a consistent set of table structures (rows, columns, spanning cells, headers). This entails resolving conflicts/overlaps, and ensuring the boxes meet certain alignment conditions (for example: rows should all have the same width, etc.).
Source code in pix2text/table_ocr.py
def objects_to_structures(objects, tokens, class_thresholds):
"""
Process the bounding boxes produced by the table structure recognition model into
a *consistent* set of table structures (rows, columns, spanning cells, headers).
This entails resolving conflicts/overlaps, and ensuring the boxes meet certain alignment
conditions (for example: rows should all have the same width, etc.).
"""
tables = [obj for obj in objects if obj['label'] == 'table']
table_structures = []
for table in tables:
table_objects = [
obj for obj in objects if iob(obj['bbox'], table['bbox']) >= 0.5
]
table_tokens = [
token for token in tokens if iob(token['bbox'], table['bbox']) >= 0.5
]
structure = {}
columns = [obj for obj in table_objects if obj['label'] == 'table column']
rows = [obj for obj in table_objects if obj['label'] == 'table row']
column_headers = [
obj for obj in table_objects if obj['label'] == 'table column header'
]
spanning_cells = [
obj for obj in table_objects if obj['label'] == 'table spanning cell'
]
for obj in spanning_cells:
obj['projected row header'] = False
projected_row_headers = [
obj for obj in table_objects if obj['label'] == 'table projected row header'
]
for obj in projected_row_headers:
obj['projected row header'] = True
spanning_cells += projected_row_headers
for obj in rows:
obj['column header'] = False
for header_obj in column_headers:
if iob(obj['bbox'], header_obj['bbox']) >= 0.5:
obj['column header'] = True
# Refine table structures
rows = postprocess.refine_rows(
rows, table_tokens, class_thresholds['table row']
)
columns = postprocess.refine_columns(
columns, table_tokens, class_thresholds['table column']
)
# Shrink table bbox to just the total height of the rows
# and the total width of the columns
row_rect = Rect()
for obj in rows:
row_rect.include_rect(obj['bbox'])
column_rect = Rect()
for obj in columns:
column_rect.include_rect(obj['bbox'])
table['row_column_bbox'] = [
column_rect[0],
row_rect[1],
column_rect[2],
row_rect[3],
]
table['bbox'] = table['row_column_bbox']
# Process the rows and columns into a complete segmented table
columns = postprocess.align_columns(columns, table['row_column_bbox'])
rows = postprocess.align_rows(rows, table['row_column_bbox'])
structure['rows'] = rows
structure['columns'] = columns
structure['column headers'] = column_headers
structure['spanning cells'] = spanning_cells
if len(rows) > 0 and len(columns) > 1:
structure = refine_table_structure(structure, class_thresholds)
table_structures.append(structure)
return table_structures
refine_table_structure(table_structure, class_thresholds)
Apply operations to the detected table structure objects such as thresholding, NMS, and alignment.
Source code in pix2text/table_ocr.py
def refine_table_structure(table_structure, class_thresholds):
"""
Apply operations to the detected table structure objects such as
thresholding, NMS, and alignment.
"""
rows = table_structure["rows"]
columns = table_structure['columns']
# Process the headers
column_headers = table_structure['column headers']
column_headers = postprocess.apply_threshold(
column_headers, class_thresholds["table column header"]
)
column_headers = postprocess.nms(column_headers)
column_headers = align_headers(column_headers, rows)
# Process spanning cells
spanning_cells = [
elem
for elem in table_structure['spanning cells']
if not elem['projected row header']
]
projected_row_headers = [
elem
for elem in table_structure['spanning cells']
if elem['projected row header']
]
spanning_cells = postprocess.apply_threshold(
spanning_cells, class_thresholds["table spanning cell"]
)
projected_row_headers = postprocess.apply_threshold(
projected_row_headers, class_thresholds["table projected row header"]
)
spanning_cells += projected_row_headers
# Align before NMS for spanning cells because alignment brings them into agreement
# with rows and columns first; if spanning cells still overlap after this operation,
# the threshold for NMS can basically be lowered to just above 0
spanning_cells = postprocess.align_supercells(spanning_cells, rows, columns)
spanning_cells = postprocess.nms_supercells(spanning_cells)
postprocess.header_supercell_tree(spanning_cells)
table_structure['columns'] = columns
table_structure['rows'] = rows
table_structure['spanning cells'] = spanning_cells
table_structure['column headers'] = column_headers
return table_structure
structure_to_cells(table_structure, tokens)
Assuming the row, column, spanning cell, and header bounding boxes have been refined into a set of consistent table structures, process these table structures into table cells. This is a universal representation format for the table, which can later be exported to Pandas or CSV formats. Classify the cells as header/access cells or data cells based on if they intersect with the header bounding box.
Source code in pix2text/table_ocr.py
def structure_to_cells(table_structure, tokens):
"""
Assuming the row, column, spanning cell, and header bounding boxes have
been refined into a set of consistent table structures, process these
table structures into table cells. This is a universal representation
format for the table, which can later be exported to Pandas or CSV formats.
Classify the cells as header/access cells or data cells
based on if they intersect with the header bounding box.
"""
columns = table_structure['columns']
rows = table_structure['rows']
spanning_cells = table_structure['spanning cells']
cells = []
subcells = []
# Identify complete cells and subcells
for column_num, column in enumerate(columns):
for row_num, row in enumerate(rows):
column_rect = Rect(list(column['bbox']))
row_rect = Rect(list(row['bbox']))
cell_rect = row_rect.intersect(column_rect)
header = 'column header' in row and row['column header']
cell = {
'bbox': list(cell_rect),
'column_nums': [column_num],
'row_nums': [row_num],
'column header': header,
}
cell['subcell'] = False
for spanning_cell in spanning_cells:
spanning_cell_rect = Rect(list(spanning_cell['bbox']))
if (
spanning_cell_rect.intersect(cell_rect).get_area()
/ cell_rect.get_area()
) > 0.5:
cell['subcell'] = True
break
if cell['subcell']:
subcells.append(cell)
else:
# cell text = extract_text_inside_bbox(table_spans, cell['bbox'])
# cell['cell text'] = cell text
cell['projected row header'] = False
cells.append(cell)
for spanning_cell in spanning_cells:
spanning_cell_rect = Rect(list(spanning_cell['bbox']))
cell_columns = set()
cell_rows = set()
cell_rect = None
header = True
for subcell in subcells:
subcell_rect = Rect(list(subcell['bbox']))
subcell_rect_area = subcell_rect.get_area()
if (
subcell_rect.intersect(spanning_cell_rect).get_area()
/ subcell_rect_area
) > 0.5:
if cell_rect is None:
cell_rect = Rect(list(subcell['bbox']))
else:
cell_rect.include_rect(Rect(list(subcell['bbox'])))
cell_rows = cell_rows.union(set(subcell['row_nums']))
cell_columns = cell_columns.union(set(subcell['column_nums']))
# By convention here, all subcells must be classified
# as header cells for a spanning cell to be classified as a header cell;
# otherwise, this could lead to a non-rectangular header region
header = (
header and 'column header' in subcell and subcell['column header']
)
if len(cell_rows) > 0 and len(cell_columns) > 0:
cell = {
'bbox': list(cell_rect),
'column_nums': list(cell_columns),
'row_nums': list(cell_rows),
'column header': header,
'projected row header': spanning_cell['projected row header'],
}
cells.append(cell)
# Compute a confidence score based on how well the page tokens
# slot into the cells reported by the model
_, _, cell_match_scores = postprocess.slot_into_containers(cells, tokens)
try:
mean_match_score = sum(cell_match_scores) / len(cell_match_scores)
min_match_score = min(cell_match_scores)
confidence_score = (mean_match_score + min_match_score) / 2
except:
confidence_score = 0
# Dilate rows and columns before final extraction
# dilated_columns = fill_column_gaps(columns, table_bbox)
dilated_columns = columns
# dilated_rows = fill_row_gaps(rows, table_bbox)
dilated_rows = rows
for cell in cells:
column_rect = Rect()
for column_num in cell['column_nums']:
column_rect.include_rect(list(dilated_columns[column_num]['bbox']))
row_rect = Rect()
for row_num in cell['row_nums']:
row_rect.include_rect(list(dilated_rows[row_num]['bbox']))
cell_rect = column_rect.intersect(row_rect)
cell['bbox'] = list(cell_rect)
span_nums_by_cell, _, _ = postprocess.slot_into_containers(
cells,
tokens,
overlap_threshold=0.001,
unique_assignment=True,
forced_assignment=False,
)
for cell, cell_span_nums in zip(cells, span_nums_by_cell):
cell_spans = [tokens[num] for num in cell_span_nums]
# TODO: Refine how text is extracted; should be character-based, not span-based;
# but need to associate
cell['cell text'] = postprocess.extract_text_from_spans(
cell_spans, remove_integer_superscripts=False
)
cell['spans'] = cell_spans
# Adjust the row, column, and cell bounding boxes to reflect the extracted text
num_rows = len(rows)
rows = postprocess.sort_objects_top_to_bottom(rows)
num_columns = len(columns)
columns = postprocess.sort_objects_left_to_right(columns)
min_y_values_by_row = defaultdict(list)
max_y_values_by_row = defaultdict(list)
min_x_values_by_column = defaultdict(list)
max_x_values_by_column = defaultdict(list)
for cell in cells:
min_row = min(cell["row_nums"])
max_row = max(cell["row_nums"])
min_column = min(cell["column_nums"])
max_column = max(cell["column_nums"])
for span in cell['spans']:
min_x_values_by_column[min_column].append(span['bbox'][0])
min_y_values_by_row[min_row].append(span['bbox'][1])
max_x_values_by_column[max_column].append(span['bbox'][2])
max_y_values_by_row[max_row].append(span['bbox'][3])
for row_num, row in enumerate(rows):
if len(min_x_values_by_column[0]) > 0:
row['bbox'][0] = min(min_x_values_by_column[0])
if len(min_y_values_by_row[row_num]) > 0:
row['bbox'][1] = min(min_y_values_by_row[row_num])
if len(max_x_values_by_column[num_columns - 1]) > 0:
row['bbox'][2] = max(max_x_values_by_column[num_columns - 1])
if len(max_y_values_by_row[row_num]) > 0:
row['bbox'][3] = max(max_y_values_by_row[row_num])
for column_num, column in enumerate(columns):
if len(min_x_values_by_column[column_num]) > 0:
column['bbox'][0] = min(min_x_values_by_column[column_num])
if len(min_y_values_by_row[0]) > 0:
column['bbox'][1] = min(min_y_values_by_row[0])
if len(max_x_values_by_column[column_num]) > 0:
column['bbox'][2] = max(max_x_values_by_column[column_num])
if len(max_y_values_by_row[num_rows - 1]) > 0:
column['bbox'][3] = max(max_y_values_by_row[num_rows - 1])
for cell in cells:
row_rect = Rect()
column_rect = Rect()
for row_num in cell['row_nums']:
row_rect.include_rect(list(rows[row_num]['bbox']))
for column_num in cell['column_nums']:
column_rect.include_rect(list(columns[column_num]['bbox']))
cell_rect = row_rect.intersect(column_rect)
if cell_rect.get_area() > 0:
cell['bbox'] = list(cell_rect)
pass
return cells, confidence_score