Source code for churro_ocr.providers.builder

"""Public OCR backend builder."""

from __future__ import annotations

from churro_ocr._internal.litellm import LiteLLMTransport
from churro_ocr.errors import ConfigurationError
from churro_ocr.ocr import OCRBackend
from churro_ocr.providers.hf import (
    DotsOCR15OCRBackend,
    HuggingFaceVisionOCRBackend,
    _default_dots_ocr_1_5_model_kwargs,
)
from churro_ocr.providers.ocr import (
    AzureDocumentIntelligenceOCRBackend,
    LiteLLMVisionOCRBackend,
    MistralOCRBackend,
    OpenAICompatibleOCRBackend,
)
from churro_ocr.providers.specs import (
    AzureDocumentIntelligenceOptions,
    HuggingFaceOptions,
    LiteLLMTransportConfig,
    MistralOptions,
    OCRBackendSpec,
    OCRModelProfile,
    OpenAICompatibleOptions,
    VLLMOptions,
    resolve_ocr_profile,
)
from churro_ocr.providers.vllm import VLLMVisionOCRBackend


def _merge_mapping(
    base: dict[str, object],
    override: dict[str, object],
) -> dict[str, object]:
    merged = dict(base)
    for key, value in override.items():
        existing = merged.get(key)
        if isinstance(existing, dict) and isinstance(value, dict):
            merged[key] = {**existing, **value}
            continue
        merged[key] = value
    return merged


def _merge_transport_config(
    base: LiteLLMTransportConfig,
    override: LiteLLMTransportConfig | None,
) -> LiteLLMTransportConfig:
    if override is None:
        return LiteLLMTransportConfig(
            api_base=base.api_base,
            api_key=base.api_key,
            api_version=base.api_version,
            image_detail=base.image_detail,
            completion_kwargs=dict(base.completion_kwargs),
            cache_dir=base.cache_dir,
        )
    return LiteLLMTransportConfig(
        api_base=override.api_base if override.api_base is not None else base.api_base,
        api_key=override.api_key if override.api_key is not None else base.api_key,
        api_version=override.api_version if override.api_version is not None else base.api_version,
        image_detail=override.image_detail if override.image_detail is not None else base.image_detail,
        completion_kwargs=_merge_mapping(base.completion_kwargs, override.completion_kwargs),
        cache_dir=override.cache_dir if override.cache_dir is not None else base.cache_dir,
    )


def _merge_huggingface_options(
    base: HuggingFaceOptions,
    override: HuggingFaceOptions | None,
) -> HuggingFaceOptions:
    if override is None:
        return HuggingFaceOptions(
            trust_remote_code=base.trust_remote_code,
            processor_kwargs=dict(base.processor_kwargs),
            model_kwargs=dict(base.model_kwargs),
            generation_kwargs=dict(base.generation_kwargs),
            vision_input_builder=base.vision_input_builder,
            backend_variant=base.backend_variant,
        )
    return HuggingFaceOptions(
        trust_remote_code=(
            override.trust_remote_code if override.trust_remote_code is not None else base.trust_remote_code
        ),
        processor_kwargs=_merge_mapping(base.processor_kwargs, override.processor_kwargs),
        model_kwargs=_merge_mapping(base.model_kwargs, override.model_kwargs),
        generation_kwargs=_merge_mapping(base.generation_kwargs, override.generation_kwargs),
        vision_input_builder=override.vision_input_builder or base.vision_input_builder,
        backend_variant=override.backend_variant or base.backend_variant,
    )


