Skip to content

LatexOCR

Classes

LatexOCR

Get a prediction of a math formula image in the easiest way

Methods

__init__(self, *, model_name='mfr', model_backend='onnx', device=None, context=None, model_dir=None, root='/home/docs/.pix2text', more_processor_configs=None, more_model_configs=None, **kwargs) special

Initialize a LatexOCR model.

Parameters:

Name Type Description Default
model_name str

The name of the model. Defaults to 'mfr'.

'mfr'
model_backend str

The model backend, either 'onnx' or 'pytorch'. Defaults to 'onnx'.

'onnx'
device str

What device to use for computation, supports ['cpu', 'cuda', 'gpu']; defaults to None, which selects the device automatically.

None
context str

Deprecated, use device instead. What device to use for computation, supports ['cpu', 'cuda', 'gpu']; defaults to None, which selects the device automatically.

None
model_dir Optional[Union[str, Path]]

The model file directory. Defaults to None.

None
root Union[str, Path]

The model root directory. Defaults to data_dir().

'/home/docs/.pix2text'
more_processor_configs Optional[Dict[str, Any]]

Additional processor configurations. Defaults to None.

None
more_model_configs Optional[Dict[str, Any]]

Additional model configurations. Defaults to None.

  • provider (str, defaults to None, which means to select one provider automatically): ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ for possible providers.
  • session_options (Optional[onnxruntime.SessionOptions], defaults to None),: ONNX Runtime session options to use for loading the model.
  • provider_options (Optional[Dict[str, Any]], defaults to None): Provider option dictionaries corresponding to the provider used. See available options for each provider: https://onnxruntime.ai/docs/api/c/group___global.html .
  • ...: see more information here: optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained()
None
**kwargs

Additional arguments, currently not used.

{}
Source code in pix2text/latex_ocr.py
def __init__(
    self,
    *,
    model_name: str = 'mfr',
    model_backend: str = 'onnx',
    device: str = None,
    context: str = None,  # deprecated, use `device` instead
    model_dir: Optional[Union[str, Path]] = None,
    root: Union[str, Path] = data_dir(),
    more_processor_configs: Optional[Dict[str, Any]] = None,
    more_model_configs: Optional[Dict[str, Any]] = None,
    **kwargs,
):
    """Initialize a LatexOCR model.

    Args:
        model_name (str, optional): The name of the model. Defaults to 'mfr'.
        model_backend (str, optional): The model backend, either 'onnx' or 'pytorch'. Defaults to 'onnx'.
        device (str, optional): What device to use for computation, supports `['cpu', 'cuda', 'gpu']`; defaults to None, which selects the device automatically.
        context (str, optional): Deprecated, use `device` instead. What device to use for computation, supports `['cpu', 'cuda', 'gpu']`; defaults to None, which selects the device automatically.
        model_dir (Optional[Union[str, Path]], optional): The model file directory. Defaults to None.
        root (Union[str, Path], optional): The model root directory. Defaults to data_dir().
        more_processor_configs (Optional[Dict[str, Any]], optional): Additional processor configurations. Defaults to None.
        more_model_configs (Optional[Dict[str, Any]], optional): Additional model configurations. Defaults to None.

            - provider (`str`, defaults to `None`, which means to select one provider automatically):
                ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ for
                possible providers.
            - session_options (`Optional[onnxruntime.SessionOptions]`, defaults to `None`),:
                ONNX Runtime session options to use for loading the model.
            - provider_options (`Optional[Dict[str, Any]]`, defaults to `None`):
                Provider option dictionaries corresponding to the provider used. See available options
                for each provider: https://onnxruntime.ai/docs/api/c/group___global.html .
            - ...: see more information here: optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained()
        **kwargs: Additional arguments, currently not used.
    """

    if context is not None:
        logger.warning(f'`context` is deprecated, please use `device` instead')
    if device is None and context is not None:
        device = context
    self.device = select_device(device)

    model_info = AVAILABLE_MODELS.get_info(model_name, model_backend)

    if model_dir is None:
        model_dir = self._prepare_model_files(root, model_backend, model_info)
    logger.info(f'Use model dir for LatexOCR: {model_dir}')

    more_model_configs = more_model_configs or {}
    if model_backend == 'onnx' and 'provider' not in more_model_configs:
        available_providers = get_default_ort_providers()
        if not available_providers:
            raise RuntimeError(
                'No available providers for ONNX Runtime, please install onnxruntime-gpu or onnxruntime.'
            )
        more_model_configs['provider'] = available_providers[0]
    self.model, self.processor = self._init_model(
        model_backend,
        model_dir,
        more_processor_config=more_processor_configs,
        more_model_config=more_model_configs,
    )
    logger.info(
        f'Loaded Pix2Text MFR model {model_name} to: backend-{model_backend}, device-{self.device}'
    )
recognize(self, imgs, batch_size=1, use_post_process=True, rec_config=None, **kwargs)

Recognize Math Formula images to LaTeX Expressions

Parameters:

