LatexOCR
Classes
LatexOCR
Get a prediction of a math formula image in the easiest way
Methods
__init__(self, *, model_name='mfr', model_backend='onnx', device=None, context=None, model_dir=None, root='/home/docs/.pix2text', more_processor_configs=None, more_model_configs=None, **kwargs)
special
Initialize a LatexOCR model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_name |
str |
The name of the model. Defaults to 'mfr'. |
'mfr' |
model_backend |
str |
The model backend, either 'onnx' or 'pytorch'. Defaults to 'onnx'. |
'onnx' |
device |
str |
What device to use for computation, supports |
None |
context |
str |
Deprecated, use |
None |
model_dir |
Optional[Union[str, Path]] |
The model file directory. Defaults to None. |
None |
root |
Union[str, Path] |
The model root directory. Defaults to data_dir(). |
'/home/docs/.pix2text' |
more_processor_configs |
Optional[Dict[str, Any]] |
Additional processor configurations. Defaults to None. |
None |
more_model_configs |
Optional[Dict[str, Any]] |
Additional model configurations. Defaults to None.
|
None |
**kwargs |
Additional arguments, currently not used. |
{} |
Source code in pix2text/latex_ocr.py
def __init__(
self,
*,
model_name: str = 'mfr',
model_backend: str = 'onnx',
device: str = None,
context: str = None, # deprecated, use `device` instead
model_dir: Optional[Union[str, Path]] = None,
root: Union[str, Path] = data_dir(),
more_processor_configs: Optional[Dict[str, Any]] = None,
more_model_configs: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""Initialize a LatexOCR model.
Args:
model_name (str, optional): The name of the model. Defaults to 'mfr'.
model_backend (str, optional): The model backend, either 'onnx' or 'pytorch'. Defaults to 'onnx'.
device (str, optional): What device to use for computation, supports `['cpu', 'cuda', 'gpu']`; defaults to None, which selects the device automatically.
context (str, optional): Deprecated, use `device` instead. What device to use for computation, supports `['cpu', 'cuda', 'gpu']`; defaults to None, which selects the device automatically.
model_dir (Optional[Union[str, Path]], optional): The model file directory. Defaults to None.
root (Union[str, Path], optional): The model root directory. Defaults to data_dir().
more_processor_configs (Optional[Dict[str, Any]], optional): Additional processor configurations. Defaults to None.
more_model_configs (Optional[Dict[str, Any]], optional): Additional model configurations. Defaults to None.
- provider (`str`, defaults to `None`, which means to select one provider automatically):
ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ for
possible providers.
- session_options (`Optional[onnxruntime.SessionOptions]`, defaults to `None`),:
ONNX Runtime session options to use for loading the model.
- provider_options (`Optional[Dict[str, Any]]`, defaults to `None`):
Provider option dictionaries corresponding to the provider used. See available options
for each provider: https://onnxruntime.ai/docs/api/c/group___global.html .
- ...: see more information here: optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained()
**kwargs: Additional arguments, currently not used.
"""
if context is not None:
logger.warning(f'`context` is deprecated, please use `device` instead')
if device is None and context is not None:
device = context
self.device = select_device(device)
model_info = AVAILABLE_MODELS.get_info(model_name, model_backend)
if model_dir is None:
model_dir = self._prepare_model_files(root, model_backend, model_info)
logger.info(f'Use model dir for LatexOCR: {model_dir}')
more_model_configs = more_model_configs or {}
if model_backend == 'onnx' and 'provider' not in more_model_configs:
available_providers = get_default_ort_providers()
if not available_providers:
raise RuntimeError(
'No available providers for ONNX Runtime, please install onnxruntime-gpu or onnxruntime.'
)
more_model_configs['provider'] = available_providers[0]
self.model, self.processor = self._init_model(
model_backend,
model_dir,
more_processor_config=more_processor_configs,
more_model_config=more_model_configs,
)
logger.info(
f'Loaded Pix2Text MFR model {model_name} to: backend-{model_backend}, device-{self.device}'
)
recognize(self, imgs, batch_size=1, use_post_process=True, rec_config=None, **kwargs)
Recognize Math Formula images to LaTeX Expressions
Parameters:
Name | Type | Description | Default |
---|---|---|---|
imgs |
Union[str, Path, Image.Image, List[str], List[Path], List[Image.Image] |
The image or list of images |
required |
batch_size |
int |
The batch size |
1 |
use_post_process |
bool |
Whether to use post process. Defaults to True |
True |
rec_config |
Optional[dict] |
The generation config |
None |
**kwargs |
Other arguments. Not used for now |
{} |
Source code in pix2text/latex_ocr.py
def recognize(
self,
imgs: Union[str, Path, Image.Image, List[str], List[Path], List[Image.Image]],
batch_size: int = 1,
use_post_process: bool = True,
rec_config: Optional[dict] = None,
**kwargs,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""
Recognize Math Formula images to LaTeX Expressions
Args:
imgs (Union[str, Path, Image.Image, List[str], List[Path], List[Image.Image]): The image or list of images
batch_size (int): The batch size
use_post_process (bool): Whether to use post process. Defaults to True
rec_config (Optional[dict]): The generation config
**kwargs (): Other arguments. Not used for now
Returns: The LaTeX Result or list of LaTeX Results; each result is a dict with `text` and `score` fields.
"""
is_single_image = False
if isinstance(imgs, (str, Path, Image.Image)):
imgs = [imgs]
is_single_image = True
input_imgs = prepare_imgs(imgs)
# inference batch by batch
results = []
for i in tqdm.tqdm(range(0, len(input_imgs), batch_size)):
part_imgs = input_imgs[i : i + batch_size]
results.extend(self._one_batch(part_imgs, rec_config, **kwargs))
if use_post_process:
for info in results:
info['text'] = self.post_process(info['text'])
if is_single_image:
return results[0]
return results
Functions
fix_latex(latex)
对识别结果做进一步处理和修正。
Source code in pix2text/latex_ocr.py
def fix_latex(latex: str) -> str:
"""对识别结果做进一步处理和修正。"""
# # 把latex中的中文括号全部替换成英文括号
# latex = latex.replace('(', '(').replace(')', ')')
# # 把latex中的中文逗号全部替换成英文逗号
# latex = latex.replace(',', ',')
left_bracket_infos = find_all_left_or_right(latex, left_or_right='left')
right_bracket_infos = find_all_left_or_right(latex, left_or_right='right')
# left 和 right 找配对,left找位置比它靠前且最靠近他的right配对
for left_bracket_info in left_bracket_infos:
for right_bracket_info in right_bracket_infos:
if (
not right_bracket_info.get('matched', False)
and right_bracket_info['start'] > left_bracket_info['start']
and match_left_right(
right_bracket_info['str'], left_bracket_info['str']
)
):
left_bracket_info['matched'] = True
right_bracket_info['matched'] = True
break
for left_bracket_info in left_bracket_infos:
# 把没有匹配的 '\left'替换为等长度的空格
left_len = len('left') + 1
if not left_bracket_info.get('matched', False):
start_idx = left_bracket_info['start']
end_idx = start_idx + left_len
latex = (
latex[: left_bracket_info['start']]
+ ' ' * (end_idx - start_idx)
+ latex[end_idx:]
)
for right_bracket_info in right_bracket_infos:
# 把没有匹配的 '\right'替换为等长度的空格
right_len = len('right') + 1
if not right_bracket_info.get('matched', False):
start_idx = right_bracket_info['start']
end_idx = start_idx + right_len
latex = (
latex[: right_bracket_info['start']]
+ ' ' * (end_idx - start_idx)
+ latex[end_idx:]
)
# 把 latex 中的连续空格替换为一个空格
latex = re.sub(r'\s+', ' ', latex)
return latex.strip()
match_left_right(left_str, right_str)
匹配左右括号,如匹配 \left(
和 ight)
。
Source code in pix2text/latex_ocr.py
def match_left_right(left_str, right_str):
"""匹配左右括号,如匹配 `\left(` 和 `\right)`。"""
left_str = left_str.strip().replace(' ', '')[len('left') + 1 :]
right_str = right_str.strip().replace(' ', '')[len('right') + 1 :]
# 去掉开头的相同部分
while left_str and right_str and left_str[0] == right_str[0]:
left_str = left_str[1:]
right_str = right_str[1:]
match_pairs = [
('', ''),
('(', ')'),
('\{', '.'), # 大括号那种
('⟮', '⟯'),
('[', ']'),
('⟨', '⟩'),
('{', '}'),
('⌈', '⌉'),
('┌', '┐'),
('⌊', '⌋'),
('└', '┘'),
('⎰', '⎱'),
('lt', 'gt'),
('lang', 'rang'),
(r'langle', r'rangle'),
(r'lbrace', r'rbrace'),
('lBrace', 'rBrace'),
(r'lbracket', r'rbracket'),
(r'lceil', r'rceil'),
('lcorner', 'rcorner'),
(r'lfloor', r'rfloor'),
(r'lgroup', r'rgroup'),
(r'lmoustache', r'rmoustache'),
(r'lparen', r'rparen'),
(r'lvert', r'rvert'),
(r'lVert', r'rVert'),
]
return (left_str, right_str) in match_pairs