def _merge_vllm_options(
    base: VLLMOptions,
    override: VLLMOptions | None,
) -> VLLMOptions:
    if override is None:
        return VLLMOptions(
            trust_remote_code=base.trust_remote_code,
            processor_kwargs=dict(base.processor_kwargs),
            llm_kwargs=dict(base.llm_kwargs),
            sampling_kwargs=dict(base.sampling_kwargs),
            limit_mm_per_prompt=dict(base.limit_mm_per_prompt),
        )
    return VLLMOptions(
        trust_remote_code=(
            override.trust_remote_code if override.trust_remote_code is not None else base.trust_remote_code
        ),
        processor_kwargs=_merge_mapping(base.processor_kwargs, override.processor_kwargs),
        llm_kwargs=_merge_mapping(base.llm_kwargs, override.llm_kwargs),
        sampling_kwargs=_merge_mapping(base.sampling_kwargs, override.sampling_kwargs),
        limit_mm_per_prompt={
            **base.limit_mm_per_prompt,
            **override.limit_mm_per_prompt,
        },
    )


def _merge_openai_options(
    override: OpenAICompatibleOptions | None,
) -> OpenAICompatibleOptions:
    if override is None:
        return OpenAICompatibleOptions()
    return OpenAICompatibleOptions(model_prefix=override.model_prefix)


def _ensure_options_type[T](options: object | None, expected: type[T], *, provider: str) -> T | None:
    if options is None:
        return None
    if not isinstance(options, expected):
        raise ConfigurationError(
            f"OCR provider '{provider}' requires options of type {expected.__name__}, "
            f"got {type(options).__name__}."
        )
    return options


def _resolve_model_name(profile: OCRModelProfile, model: str | None, *, fallback: str) -> str:
    if profile.display_name is not None:
        return profile.display_name
    if model is not None:
        return model
    return fallback


def _build_litellm_backend(spec: OCRBackendSpec, profile: OCRModelProfile) -> OCRBackend:
    if spec.model is None:
        raise ConfigurationError("OCR provider 'litellm' requires `model`.")
    transport_config = _merge_transport_config(profile.transport, spec.transport)
    return LiteLLMVisionOCRBackend(
        model=spec.model,
        template=profile.template,
        model_name=_resolve_model_name(profile, spec.model, fallback=spec.model),
        transport=LiteLLMTransport(transport_config),
        image_preprocessor=profile.image_preprocessor,
        text_postprocessor=profile.text_postprocessor,
    )


def _build_openai_compatible_backend(spec: OCRBackendSpec, profile: OCRModelProfile) -> OCRBackend:
    if spec.model is None:
        raise ConfigurationError("OCR provider 'openai-compatible' requires `model`.")
    options = _merge_openai_options(
        _ensure_options_type(spec.options, OpenAICompatibleOptions, provider=spec.provider)
    )
    transport_config = _merge_transport_config(profile.transport, spec.transport)
    if not transport_config.api_base or not transport_config.api_key:
        raise ConfigurationError(
            "OCR provider 'openai-compatible' requires `transport.api_base` and `transport.api_key`."
        )
    return OpenAICompatibleOCRBackend(
        model=spec.model,
        model_prefix=options.model_prefix or "openai",
        model_name=_resolve_model_name(profile, spec.model, fallback=spec.model),
        template=profile.template,
        transport=LiteLLMTransport(transport_config),
        image_preprocessor=profile.image_preprocessor,
        text_postprocessor=profile.text_postprocessor,
    )


def _build_huggingface_backend(spec: OCRBackendSpec, profile: OCRModelProfile) -> OCRBackend:
    if spec.model is None:
        raise ConfigurationError("OCR provider 'hf' requires `model`.")
    options = _merge_huggingface_options(
        profile.huggingface,
        _ensure_options_type(spec.options, HuggingFaceOptions, provider=spec.provider),
    )
    backend_cls: type[HuggingFaceVisionOCRBackend] = HuggingFaceVisionOCRBackend
    model_kwargs = dict(options.model_kwargs)
    if options.backend_variant == "dots-ocr-1.5":
        backend_cls = DotsOCR15OCRBackend
        model_kwargs = _merge_mapping(_default_dots_ocr_1_5_model_kwargs(), model_kwargs)
    return backend_cls(
        model_id=spec.model,
        template=profile.template,
        model_name=_resolve_model_name(profile, spec.model, fallback=spec.model),
        trust_remote_code=bool(options.trust_remote_code),
        processor_kwargs=dict(options.processor_kwargs),
        model_kwargs=model_kwargs,
        generation_kwargs=dict(options.generation_kwargs),
        vision_input_builder=options.vision_input_builder,
        image_preprocessor=profile.image_preprocessor,
        text_postprocessor=profile.text_postprocessor,
    )