Name Type Description Default
imgs Union[str, Path, Image.Image, List[str], List[Path], List[Image.Image]

The image or list of images

required
batch_size int

The batch size

1
use_post_process bool

Whether to use post process. Defaults to True

True
rec_config Optional[dict]

The generation config

None
**kwargs

Other arguments. Not used for now

{}
Source code in pix2text/latex_ocr.py
def recognize(
    self,
    imgs: Union[str, Path, Image.Image, List[str], List[Path], List[Image.Image]],
    batch_size: int = 1,
    use_post_process: bool = True,
    rec_config: Optional[dict] = None,
    **kwargs,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
    """
    Recognize Math Formula images to LaTeX Expressions
    Args:
        imgs (Union[str, Path, Image.Image, List[str], List[Path], List[Image.Image]): The image or list of images
        batch_size (int): The batch size
        use_post_process (bool): Whether to use post process. Defaults to True
        rec_config (Optional[dict]): The generation config
        **kwargs (): Other arguments. Not used for now

    Returns: The LaTeX Result or list of LaTeX Results; each result is a dict with `text` and `score` fields.

    """
    is_single_image = False
    if isinstance(imgs, (str, Path, Image.Image)):
        imgs = [imgs]
        is_single_image = True

    input_imgs = prepare_imgs(imgs)

    # inference batch by batch
    results = []
    for i in tqdm.tqdm(range(0, len(input_imgs), batch_size)):
        part_imgs = input_imgs[i : i + batch_size]
        results.extend(self._one_batch(part_imgs, rec_config, **kwargs))

    if use_post_process:
        for info in results:
            info['text'] = self.post_process(info['text'])

    if is_single_image:
        return results[0]
    return results

Functions

fix_latex(latex)

对识别结果做进一步处理和修正。

Source code in pix2text/latex_ocr.py
def fix_latex(latex: str) -> str:
    """对识别结果做进一步处理和修正。"""
    # # 把latex中的中文括号全部替换成英文括号
    # latex = latex.replace('(', '(').replace(')', ')')
    # # 把latex中的中文逗号全部替换成英文逗号
    # latex = latex.replace(',', ',')

    left_bracket_infos = find_all_left_or_right(latex, left_or_right='left')
    right_bracket_infos = find_all_left_or_right(latex, left_or_right='right')
    # left 和 right 找配对,left找位置比它靠前且最靠近他的right配对
    for left_bracket_info in left_bracket_infos:
        for right_bracket_info in right_bracket_infos:
            if (
                not right_bracket_info.get('matched', False)
                and right_bracket_info['start'] > left_bracket_info['start']
                and match_left_right(
                    right_bracket_info['str'], left_bracket_info['str']
                )
            ):
                left_bracket_info['matched'] = True
                right_bracket_info['matched'] = True
                break

    for left_bracket_info in left_bracket_infos:
        # 把没有匹配的 '\left'替换为等长度的空格
        left_len = len('left') + 1
        if not left_bracket_info.get('matched', False):
            start_idx = left_bracket_info['start']
            end_idx = start_idx + left_len
            latex = (
                latex[: left_bracket_info['start']]
                + ' ' * (end_idx - start_idx)
                + latex[end_idx:]
            )
    for right_bracket_info in right_bracket_infos:
        # 把没有匹配的 '\right'替换为等长度的空格
        right_len = len('right') + 1
        if not right_bracket_info.get('matched', False):
            start_idx = right_bracket_info['start']
            end_idx = start_idx + right_len
            latex = (
                latex[: right_bracket_info['start']]
                + ' ' * (end_idx - start_idx)
                + latex[end_idx:]
            )

    # 把 latex 中的连续空格替换为一个空格
    latex = re.sub(r'\s+', ' ', latex)
    return latex.strip()

match_left_right(left_str, right_str)

匹配左右括号,如匹配 \left(ight)

Source code in pix2text/latex_ocr.py
def match_left_right(left_str, right_str):
    """匹配左右括号,如匹配 `\left(` 和 `\right)`。"""
    left_str = left_str.strip().replace(' ', '')[len('left') + 1 :]
    right_str = right_str.strip().replace(' ', '')[len('right') + 1 :]
    # 去掉开头的相同部分
    while left_str and right_str and left_str[0] == right_str[0]:
        left_str = left_str[1:]
        right_str = right_str[1:]

    match_pairs = [
        ('', ''),
        ('(', ')'),
        ('\{', '.'),  # 大括号那种
        ('⟮', '⟯'),
        ('[', ']'),
        ('⟨', '⟩'),
        ('{', '}'),
        ('⌈', '⌉'),
        ('┌', '┐'),
        ('⌊', '⌋'),
        ('└', '┘'),
        ('⎰', '⎱'),
        ('lt', 'gt'),
        ('lang', 'rang'),
        (r'langle', r'rangle'),
        (r'lbrace', r'rbrace'),
        ('lBrace', 'rBrace'),
        (r'lbracket', r'rbracket'),
        (r'lceil', r'rceil'),
        ('lcorner', 'rcorner'),
        (r'lfloor', r'rfloor'),
        (r'lgroup', r'rgroup'),
        (r'lmoustache', r'rmoustache'),
        (r'lparen', r'rparen'),
        (r'lvert', r'rvert'),
        (r'lVert', r'rVert'),
    ]
    return (left_str, right_str) in match_pairs