pyflink.dataframe.ai.providers.TritonProvider#
- class TritonProvider(*, endpoint: str, model_name: str | None = None, model_version: str = 'latest', timeout: str = '30s', flatten_batch_dim: bool = False, priority: int | None = None, sequence_id: str | None = None, sequence_start: bool = False, sequence_end: bool = False, compression: Literal['gzip'] | None = None, auth_token: str | None = None, custom_headers: str | None = None, max_retries: int = 0, retry_initial_backoff: str = '100ms', retry_max_backoff: str = '30s', default_value: str | int | float | bool | List[Any] | Tuple[Any, ...] | Mapping[str, Any] | None = None, health_check_enabled: bool = False, health_check_interval: str = '30s', circuit_breaker_enabled: bool = False, circuit_breaker_failure_threshold: float = 0.5, circuit_breaker_timeout: str = '60s', circuit_breaker_half_open_requests: int = 3, **extra_options: Any)[source]#
Provider for NVIDIA Triton Inference Server (triton).
- Parameters:
endpoint – Full URL of the Triton Inference Server endpoint.
model_name – Name of the model to invoke. This can also be provided through
df.llm.predict(..., model="..."), which maps to the Java-sidemodel-nameoption for Triton.model_version – Version of the model to use. Defaults to
"latest".timeout – HTTP request timeout, for example
"10s"or"30000ms". Defaults to"30s".flatten_batch_dim – Whether to flatten the leading batch dimension for array inputs. For
ARRAY<T>inputs, the default shape is[1, N], whereNis the array length. Set this toTruewhen the Triton model expects[N]instead. Defaults toFalse.priority – Triton request priority level.
sequence_id – Sequence ID for stateful models.
sequence_start – Whether this request starts a stateful sequence. Defaults to
False.sequence_end – Whether this request ends a stateful sequence. Defaults to
False.compression – Compression algorithm for the request body. Currently Triton provider supports
"gzip".auth_token – Authentication token for secured Triton servers. The Java provider sends it as a Bearer token.
custom_headers – Custom HTTP headers as a Flink map string with comma-separated
key:valuepairs (e.g."X-Trace-Id:abc,X-Other:val").max_retries – Maximum number of retries for failed inference requests. Defaults to
0.retry_initial_backoff – Initial backoff duration between retry attempts. Defaults to
"100ms".retry_max_backoff – Upper bound on the delay between retry attempts. Defaults to
"30s".default_value –
Fallback value to return when inference fails after retries or with a non-retryable error:
If not specified, inference failures are propagated as exceptions.
For
STRINGoutputs, pass plain text such as"FAILED".For numeric outputs, pass the numeric value or its string representation, such as
-1or"-1".For
ARRAYor structured outputs, pass a JSON string or the corresponding Python list, tuple, or mapping; Python containers are serialized as JSON.To emit SQL
NULL, pass the lower-case literal"null". For string outputs,"null"is therefore not usable as a literal string sentinel; use values such as"NULL","FAILED", or"<null>"instead.
health_check_enabled – Whether to enable periodic health checks for the Triton server. Defaults to
False.health_check_interval – Interval between health check requests. Defaults to
"30s".circuit_breaker_enabled – Whether to enable circuit breaker protection for Triton inference requests. Defaults to
False.circuit_breaker_failure_threshold – Failure rate threshold that opens the circuit breaker. Must be in
(0.0, 1.0]. Defaults to0.5.circuit_breaker_timeout – Duration to keep the circuit breaker open before probing recovery. Defaults to
"60s".circuit_breaker_half_open_requests – Number of successful half-open probe requests required to close the circuit. Defaults to
3.**extra_options – Additional options passed through as-is.
Examples:
>>> import pyflink.dataframe as pf >>> >>> # Classifier with ARRAY<FLOAT> features and BIGINT class output. >>> provider = TritonProvider( ... endpoint="<Your Triton endpoint>", ... auth_token="<Your authentication token>", ... model_name="classifier", ... compression="gzip", ... ) >>> pf.set_provider(provider) >>> df = pf.from_records( ... [([5.1, 3.5, 1.4, 0.2],), ...], ... schema=["features"]) >>> result = df.llm.predict( ... "features", ... output_type={"class_id": "BIGINT"}) >>> >>> # Stateful conversation model with a fixed Triton sequence. >>> provider = TritonProvider( ... endpoint="<Your Triton endpoint>", ... auth_token="<Your authentication token>", ... model_name="chatbot_lstm", ... sequence_id="conv-001", ... sequence_start=True, ... sequence_end=False, ... ) >>> pf.set_provider(provider) >>> chat_messages = pf.from_records( ... [("hello",), ...], ... schema=["message_text"]) >>> result = chat_messages.llm.predict( ... "message_text", ... output_type={"bot_response": "STRING"}) >>> >>> # Vector transform model with ARRAY<FLOAT> input and output. >>> provider = TritonProvider( ... endpoint="<Your Triton endpoint>", ... auth_token="<Your authentication token>", ... model_name="vector-transform", ... flatten_batch_dim=True, # Used when Triton model expects one-dimensional input ... ) >>> pf.set_provider(provider) >>> vector_input = pf.from_records( ... [([0.1, 0.2, 0.3, ...],), ...], ... schema=["features"]) >>> result = vector_input.llm.predict( ... "features", ... output_type={"output_vector": "ARRAY<FLOAT>"})
Methods
model_option_key()Return the Java-side option key used for a per-call model name.
provider_name()Return the provider identifier recognized by Flink's Java runtime.
to_options()Return all configured options as a dict with Java-side key names.