Source code for pyflink.dataframe.ai.providers

################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

"""
Model provider classes for the DataFrame LLM API.

Example::
    import pyflink.dataframe as pf

    # Typed provider with IDE auto-complete and validation
    provider = pf.OpenAICompatProvider(
        endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
        api_key="sk-...",
        temperature=0.7,
    )
    pf.set_provider(provider)

    # Generic provider for unknown/custom providers
    pf.set_provider("my-custom-provider", endpoint="https://...", api_key="sk-...")
"""

from abc import ABC, abstractmethod
from collections.abc import Mapping as MappingABC
import json
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union


_TritonDefaultValue = Union[
    str, int, float, bool, List[Any], Tuple[Any, ...], Mapping[str, Any]
]

_OPENAI_COMPAT_TASKS = frozenset((
    "chat/completions",
    "embeddings",
))
_DASHSCOPE_TASKS = _OPENAI_COMPAT_TASKS | frozenset((
    "multimodal-embedding",
))
_CONTEXT_OVERFLOW_ACTIONS = frozenset((
    "truncated-tail",
    "truncated-tail-log",
    "truncated-head",
    "truncated-head-log",
    "skipped",
    "skipped-log",
))


[docs] class Provider(ABC): """Base class for model providers. Subclasses represent specific provider configurations and handle the translation from Python-style parameter names to the Java-side option keys expected by Flink's ModelDescriptor. """ @abstractmethod def provider_name(self) -> str: """Return the provider identifier recognized by Flink's Java runtime.""" @abstractmethod def to_options(self) -> Dict[str, str]: """Return all configured options as a dict with Java-side key names.""" def model_option_key(self) -> str: """Return the Java-side option key used for a per-call model name.""" return "model"
[docs] class OpenAICompatProvider(Provider): """Provider for all OpenAI-compatible endpoints (openai-compat). Covers OpenAI, DeepSeek, Bailian, and any other service that implements the OpenAI chat/completions or embeddings API. Args: endpoint: The endpoint to connect to. Required with ``api_key``. api_key: The key used to authorize the access to the endpoint. Required when using BYOK models. task: The model task. Supported values are ``"chat/completions"``, and ``"embeddings"``. Required when using Flink AI Model Service. model: The version of the model to use. system_prompt: The system message of a chat. Can be disabled by setting to empty string. Defaults to ``"You are a helpful assistant."``. user_prompt: The prompt of a chat, passed to the model service through user's role. Can be disabled by setting to empty string. temperature: Controls the randomness or "creativity" of the output. Typical values are between 0.0 and 1.0. top_p: The probability cutoff for token selection. Usually either temperature or top_p are specified, but not both. max_tokens: The maximum number of tokens that can be generated in the chat completion. stop: A CSV list of strings to pass as stop sequences to the model. presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. n: How many chat completion choices to generate for each input message. Keep n as 1 to minimize costs. seed: If specified, the model platform will make a best effort to sample deterministically. Determinism is not guaranteed. content_type: Content type of the input string. Supported types: ``"TEXT"`` (default), ``"IMAGE_URL"``. response_format: The format of the response (``"text"`` or ``"json_object"``). dimension: The size of the embedding result array. max_context_size: Max number of tokens for context. ``context_overflow_action`` is triggered if this threshold is exceeded. context_overflow_action: Action to handle context overflows. One of ``"truncated-tail"``, ``"truncated-tail-log"``, ``"truncated-head"``, ``"truncated-head-log"``, ``"skipped"``, or ``"skipped-log"`` (case-insensitive). Defaults to ``"truncated-tail"``. error_handling_strategy: Strategy for handling errors during model requests. ``"RETRY"`` retries the request (limited by retry_num, retry_fallback_strategy, etc.); ``"FAILOVER"`` throws exceptions and fails the job; ``"IGNORE"`` skips the error input and continues. Defaults to ``"RETRY"``. retry_num: Number of retries for client requests. Defaults to ``100``. retry_backoff_strategy: The strategy to use for retry backoff. ``"FIXED"`` or ``"EXPONENTIAL"``. Defaults to ``"FIXED"``. retry_backoff_base_interval: The base interval for retry backoff, used as the initial delay before the first retry and as the base for calculating subsequent retry delays. Defaults to ``"1s"``. retry_fallback_strategy: Fallback strategy to employ if the retry attempts are exhausted. ``"FAILOVER"`` or ``"IGNORE"``. Defaults to ``"FAILOVER"``. extra_header: Additional headers for the requests. Should be a JSON-format string whose values are strings or list of strings. extra_body: Additional parameters to pass through the requests' body. Should be a JSON-format string. **extra_options: Additional options passed through as-is (keys are not translated). Example:: >>> provider = OpenAICompatProvider( ... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", ... api_key="sk-...", ... temperature=0.7, ... ) """ # Explicit mapping from Python parameter names to Java-side option keys. _OPTION_MAP = { "endpoint": "endpoint", "api_key": "api-key", "task": "task", "model": "model", "system_prompt": "system-prompt", "user_prompt": "user-prompt", "temperature": "temperature", "top_p": "top-p", "max_tokens": "max-tokens", "stop": "stop", "presence_penalty": "presence-penalty", "n": "n", "seed": "seed", "content_type": "content-type", "response_format": "response-format", "dimension": "dimension", "max_context_size": "max-context-size", "context_overflow_action": "context-overflow-action", "error_handling_strategy": "error-handling-strategy", "retry_num": "retry-num", "retry_backoff_strategy": "retry-backoff-strategy", "retry_backoff_base_interval": "retry-backoff-base-interval", "retry_fallback_strategy": "retry-fallback-strategy", "extra_header": "extra-header", "extra_body": "extra-body", } _SUPPORTED_TASKS = _OPENAI_COMPAT_TASKS def __init__( self, *, endpoint: Optional[str] = None, api_key: Optional[str] = None, task: Optional[str] = None, model: Optional[str] = None, system_prompt: str = "You are a helpful assistant.", user_prompt: Optional[str] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, stop: Optional[str] = None, presence_penalty: Optional[float] = None, n: Optional[int] = None, seed: Optional[int] = None, content_type: str = "TEXT", response_format: Optional[str] = None, dimension: Optional[int] = None, max_context_size: Optional[int] = None, context_overflow_action: str = "truncated-tail", error_handling_strategy: str = "RETRY", retry_num: int = 100, retry_backoff_strategy: str = "FIXED", retry_backoff_base_interval: str = "1s", retry_fallback_strategy: str = "FAILOVER", extra_header: Optional[str] = None, extra_body: Optional[str] = None, **extra_options: Any, ): self._validate_service_options(endpoint, api_key, task) self._validate_context_overflow_action(context_overflow_action) self._params: Dict[str, Any] = {} # Collect all explicitly named parameters that are not None. local_vars = { "endpoint": endpoint, "api_key": api_key, "task": task, "model": model, "system_prompt": system_prompt, "user_prompt": user_prompt, "temperature": temperature, "top_p": top_p, "max_tokens": max_tokens, "stop": stop, "presence_penalty": presence_penalty, "n": n, "seed": seed, "content_type": content_type, "response_format": response_format, "dimension": dimension, "max_context_size": max_context_size, "context_overflow_action": context_overflow_action, "error_handling_strategy": error_handling_strategy, "retry_num": retry_num, "retry_backoff_strategy": retry_backoff_strategy, "retry_backoff_base_interval": retry_backoff_base_interval, "retry_fallback_strategy": retry_fallback_strategy, "extra_header": extra_header, "extra_body": extra_body, } for key, value in local_vars.items(): if value is not None: self._params[key] = value self._extra_options = extra_options def provider_name(self) -> str: return "openai-compat" @classmethod def _validate_service_options( cls, endpoint: Optional[str], api_key: Optional[str], task: Optional[str], ) -> None: if api_key is not None: if endpoint is None or not endpoint.strip(): raise ValueError("'api_key' requires 'endpoint'.") if task is not None: raise ValueError("'api_key' cannot be combined with 'task'.") if task is not None and task not in cls._SUPPORTED_TASKS: supported = ", ".join(sorted(cls._SUPPORTED_TASKS)) raise ValueError( f"Unsupported task {task!r}. Supported values: {supported}.") @staticmethod def _validate_context_overflow_action(value: str) -> None: if value not in _CONTEXT_OVERFLOW_ACTIONS: supported = ", ".join(sorted(_CONTEXT_OVERFLOW_ACTIONS)) raise ValueError( f"Unsupported context_overflow_action {value!r}. " f"Supported values: {supported}.") def to_options(self) -> Dict[str, str]: options: Dict[str, str] = {} for py_key, value in self._params.items(): java_key = self._OPTION_MAP[py_key] options[java_key] = str(value) # Extra options are passed through as-is. for key, value in self._extra_options.items(): options[key] = str(value) return options
[docs] class DashScopeProvider(OpenAICompatProvider): """Provider for Alibaba Cloud DashScope (dashscope). DashScope reuses the OpenAI-compatible Java provider for normal chat and embedding requests, and adds DashScope-specific multi-modal embedding support. Args: endpoint: The endpoint to connect to. Required with ``api_key``. api_key: The key used to authorize the access to the endpoint. Required when using BYOK models. task: The model task. Supported values are ``"chat/completions"``, ``"embeddings"``, and ``"multimodal-embedding"``. Required when using Flink AI Model Service. model: The version of the model to use. system_prompt: The system message of a chat. Can be disabled by setting to empty string. Defaults to ``"You are a helpful assistant."``. user_prompt: The prompt of a chat, passed to the model service through user's role. Can be disabled by setting to empty string. temperature: Controls the randomness or "creativity" of the output. Typical values are between 0.0 and 1.0. top_p: The probability cutoff for token selection. Usually either temperature or top_p are specified, but not both. max_tokens: The maximum number of tokens that can be generated in the chat completion. stop: A CSV list of strings to pass as stop sequences to the model. presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. n: How many chat completion choices to generate for each input message. Keep n as 1 to minimize costs. seed: If specified, the model platform will make a best effort to sample deterministically. Determinism is not guaranteed. content_type: Content type of the input string. Supported types: ``"TEXT"`` (default), ``"IMAGE_URL"``. For multi-modal embedding, set this to ``"IMAGE_URL"``. response_format: The format of the response (``"text"`` or ``"json_object"``). dimension: The size of the embedding result array. max_context_size: Max number of tokens for context. ``context_overflow_action`` is triggered if this threshold is exceeded. context_overflow_action: Action to handle context overflows. One of ``"truncated-tail"``, ``"truncated-tail-log"``, ``"truncated-head"``, ``"truncated-head-log"``, ``"skipped"``, or ``"skipped-log"`` (case-insensitive). Defaults to ``"truncated-tail"``. error_handling_strategy: Strategy for handling errors during model requests. ``"RETRY"`` retries the request (limited by retry_num, retry_fallback_strategy, etc.); ``"FAILOVER"`` throws exceptions and fails the job; ``"IGNORE"`` skips the error input and continues. Defaults to ``"RETRY"``. retry_num: Number of retries for client requests. Defaults to ``100``. retry_backoff_strategy: The strategy to use for retry backoff. ``"FIXED"`` or ``"EXPONENTIAL"``. Defaults to ``"FIXED"``. retry_backoff_base_interval: The base interval for retry backoff, used as the initial delay before the first retry and as the base for calculating subsequent retry delays. Defaults to ``"1s"``. retry_fallback_strategy: Fallback strategy to employ if the retry attempts are exhausted. ``"FAILOVER"`` or ``"IGNORE"``. Defaults to ``"FAILOVER"``. extra_header: Additional headers for the requests. Should be a JSON-format string whose values are strings or list of strings. extra_body: Additional parameters to pass through the requests' body. Should be a JSON-format string. Multi-modal embedding tasks do not support this option. **options: Additional options passed through as-is (keys are not translated). Examples:: >>> provider = DashScopeProvider( ... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", ... api_key="sk-...", ... model="qwen-plus", ... temperature=0.7, ... ) >>> provider = DashScopeProvider( ... endpoint="https://dashscope.aliyuncs.com/api/v1/services/" ... "embeddings/multimodal-embedding/multimodal-embedding", ... api_key="sk-...", ... model="tongyi-embedding-vision-plus", ... content_type="IMAGE_URL", ... ) >>> pf.set_provider(provider) >>> df = pf.from_dict({"image_url": [ ... "https://dashscope.oss-cn-beijing.aliyuncs.com/images/tiger.png" ... ]}) >>> embeddings = df.llm.ai_embed("image_url", dimension=512) """ _SUPPORTED_TASKS = _DASHSCOPE_TASKS def provider_name(self) -> str: return "dashscope"
[docs] class TritonProvider(Provider): """Provider for NVIDIA Triton Inference Server (triton). Args: endpoint: Full URL of the Triton Inference Server endpoint. model_name: Name of the model to invoke. This can also be provided through ``df.llm.predict(..., model="...")``, which maps to the Java-side ``model-name`` option for Triton. model_version: Version of the model to use. Defaults to ``"latest"``. timeout: HTTP request timeout, for example ``"10s"`` or ``"30000ms"``. Defaults to ``"30s"``. flatten_batch_dim: Whether to flatten the leading batch dimension for array inputs. For ``ARRAY<T>`` inputs, the default shape is ``[1, N]``, where ``N`` is the array length. Set this to ``True`` when the Triton model expects ``[N]`` instead. Defaults to ``False``. priority: Triton request priority level. sequence_id: Sequence ID for stateful models. sequence_start: Whether this request starts a stateful sequence. Defaults to ``False``. sequence_end: Whether this request ends a stateful sequence. Defaults to ``False``. compression: Compression algorithm for the request body. Currently Triton provider supports ``"gzip"``. auth_token: Authentication token for secured Triton servers. The Java provider sends it as a Bearer token. custom_headers: Custom HTTP headers as a Flink map string with comma-separated ``key:value`` pairs (e.g. ``"X-Trace-Id:abc,X-Other:val"``). max_retries: Maximum number of retries for failed inference requests. Defaults to ``0``. retry_initial_backoff: Initial backoff duration between retry attempts. Defaults to ``"100ms"``. retry_max_backoff: Upper bound on the delay between retry attempts. Defaults to ``"30s"``. default_value: Fallback value to return when inference fails after retries or with a non-retryable error: - If not specified, inference failures are propagated as exceptions. - For ``STRING`` outputs, pass plain text such as ``"FAILED"``. - For numeric outputs, pass the numeric value or its string representation, such as ``-1`` or ``"-1"``. - For ``ARRAY`` or structured outputs, pass a JSON string or the corresponding Python list, tuple, or mapping; Python containers are serialized as JSON. - To emit SQL ``NULL``, pass the lower-case literal ``"null"``. For string outputs, ``"null"`` is therefore not usable as a literal string sentinel; use values such as ``"NULL"``, ``"FAILED"``, or ``"<null>"`` instead. health_check_enabled: Whether to enable periodic health checks for the Triton server. Defaults to ``False``. health_check_interval: Interval between health check requests. Defaults to ``"30s"``. circuit_breaker_enabled: Whether to enable circuit breaker protection for Triton inference requests. Defaults to ``False``. circuit_breaker_failure_threshold: Failure rate threshold that opens the circuit breaker. Must be in ``(0.0, 1.0]``. Defaults to ``0.5``. circuit_breaker_timeout: Duration to keep the circuit breaker open before probing recovery. Defaults to ``"60s"``. circuit_breaker_half_open_requests: Number of successful half-open probe requests required to close the circuit. Defaults to ``3``. **extra_options: Additional options passed through as-is. Examples:: >>> import pyflink.dataframe as pf >>> >>> # Classifier with ARRAY<FLOAT> features and BIGINT class output. >>> provider = TritonProvider( ... endpoint="<Your Triton endpoint>", ... auth_token="<Your authentication token>", ... model_name="classifier", ... compression="gzip", ... ) >>> pf.set_provider(provider) >>> df = pf.from_records( ... [([5.1, 3.5, 1.4, 0.2],), ...], ... schema=["features"]) >>> result = df.llm.predict( ... "features", ... output_type={"class_id": "BIGINT"}) >>> >>> # Stateful conversation model with a fixed Triton sequence. >>> provider = TritonProvider( ... endpoint="<Your Triton endpoint>", ... auth_token="<Your authentication token>", ... model_name="chatbot_lstm", ... sequence_id="conv-001", ... sequence_start=True, ... sequence_end=False, ... ) >>> pf.set_provider(provider) >>> chat_messages = pf.from_records( ... [("hello",), ...], ... schema=["message_text"]) >>> result = chat_messages.llm.predict( ... "message_text", ... output_type={"bot_response": "STRING"}) >>> >>> # Vector transform model with ARRAY<FLOAT> input and output. >>> provider = TritonProvider( ... endpoint="<Your Triton endpoint>", ... auth_token="<Your authentication token>", ... model_name="vector-transform", ... flatten_batch_dim=True, # Used when Triton model expects one-dimensional input ... ) >>> pf.set_provider(provider) >>> vector_input = pf.from_records( ... [([0.1, 0.2, 0.3, ...],), ...], ... schema=["features"]) >>> result = vector_input.llm.predict( ... "features", ... output_type={"output_vector": "ARRAY<FLOAT>"}) """ _OPTION_MAP = { "endpoint": "endpoint", "model_name": "model-name", "model_version": "model-version", "timeout": "timeout", "flatten_batch_dim": "flatten-batch-dim", "priority": "priority", "sequence_id": "sequence-id", "sequence_start": "sequence-start", "sequence_end": "sequence-end", "compression": "compression", "auth_token": "auth-token", "custom_headers": "custom-headers", "max_retries": "max-retries", "retry_initial_backoff": "retry-initial-backoff", "retry_max_backoff": "retry-max-backoff", "default_value": "default-value", "health_check_enabled": "health-check-enabled", "health_check_interval": "health-check-interval", "circuit_breaker_enabled": "circuit-breaker-enabled", "circuit_breaker_failure_threshold": ( "circuit-breaker-failure-threshold" ), "circuit_breaker_timeout": "circuit-breaker-timeout", "circuit_breaker_half_open_requests": ( "circuit-breaker-half-open-requests" ), } def __init__( self, *, endpoint: str, model_name: Optional[str] = None, model_version: str = "latest", timeout: str = "30s", flatten_batch_dim: bool = False, priority: Optional[int] = None, sequence_id: Optional[str] = None, sequence_start: bool = False, sequence_end: bool = False, compression: Optional[Literal["gzip"]] = None, auth_token: Optional[str] = None, custom_headers: Optional[str] = None, max_retries: int = 0, retry_initial_backoff: str = "100ms", retry_max_backoff: str = "30s", default_value: Optional[_TritonDefaultValue] = None, health_check_enabled: bool = False, health_check_interval: str = "30s", circuit_breaker_enabled: bool = False, circuit_breaker_failure_threshold: float = 0.5, circuit_breaker_timeout: str = "60s", circuit_breaker_half_open_requests: int = 3, **extra_options: Any, ): if not isinstance(endpoint, str): raise TypeError(f"endpoint must be a string, got {type(endpoint)}") if compression is not None and compression != "gzip": raise ValueError( "compression must be 'gzip' when provided, got " f"{compression!r}.") if priority is not None and not 0 <= priority <= 255: raise ValueError("priority must be in range [0, 255].") if max_retries < 0: raise ValueError("max_retries must be >= 0.") if not 0.0 < circuit_breaker_failure_threshold <= 1.0: raise ValueError( "circuit_breaker_failure_threshold must be in range " "(0.0, 1.0].") if circuit_breaker_half_open_requests <= 0: raise ValueError( "circuit_breaker_half_open_requests must be positive.") if custom_headers is not None: if not isinstance(custom_headers, str): raise TypeError( f"custom_headers must be a string, got " f"{type(custom_headers)}") if "\n" in custom_headers or "\r" in custom_headers: raise ValueError("custom_headers cannot contain line breaks.") default_value_value = self._format_default_value(default_value) self._params: Dict[str, Any] = {} local_vars = { "endpoint": endpoint, "model_name": model_name, "model_version": model_version, "timeout": timeout, "flatten_batch_dim": flatten_batch_dim, "priority": priority, "sequence_id": sequence_id, "sequence_start": sequence_start, "sequence_end": sequence_end, "compression": compression, "auth_token": auth_token, "custom_headers": custom_headers, "max_retries": max_retries, "retry_initial_backoff": retry_initial_backoff, "retry_max_backoff": retry_max_backoff, "default_value": default_value_value, "health_check_enabled": health_check_enabled, "health_check_interval": health_check_interval, "circuit_breaker_enabled": circuit_breaker_enabled, "circuit_breaker_failure_threshold": ( circuit_breaker_failure_threshold ), "circuit_breaker_timeout": circuit_breaker_timeout, "circuit_breaker_half_open_requests": ( circuit_breaker_half_open_requests ), } for key, value in local_vars.items(): if value is not None: self._params[key] = value self._extra_options = extra_options def provider_name(self) -> str: return "triton" def model_option_key(self) -> str: return "model-name" def to_options(self) -> Dict[str, str]: options: Dict[str, str] = {} for py_key, value in self._params.items(): java_key = self._OPTION_MAP[py_key] options[java_key] = self._stringify(value) for key, value in self._extra_options.items(): options[key] = self._stringify(value) return options @staticmethod def _format_default_value( default_value: Optional[_TritonDefaultValue], ) -> Optional[str]: if default_value is None: return None if isinstance(default_value, bool): return "true" if default_value else "false" if isinstance(default_value, (list, tuple, MappingABC)): payload = ( dict(default_value) if isinstance(default_value, MappingABC) else default_value ) try: return json.dumps(payload, allow_nan=False) except (TypeError, ValueError) as exc: raise ValueError( "default_value contains a non-JSON-serializable " f"element: {exc}") from exc return str(default_value) @staticmethod def _stringify(value: Any) -> str: if isinstance(value, bool): return str(value).lower() return str(value)
[docs] class GenericProvider(Provider): """Generic provider for unknown or custom model providers. Options are passed through to the Java side as-is without any key name translation. Args: name: Provider identifier recognized by Flink's Java runtime. **options: Arbitrary key-value options. Example:: >>> provider = GenericProvider("my-provider", endpoint="https://...", ... **{"api-key": "sk-..."}) """ def __init__(self, name: str, **options: Any): self._name = name self._options = options def provider_name(self) -> str: return self._name def to_options(self) -> Dict[str, str]: return {k: str(v) for k, v in self._options.items()}