def _build_vllm_backend(spec: OCRBackendSpec, profile: OCRModelProfile) -> OCRBackend:
    if spec.model is None:
        raise ConfigurationError("OCR provider 'vllm' requires `model`.")
    options = _merge_vllm_options(
        profile.vllm,
        _ensure_options_type(spec.options, VLLMOptions, provider=spec.provider),
    )
    return VLLMVisionOCRBackend(
        model_id=spec.model,
        template=profile.template,
        model_name=_resolve_model_name(profile, spec.model, fallback=spec.model),
        trust_remote_code=bool(options.trust_remote_code),
        processor_kwargs=dict(options.processor_kwargs),
        llm_kwargs=dict(options.llm_kwargs),
        sampling_kwargs=dict(options.sampling_kwargs),
        limit_mm_per_prompt=dict(options.limit_mm_per_prompt) or {"image": 1},
        image_preprocessor=profile.image_preprocessor,
        text_postprocessor=profile.text_postprocessor,
    )


def _build_azure_backend(spec: OCRBackendSpec, profile: OCRModelProfile) -> OCRBackend:
    options = _ensure_options_type(spec.options, AzureDocumentIntelligenceOptions, provider=spec.provider)
    if options is None or not options.endpoint or not options.api_key:
        raise ConfigurationError(
            "OCR provider 'azure' requires AzureDocumentIntelligenceOptions(endpoint=..., api_key=...)."
        )
    model_id = spec.model or "prebuilt-layout"
    return AzureDocumentIntelligenceOCRBackend(
        endpoint=options.endpoint,
        api_key=options.api_key,
        model_id=model_id,
        model_name=_resolve_model_name(profile, spec.model, fallback=model_id),
        image_preprocessor=profile.image_preprocessor,
        text_postprocessor=profile.text_postprocessor,
    )


def _build_mistral_backend(spec: OCRBackendSpec, profile: OCRModelProfile) -> OCRBackend:
    options = _ensure_options_type(spec.options, MistralOptions, provider=spec.provider)
    if options is None or not options.api_key:
        raise ConfigurationError("OCR provider 'mistral' requires MistralOptions(api_key=...).")
    model = spec.model or "mistral-ocr-latest"
    return MistralOCRBackend(
        api_key=options.api_key,
        model=model,
        model_name=_resolve_model_name(profile, spec.model, fallback=model),
        image_preprocessor=profile.image_preprocessor,
        text_postprocessor=profile.text_postprocessor,
    )


[docs] def build_ocr_backend(spec: OCRBackendSpec) -> OCRBackend: """Build an OCR backend from a declarative spec. :param spec: Declarative backend specification. :returns: Configured OCR backend ready for use with ``OCRClient`` or ``DocumentOCRPipeline``. :raises ConfigurationError: If the provider is unsupported or required provider-specific configuration is missing. """ profile = resolve_ocr_profile(model_id=spec.model, profile=spec.profile) if spec.provider == "litellm": return _build_litellm_backend(spec, profile) if spec.provider == "openai-compatible": return _build_openai_compatible_backend(spec, profile) if spec.provider == "hf": return _build_huggingface_backend(spec, profile) if spec.provider == "vllm": return _build_vllm_backend(spec, profile) if spec.provider == "azure": return _build_azure_backend(spec, profile) if spec.provider == "mistral": return _build_mistral_backend(spec, profile) raise ConfigurationError(f"Unsupported OCR provider '{spec.provider}'.")
__all__ = ["build_ocr_backend"]