Source code for churro_ocr.providers.specs

"""Public OCR provider specs, options, and model profile resolution."""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Literal

from PIL import Image

from churro_ocr._internal.image import prepare_ocr_image
from churro_ocr.prompts import DEFAULT_OCR_OUTPUT_TAG, strip_ocr_output_tag
from churro_ocr.templates import (
    CHURRO_3B_MODEL_ID,
    CHURRO_3B_XML_TEMPLATE,
    DEFAULT_OCR_TEMPLATE,
    DOTS_OCR_1_5_MODEL_ID,
    DOTS_OCR_1_5_OCR_TEMPLATE,
    OCRConversation,
    OCRPromptTemplateLike,
)

if TYPE_CHECKING:
    pass


OCRProvider = Literal["litellm", "openai-compatible", "azure", "mistral", "hf", "vllm"]
ImagePreprocessor = Callable[[Image.Image], Image.Image]
TextPostprocessor = Callable[[str], str]
VisionInputBuilder = Callable[[OCRConversation], object]
DEFAULT_OCR_MAX_TOKENS = 20_000


[docs] def identity_text_postprocessor(text: str) -> str: """Return OCR text unchanged. :param text: OCR text to return. :returns: The original ``text`` value. """ return text
[docs] def default_ocr_image_preprocessor(image: Image.Image) -> Image.Image: """Apply the default OCR image preprocessing. :param image: Source page image. :returns: Preprocessed image ready for OCR. """ return prepare_ocr_image(image)
[docs] def default_ocr_text_postprocessor(text: str) -> str: """Strip the default OCR output tag wrapper. :param text: Raw OCR response text. :returns: OCR text with the default wrapper removed when present. """ return strip_ocr_output_tag(text, output_tag=DEFAULT_OCR_OUTPUT_TAG)
[docs] @dataclass(slots=True, frozen=True) class LiteLLMTransportConfig: """Shared transport config for LiteLLM-based multimodal requests. :param api_base: Optional API base URL override. :param api_key: Optional API key forwarded to LiteLLM. :param api_version: Optional API version string for providers that need one. :param image_detail: Optional image-detail hint supported by some providers. :param completion_kwargs: Extra completion kwargs merged into LiteLLM calls. :param cache_dir: Optional disk-cache directory for LiteLLM request caching. """ api_base: str | None = None api_key: str | None = None api_version: str | None = None image_detail: str | None = None completion_kwargs: dict[str, object] = field(default_factory=dict) cache_dir: str | Path | None = None
[docs] @dataclass(slots=True, frozen=True) class OpenAICompatibleOptions: """Provider options for OpenAI-compatible OCR servers. :param model_prefix: Provider prefix prepended to the configured model name. """ model_prefix: str | None = None
[docs] @dataclass(slots=True, frozen=True) class HuggingFaceOptions: """Provider options for local Hugging Face OCR backends. :param trust_remote_code: Whether to allow remote model code execution. :param processor_kwargs: Extra kwargs passed to ``AutoProcessor.from_pretrained``. :param model_kwargs: Extra kwargs passed to model loading. :param generation_kwargs: Extra generation kwargs passed at inference time. :param vision_input_builder: Optional override for building multimodal inputs. :param backend_variant: Optional implementation preset such as ``"dots-ocr-1.5"``. """ trust_remote_code: bool | None = None processor_kwargs: dict[str, object] = field(default_factory=dict) model_kwargs: dict[str, object] = field(default_factory=dict) generation_kwargs: dict[str, object] = field(default_factory=dict) vision_input_builder: VisionInputBuilder | None = None backend_variant: str | None = None
[docs] @dataclass(slots=True, frozen=True) class VLLMOptions: """Provider options for local vLLM OCR backends. :param trust_remote_code: Whether to allow remote model code execution. :param processor_kwargs: Extra kwargs passed to ``AutoProcessor.from_pretrained``. :param llm_kwargs: Extra kwargs passed to the vLLM ``LLM`` constructor. :param sampling_kwargs: Extra kwargs passed to vLLM sampling params. :param limit_mm_per_prompt: Per-request multimodal limits passed to vLLM. """ trust_remote_code: bool | None = None processor_kwargs: dict[str, object] = field(default_factory=dict) llm_kwargs: dict[str, object] = field(default_factory=dict) sampling_kwargs: dict[str, object] = field(default_factory=dict) limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
[docs] @dataclass(slots=True, frozen=True) class AzureDocumentIntelligenceOptions: """Provider options for Azure Document Intelligence OCR. :param endpoint: Azure Document Intelligence endpoint URL. :param api_key: Azure API key for the configured resource. """ endpoint: str | None = None api_key: str | None = None
[docs] @dataclass(slots=True, frozen=True) class MistralOptions: """Provider options for Mistral OCR. :param api_key: Mistral API key used for OCR requests. """ api_key: str | None = None
OCRProviderOptions = ( OpenAICompatibleOptions | HuggingFaceOptions | VLLMOptions | AzureDocumentIntelligenceOptions | MistralOptions )
[docs] @dataclass(slots=True, frozen=True) class OCRModelProfile: """Model-level OCR behavior shared across provider adapters. :param profile_name: Stable profile identifier. :param template: Prompt template used to render OCR input. :param image_preprocessor: Image preprocessor applied before OCR. :param text_postprocessor: Text postprocessor applied after OCR. :param display_name: Optional human-readable model name. :param transport: Default LiteLLM transport settings for this profile. :param huggingface: Default Hugging Face backend options for this profile. :param vllm: Default vLLM backend options for this profile. """ profile_name: str template: OCRPromptTemplateLike = DEFAULT_OCR_TEMPLATE image_preprocessor: ImagePreprocessor = default_ocr_image_preprocessor text_postprocessor: TextPostprocessor = default_ocr_text_postprocessor display_name: str | None = None transport: LiteLLMTransportConfig = field(default_factory=LiteLLMTransportConfig) huggingface: HuggingFaceOptions = field(default_factory=HuggingFaceOptions) vllm: VLLMOptions = field(default_factory=VLLMOptions)
[docs] @dataclass(slots=True, frozen=True) class OCRBackendSpec: """Declarative builder input for OCR backends. :param provider: OCR provider identifier. :param model: Provider-specific model identifier. :param profile: Optional built-in or custom model profile. :param transport: Optional transport settings for LiteLLM-based providers. :param options: Optional provider-specific options dataclass. """ provider: OCRProvider model: str | None = None profile: str | OCRModelProfile | None = None transport: LiteLLMTransportConfig | None = None options: OCRProviderOptions | None = None
[docs] def default_ocr_profile() -> OCRModelProfile: """Return the generic OCR model profile. :returns: Baseline profile used when no more specific profile matches. """ return OCRModelProfile(profile_name="default")
def churro_3b_profile() -> OCRModelProfile: """Return the built-in ``stanford-oval/churro-3B`` OCR profile. :returns: Profile configured for the built-in CHURRO 3B template. """ return OCRModelProfile( profile_name=CHURRO_3B_MODEL_ID, template=CHURRO_3B_XML_TEMPLATE, text_postprocessor=identity_text_postprocessor, display_name="churro-3B", ) def dots_ocr_1_5_profile() -> OCRModelProfile: """Return the built-in ``kristaller486/dots.ocr-1.5`` OCR profile. :returns: Profile configured for the built-in Dots OCR 1.5 template. """ return OCRModelProfile( profile_name=DOTS_OCR_1_5_MODEL_ID, template=DOTS_OCR_1_5_OCR_TEMPLATE, text_postprocessor=identity_text_postprocessor, display_name="dots.ocr-1.5", huggingface=HuggingFaceOptions( trust_remote_code=True, backend_variant="dots-ocr-1.5", ), vllm=VLLMOptions( trust_remote_code=True, ), ) def _profile_registry() -> dict[str, OCRModelProfile]: default_profile = default_ocr_profile() churro_profile = churro_3b_profile() dots_profile = dots_ocr_1_5_profile() return { default_profile.profile_name: default_profile, churro_profile.profile_name: churro_profile, dots_profile.profile_name: dots_profile, }
[docs] def resolve_ocr_profile( *, model_id: str | None, profile: str | OCRModelProfile | None = None, ) -> OCRModelProfile: """Resolve the OCR model profile for a model or explicit profile. :param model_id: Model identifier that may map to a built-in profile. :param profile: Explicit profile name or profile object to use. :returns: Resolved OCR model profile. :raises ValueError: If ``profile`` is a string that does not match a known profile. """ if isinstance(profile, OCRModelProfile): return profile registry = _profile_registry() if isinstance(profile, str): try: return registry[profile] except KeyError as exc: raise ValueError(f"Unknown OCR profile '{profile}'.") from exc if model_id is not None and model_id in registry: return registry[model_id] return registry["default"]
__all__ = [ "AzureDocumentIntelligenceOptions", "DEFAULT_OCR_MAX_TOKENS", "default_ocr_image_preprocessor", "default_ocr_profile", "default_ocr_text_postprocessor", "HuggingFaceOptions", "identity_text_postprocessor", "ImagePreprocessor", "LiteLLMTransportConfig", "MistralOptions", "OCRBackendSpec", "OCRModelProfile", "OCRProvider", "OpenAICompatibleOptions", "resolve_ocr_profile", "TextPostprocessor", "VisionInputBuilder", "VLLMOptions", ]