################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
Model provider classes for the DataFrame LLM API.
Example::
import pyflink.dataframe as pf
# Typed provider with IDE auto-complete and validation
provider = pf.OpenAICompatProvider(
endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
api_key="sk-...",
temperature=0.7,
)
pf.set_provider(provider)
# Generic provider for unknown/custom providers
pf.set_provider("my-custom-provider", endpoint="https://...", api_key="sk-...")
"""
from abc import ABC, abstractmethod
from collections.abc import Mapping as MappingABC
import json
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union
_TritonDefaultValue = Union[
str, int, float, bool, List[Any], Tuple[Any, ...], Mapping[str, Any]
]
_OPENAI_COMPAT_TASKS = frozenset((
"chat/completions",
"embeddings",
))
_DASHSCOPE_TASKS = _OPENAI_COMPAT_TASKS | frozenset((
"multimodal-embedding",
))
_CONTEXT_OVERFLOW_ACTIONS = frozenset((
"truncated-tail",
"truncated-tail-log",
"truncated-head",
"truncated-head-log",
"skipped",
"skipped-log",
))
[docs]
class Provider(ABC):
"""Base class for model providers.
Subclasses represent specific provider configurations and handle
the translation from Python-style parameter names to the Java-side
option keys expected by Flink's ModelDescriptor.
"""
@abstractmethod
def provider_name(self) -> str:
"""Return the provider identifier recognized by Flink's Java runtime."""
@abstractmethod
def to_options(self) -> Dict[str, str]:
"""Return all configured options as a dict with Java-side key names."""
def model_option_key(self) -> str:
"""Return the Java-side option key used for a per-call model name."""
return "model"
[docs]
class OpenAICompatProvider(Provider):
"""Provider for all OpenAI-compatible endpoints (openai-compat).
Covers OpenAI, DeepSeek, Bailian, and any other service that
implements the OpenAI chat/completions or embeddings API.
Args:
endpoint: The endpoint to connect to. Required with ``api_key``.
api_key: The key used to authorize the access to the endpoint.
Required when using BYOK models.
task: The model task. Supported values are ``"chat/completions"``,
and ``"embeddings"``.
Required when using Flink AI Model Service.
model: The version of the model to use.
system_prompt: The system message of a chat. Can be disabled by
setting to empty string.
Defaults to ``"You are a helpful assistant."``.
user_prompt: The prompt of a chat, passed to the model service
through user's role. Can be disabled by setting to empty string.
temperature: Controls the randomness or "creativity" of the output.
Typical values are between 0.0 and 1.0.
top_p: The probability cutoff for token selection. Usually either
temperature or top_p are specified, but not both.
max_tokens: The maximum number of tokens that can be generated
in the chat completion.
stop: A CSV list of strings to pass as stop sequences to the model.
presence_penalty: Number between -2.0 and 2.0. Positive values
penalize new tokens based on whether they appear in the text
so far, increasing the model's likelihood to talk about new
topics.
n: How many chat completion choices to generate for each input
message. Keep n as 1 to minimize costs.
seed: If specified, the model platform will make a best effort to
sample deterministically. Determinism is not guaranteed.
content_type: Content type of the input string. Supported types:
``"TEXT"`` (default), ``"IMAGE_URL"``.
response_format: The format of the response (``"text"`` or
``"json_object"``).
dimension: The size of the embedding result array.
max_context_size: Max number of tokens for context.
``context_overflow_action`` is triggered if this threshold
is exceeded.
context_overflow_action: Action to handle context overflows. One
of ``"truncated-tail"``, ``"truncated-tail-log"``,
``"truncated-head"``, ``"truncated-head-log"``,
``"skipped"``, or ``"skipped-log"`` (case-insensitive).
Defaults to ``"truncated-tail"``.
error_handling_strategy: Strategy for handling errors during model
requests. ``"RETRY"`` retries the request (limited by
retry_num, retry_fallback_strategy, etc.); ``"FAILOVER"``
throws exceptions and fails the job; ``"IGNORE"`` skips the
error input and continues. Defaults to ``"RETRY"``.
retry_num: Number of retries for client requests.
Defaults to ``100``.
retry_backoff_strategy: The strategy to use for retry backoff.
``"FIXED"`` or ``"EXPONENTIAL"``. Defaults to ``"FIXED"``.
retry_backoff_base_interval: The base interval for retry backoff,
used as the initial delay before the first retry and as the
base for calculating subsequent retry delays.
Defaults to ``"1s"``.
retry_fallback_strategy: Fallback strategy to employ if the retry
attempts are exhausted. ``"FAILOVER"`` or ``"IGNORE"``.
Defaults to ``"FAILOVER"``.
extra_header: Additional headers for the requests. Should be a
JSON-format string whose values are strings or list of strings.
extra_body: Additional parameters to pass through the requests'
body. Should be a JSON-format string.
**extra_options: Additional options passed through as-is (keys are
not translated).
Example::
>>> provider = OpenAICompatProvider(
... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
... api_key="sk-...",
... temperature=0.7,
... )
"""
# Explicit mapping from Python parameter names to Java-side option keys.
_OPTION_MAP = {
"endpoint": "endpoint",
"api_key": "api-key",
"task": "task",
"model": "model",
"system_prompt": "system-prompt",
"user_prompt": "user-prompt",
"temperature": "temperature",
"top_p": "top-p",
"max_tokens": "max-tokens",
"stop": "stop",
"presence_penalty": "presence-penalty",
"n": "n",
"seed": "seed",
"content_type": "content-type",
"response_format": "response-format",
"dimension": "dimension",
"max_context_size": "max-context-size",
"context_overflow_action": "context-overflow-action",
"error_handling_strategy": "error-handling-strategy",
"retry_num": "retry-num",
"retry_backoff_strategy": "retry-backoff-strategy",
"retry_backoff_base_interval": "retry-backoff-base-interval",
"retry_fallback_strategy": "retry-fallback-strategy",
"extra_header": "extra-header",
"extra_body": "extra-body",
}
_SUPPORTED_TASKS = _OPENAI_COMPAT_TASKS
def __init__(
self,
*,
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
task: Optional[str] = None,
model: Optional[str] = None,
system_prompt: str = "You are a helpful assistant.",
user_prompt: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
stop: Optional[str] = None,
presence_penalty: Optional[float] = None,
n: Optional[int] = None,
seed: Optional[int] = None,
content_type: str = "TEXT",
response_format: Optional[str] = None,
dimension: Optional[int] = None,
max_context_size: Optional[int] = None,
context_overflow_action: str = "truncated-tail",
error_handling_strategy: str = "RETRY",
retry_num: int = 100,
retry_backoff_strategy: str = "FIXED",
retry_backoff_base_interval: str = "1s",
retry_fallback_strategy: str = "FAILOVER",
extra_header: Optional[str] = None,
extra_body: Optional[str] = None,
**extra_options: Any,
):
self._validate_service_options(endpoint, api_key, task)
self._validate_context_overflow_action(context_overflow_action)
self._params: Dict[str, Any] = {}
# Collect all explicitly named parameters that are not None.
local_vars = {
"endpoint": endpoint,
"api_key": api_key,
"task": task,
"model": model,
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"stop": stop,
"presence_penalty": presence_penalty,
"n": n,
"seed": seed,
"content_type": content_type,
"response_format": response_format,
"dimension": dimension,
"max_context_size": max_context_size,
"context_overflow_action": context_overflow_action,
"error_handling_strategy": error_handling_strategy,
"retry_num": retry_num,
"retry_backoff_strategy": retry_backoff_strategy,
"retry_backoff_base_interval": retry_backoff_base_interval,
"retry_fallback_strategy": retry_fallback_strategy,
"extra_header": extra_header,
"extra_body": extra_body,
}
for key, value in local_vars.items():
if value is not None:
self._params[key] = value
self._extra_options = extra_options
def provider_name(self) -> str:
return "openai-compat"
@classmethod
def _validate_service_options(
cls,
endpoint: Optional[str],
api_key: Optional[str],
task: Optional[str],
) -> None:
if api_key is not None:
if endpoint is None or not endpoint.strip():
raise ValueError("'api_key' requires 'endpoint'.")
if task is not None:
raise ValueError("'api_key' cannot be combined with 'task'.")
if task is not None and task not in cls._SUPPORTED_TASKS:
supported = ", ".join(sorted(cls._SUPPORTED_TASKS))
raise ValueError(
f"Unsupported task {task!r}. Supported values: {supported}.")
@staticmethod
def _validate_context_overflow_action(value: str) -> None:
if value not in _CONTEXT_OVERFLOW_ACTIONS:
supported = ", ".join(sorted(_CONTEXT_OVERFLOW_ACTIONS))
raise ValueError(
f"Unsupported context_overflow_action {value!r}. "
f"Supported values: {supported}.")
def to_options(self) -> Dict[str, str]:
options: Dict[str, str] = {}
for py_key, value in self._params.items():
java_key = self._OPTION_MAP[py_key]
options[java_key] = str(value)
# Extra options are passed through as-is.
for key, value in self._extra_options.items():
options[key] = str(value)
return options
[docs]
class DashScopeProvider(OpenAICompatProvider):
"""Provider for Alibaba Cloud DashScope (dashscope).
DashScope reuses the OpenAI-compatible Java provider for normal chat and
embedding requests, and adds DashScope-specific multi-modal embedding support.
Args:
endpoint: The endpoint to connect to. Required with ``api_key``.
api_key: The key used to authorize the access to the endpoint.
Required when using BYOK models.
task: The model task. Supported values are ``"chat/completions"``,
``"embeddings"``, and ``"multimodal-embedding"``.
Required when using Flink AI Model Service.
model: The version of the model to use.
system_prompt: The system message of a chat. Can be disabled by
setting to empty string.
Defaults to ``"You are a helpful assistant."``.
user_prompt: The prompt of a chat, passed to the model service
through user's role. Can be disabled by setting to empty string.
temperature: Controls the randomness or "creativity" of the output.
Typical values are between 0.0 and 1.0.
top_p: The probability cutoff for token selection. Usually either
temperature or top_p are specified, but not both.
max_tokens: The maximum number of tokens that can be generated
in the chat completion.
stop: A CSV list of strings to pass as stop sequences to the model.
presence_penalty: Number between -2.0 and 2.0. Positive values
penalize new tokens based on whether they appear in the text
so far, increasing the model's likelihood to talk about new
topics.
n: How many chat completion choices to generate for each input
message. Keep n as 1 to minimize costs.
seed: If specified, the model platform will make a best effort to
sample deterministically. Determinism is not guaranteed.
content_type: Content type of the input string. Supported types:
``"TEXT"`` (default), ``"IMAGE_URL"``. For
multi-modal embedding, set this to ``"IMAGE_URL"``.
response_format: The format of the response (``"text"`` or
``"json_object"``).
dimension: The size of the embedding result array.
max_context_size: Max number of tokens for context.
``context_overflow_action`` is triggered if this threshold
is exceeded.
context_overflow_action: Action to handle context overflows. One
of ``"truncated-tail"``, ``"truncated-tail-log"``,
``"truncated-head"``, ``"truncated-head-log"``,
``"skipped"``, or ``"skipped-log"`` (case-insensitive).
Defaults to ``"truncated-tail"``.
error_handling_strategy: Strategy for handling errors during model
requests. ``"RETRY"`` retries the request (limited by
retry_num, retry_fallback_strategy, etc.); ``"FAILOVER"``
throws exceptions and fails the job; ``"IGNORE"`` skips the
error input and continues. Defaults to ``"RETRY"``.
retry_num: Number of retries for client requests.
Defaults to ``100``.
retry_backoff_strategy: The strategy to use for retry backoff.
``"FIXED"`` or ``"EXPONENTIAL"``. Defaults to ``"FIXED"``.
retry_backoff_base_interval: The base interval for retry backoff,
used as the initial delay before the first retry and as the
base for calculating subsequent retry delays.
Defaults to ``"1s"``.
retry_fallback_strategy: Fallback strategy to employ if the retry
attempts are exhausted. ``"FAILOVER"`` or ``"IGNORE"``.
Defaults to ``"FAILOVER"``.
extra_header: Additional headers for the requests. Should be a
JSON-format string whose values are strings or list of strings.
extra_body: Additional parameters to pass through the requests'
body. Should be a JSON-format string. Multi-modal
embedding tasks do not support this option.
**options: Additional options passed through as-is (keys are not
translated).
Examples::
>>> provider = DashScopeProvider(
... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
... api_key="sk-...",
... model="qwen-plus",
... temperature=0.7,
... )
>>> provider = DashScopeProvider(
... endpoint="https://dashscope.aliyuncs.com/api/v1/services/"
... "embeddings/multimodal-embedding/multimodal-embedding",
... api_key="sk-...",
... model="tongyi-embedding-vision-plus",
... content_type="IMAGE_URL",
... )
>>> pf.set_provider(provider)
>>> df = pf.from_dict({"image_url": [
... "https://dashscope.oss-cn-beijing.aliyuncs.com/images/tiger.png"
... ]})
>>> embeddings = df.llm.ai_embed("image_url", dimension=512)
"""
_SUPPORTED_TASKS = _DASHSCOPE_TASKS
def provider_name(self) -> str:
return "dashscope"
[docs]
class TritonProvider(Provider):
"""Provider for NVIDIA Triton Inference Server (triton).
Args:
endpoint: Full URL of the Triton Inference Server endpoint.
model_name: Name of the model to invoke. This can also be provided
through ``df.llm.predict(..., model="...")``, which maps to the
Java-side ``model-name`` option for Triton.
model_version: Version of the model to use.
Defaults to ``"latest"``.
timeout: HTTP request timeout, for example ``"10s"`` or
``"30000ms"``. Defaults to ``"30s"``.
flatten_batch_dim: Whether to flatten the leading batch dimension for
array inputs. For ``ARRAY<T>`` inputs, the default shape is
``[1, N]``, where ``N`` is the array length. Set this to ``True``
when the Triton model expects ``[N]`` instead.
Defaults to ``False``.
priority: Triton request priority level.
sequence_id: Sequence ID for stateful models.
sequence_start: Whether this request starts a stateful sequence.
Defaults to ``False``.
sequence_end: Whether this request ends a stateful sequence.
Defaults to ``False``.
compression: Compression algorithm for the request body. Currently
Triton provider supports ``"gzip"``.
auth_token: Authentication token for secured Triton servers. The Java
provider sends it as a Bearer token.
custom_headers: Custom HTTP headers as a Flink map string with
comma-separated ``key:value`` pairs (e.g.
``"X-Trace-Id:abc,X-Other:val"``).
max_retries: Maximum number of retries for failed inference requests.
Defaults to ``0``.
retry_initial_backoff: Initial backoff duration between retry
attempts. Defaults to ``"100ms"``.
retry_max_backoff: Upper bound on the delay between retry attempts.
Defaults to ``"30s"``.
default_value: Fallback value to return when inference fails after
retries or with a non-retryable error:
- If not specified, inference failures are propagated as
exceptions.
- For ``STRING`` outputs, pass plain text such as ``"FAILED"``.
- For numeric outputs, pass the numeric value or its string
representation, such as ``-1`` or ``"-1"``.
- For ``ARRAY`` or structured outputs, pass a JSON string or the
corresponding Python list, tuple, or mapping; Python containers
are serialized as JSON.
- To emit SQL ``NULL``, pass the lower-case literal ``"null"``.
For string outputs, ``"null"`` is therefore not usable as a
literal string sentinel; use values such as ``"NULL"``,
``"FAILED"``, or ``"<null>"`` instead.
health_check_enabled: Whether to enable periodic health checks for the
Triton server. Defaults to ``False``.
health_check_interval: Interval between health check requests.
Defaults to ``"30s"``.
circuit_breaker_enabled: Whether to enable circuit breaker protection
for Triton inference requests. Defaults to ``False``.
circuit_breaker_failure_threshold: Failure rate threshold that opens
the circuit breaker. Must be in ``(0.0, 1.0]``. Defaults to
``0.5``.
circuit_breaker_timeout: Duration to keep the circuit breaker open
before probing recovery. Defaults to ``"60s"``.
circuit_breaker_half_open_requests: Number of successful half-open
probe requests required to close the circuit. Defaults to ``3``.
**extra_options: Additional options passed through as-is.
Examples::
>>> import pyflink.dataframe as pf
>>>
>>> # Classifier with ARRAY<FLOAT> features and BIGINT class output.
>>> provider = TritonProvider(
... endpoint="<Your Triton endpoint>",
... auth_token="<Your authentication token>",
... model_name="classifier",
... compression="gzip",
... )
>>> pf.set_provider(provider)
>>> df = pf.from_records(
... [([5.1, 3.5, 1.4, 0.2],), ...],
... schema=["features"])
>>> result = df.llm.predict(
... "features",
... output_type={"class_id": "BIGINT"})
>>>
>>> # Stateful conversation model with a fixed Triton sequence.
>>> provider = TritonProvider(
... endpoint="<Your Triton endpoint>",
... auth_token="<Your authentication token>",
... model_name="chatbot_lstm",
... sequence_id="conv-001",
... sequence_start=True,
... sequence_end=False,
... )
>>> pf.set_provider(provider)
>>> chat_messages = pf.from_records(
... [("hello",), ...],
... schema=["message_text"])
>>> result = chat_messages.llm.predict(
... "message_text",
... output_type={"bot_response": "STRING"})
>>>
>>> # Vector transform model with ARRAY<FLOAT> input and output.
>>> provider = TritonProvider(
... endpoint="<Your Triton endpoint>",
... auth_token="<Your authentication token>",
... model_name="vector-transform",
... flatten_batch_dim=True, # Used when Triton model expects one-dimensional input
... )
>>> pf.set_provider(provider)
>>> vector_input = pf.from_records(
... [([0.1, 0.2, 0.3, ...],), ...],
... schema=["features"])
>>> result = vector_input.llm.predict(
... "features",
... output_type={"output_vector": "ARRAY<FLOAT>"})
"""
_OPTION_MAP = {
"endpoint": "endpoint",
"model_name": "model-name",
"model_version": "model-version",
"timeout": "timeout",
"flatten_batch_dim": "flatten-batch-dim",
"priority": "priority",
"sequence_id": "sequence-id",
"sequence_start": "sequence-start",
"sequence_end": "sequence-end",
"compression": "compression",
"auth_token": "auth-token",
"custom_headers": "custom-headers",
"max_retries": "max-retries",
"retry_initial_backoff": "retry-initial-backoff",
"retry_max_backoff": "retry-max-backoff",
"default_value": "default-value",
"health_check_enabled": "health-check-enabled",
"health_check_interval": "health-check-interval",
"circuit_breaker_enabled": "circuit-breaker-enabled",
"circuit_breaker_failure_threshold": (
"circuit-breaker-failure-threshold"
),
"circuit_breaker_timeout": "circuit-breaker-timeout",
"circuit_breaker_half_open_requests": (
"circuit-breaker-half-open-requests"
),
}
def __init__(
self,
*,
endpoint: str,
model_name: Optional[str] = None,
model_version: str = "latest",
timeout: str = "30s",
flatten_batch_dim: bool = False,
priority: Optional[int] = None,
sequence_id: Optional[str] = None,
sequence_start: bool = False,
sequence_end: bool = False,
compression: Optional[Literal["gzip"]] = None,
auth_token: Optional[str] = None,
custom_headers: Optional[str] = None,
max_retries: int = 0,
retry_initial_backoff: str = "100ms",
retry_max_backoff: str = "30s",
default_value: Optional[_TritonDefaultValue] = None,
health_check_enabled: bool = False,
health_check_interval: str = "30s",
circuit_breaker_enabled: bool = False,
circuit_breaker_failure_threshold: float = 0.5,
circuit_breaker_timeout: str = "60s",
circuit_breaker_half_open_requests: int = 3,
**extra_options: Any,
):
if not isinstance(endpoint, str):
raise TypeError(f"endpoint must be a string, got {type(endpoint)}")
if compression is not None and compression != "gzip":
raise ValueError(
"compression must be 'gzip' when provided, got "
f"{compression!r}.")
if priority is not None and not 0 <= priority <= 255:
raise ValueError("priority must be in range [0, 255].")
if max_retries < 0:
raise ValueError("max_retries must be >= 0.")
if not 0.0 < circuit_breaker_failure_threshold <= 1.0:
raise ValueError(
"circuit_breaker_failure_threshold must be in range "
"(0.0, 1.0].")
if circuit_breaker_half_open_requests <= 0:
raise ValueError(
"circuit_breaker_half_open_requests must be positive.")
if custom_headers is not None:
if not isinstance(custom_headers, str):
raise TypeError(
f"custom_headers must be a string, got "
f"{type(custom_headers)}")
if "\n" in custom_headers or "\r" in custom_headers:
raise ValueError("custom_headers cannot contain line breaks.")
default_value_value = self._format_default_value(default_value)
self._params: Dict[str, Any] = {}
local_vars = {
"endpoint": endpoint,
"model_name": model_name,
"model_version": model_version,
"timeout": timeout,
"flatten_batch_dim": flatten_batch_dim,
"priority": priority,
"sequence_id": sequence_id,
"sequence_start": sequence_start,
"sequence_end": sequence_end,
"compression": compression,
"auth_token": auth_token,
"custom_headers": custom_headers,
"max_retries": max_retries,
"retry_initial_backoff": retry_initial_backoff,
"retry_max_backoff": retry_max_backoff,
"default_value": default_value_value,
"health_check_enabled": health_check_enabled,
"health_check_interval": health_check_interval,
"circuit_breaker_enabled": circuit_breaker_enabled,
"circuit_breaker_failure_threshold": (
circuit_breaker_failure_threshold
),
"circuit_breaker_timeout": circuit_breaker_timeout,
"circuit_breaker_half_open_requests": (
circuit_breaker_half_open_requests
),
}
for key, value in local_vars.items():
if value is not None:
self._params[key] = value
self._extra_options = extra_options
def provider_name(self) -> str:
return "triton"
def model_option_key(self) -> str:
return "model-name"
def to_options(self) -> Dict[str, str]:
options: Dict[str, str] = {}
for py_key, value in self._params.items():
java_key = self._OPTION_MAP[py_key]
options[java_key] = self._stringify(value)
for key, value in self._extra_options.items():
options[key] = self._stringify(value)
return options
@staticmethod
def _format_default_value(
default_value: Optional[_TritonDefaultValue],
) -> Optional[str]:
if default_value is None:
return None
if isinstance(default_value, bool):
return "true" if default_value else "false"
if isinstance(default_value, (list, tuple, MappingABC)):
payload = (
dict(default_value)
if isinstance(default_value, MappingABC)
else default_value
)
try:
return json.dumps(payload, allow_nan=False)
except (TypeError, ValueError) as exc:
raise ValueError(
"default_value contains a non-JSON-serializable "
f"element: {exc}") from exc
return str(default_value)
@staticmethod
def _stringify(value: Any) -> str:
if isinstance(value, bool):
return str(value).lower()
return str(value)
[docs]
class GenericProvider(Provider):
"""Generic provider for unknown or custom model providers.
Options are passed through to the Java side as-is without any
key name translation.
Args:
name: Provider identifier recognized by Flink's Java runtime.
**options: Arbitrary key-value options.
Example::
>>> provider = GenericProvider("my-provider", endpoint="https://...",
... **{"api-key": "sk-..."})
"""
def __init__(self, name: str, **options: Any):
self._name = name
self._options = options
def provider_name(self) -> str:
return self._name
def to_options(self) -> Dict[str, str]:
return {k: str(v) for k, v in self._options.items()}