Skip to main content
Ctrl+K
PyFlink 1.20+vvr.11.7.dev0 documentation - Home PyFlink 1.20+vvr.11.7.dev0 documentation - Home
  • API Reference
  • Examples
  • API Reference
  • Examples

Section Navigation

  • PyFlink Table
  • PyFlink DataStream
  • PyFlink DataFrame
    • DataFrame
    • DataFrame Creation
    • Input/Output
    • SQL
    • DataType
    • User Defined Functions
    • Configuration
    • GPU Support
    • AI / LLM
  • PyFlink Common
  • API Reference
  • PyFlink DataFrame
  • AI / LLM
  • pyflink.dataframe.ai.providers.TritonProvider

pyflink.dataframe.ai.providers.TritonProvider#

class TritonProvider(*, endpoint: str, model_name: str | None = None, model_version: str = 'latest', timeout: str = '30s', flatten_batch_dim: bool = False, priority: int | None = None, sequence_id: str | None = None, sequence_start: bool = False, sequence_end: bool = False, compression: Literal['gzip'] | None = None, auth_token: str | None = None, custom_headers: str | None = None, max_retries: int = 0, retry_initial_backoff: str = '100ms', retry_max_backoff: str = '30s', default_value: str | int | float | bool | List[Any] | Tuple[Any, ...] | Mapping[str, Any] | None = None, health_check_enabled: bool = False, health_check_interval: str = '30s', circuit_breaker_enabled: bool = False, circuit_breaker_failure_threshold: float = 0.5, circuit_breaker_timeout: str = '60s', circuit_breaker_half_open_requests: int = 3, **extra_options: Any)[source]#

Provider for NVIDIA Triton Inference Server (triton).

Parameters:
  • endpoint – Full URL of the Triton Inference Server endpoint.

  • model_name – Name of the model to invoke. This can also be provided through df.llm.predict(..., model="..."), which maps to the Java-side model-name option for Triton.

  • model_version – Version of the model to use. Defaults to "latest".

  • timeout – HTTP request timeout, for example "10s" or "30000ms". Defaults to "30s".

  • flatten_batch_dim – Whether to flatten the leading batch dimension for array inputs. For ARRAY<T> inputs, the default shape is [1, N], where N is the array length. Set this to True when the Triton model expects [N] instead. Defaults to False.

  • priority – Triton request priority level.

  • sequence_id – Sequence ID for stateful models.

  • sequence_start – Whether this request starts a stateful sequence. Defaults to False.

  • sequence_end – Whether this request ends a stateful sequence. Defaults to False.

  • compression – Compression algorithm for the request body. Currently Triton provider supports "gzip".

  • auth_token – Authentication token for secured Triton servers. The Java provider sends it as a Bearer token.

  • custom_headers – Custom HTTP headers as a Flink map string with comma-separated key:value pairs (e.g. "X-Trace-Id:abc,X-Other:val").

  • max_retries – Maximum number of retries for failed inference requests. Defaults to 0.

  • retry_initial_backoff – Initial backoff duration between retry attempts. Defaults to "100ms".

  • retry_max_backoff – Upper bound on the delay between retry attempts. Defaults to "30s".

  • default_value –

    Fallback value to return when inference fails after retries or with a non-retryable error:

    • If not specified, inference failures are propagated as exceptions.

    • For STRING outputs, pass plain text such as "FAILED".

    • For numeric outputs, pass the numeric value or its string representation, such as -1 or "-1".

    • For ARRAY or structured outputs, pass a JSON string or the corresponding Python list, tuple, or mapping; Python containers are serialized as JSON.

    • To emit SQL NULL, pass the lower-case literal "null". For string outputs, "null" is therefore not usable as a literal string sentinel; use values such as "NULL", "FAILED", or "<null>" instead.

  • health_check_enabled – Whether to enable periodic health checks for the Triton server. Defaults to False.

  • health_check_interval – Interval between health check requests. Defaults to "30s".

  • circuit_breaker_enabled – Whether to enable circuit breaker protection for Triton inference requests. Defaults to False.

  • circuit_breaker_failure_threshold – Failure rate threshold that opens the circuit breaker. Must be in (0.0, 1.0]. Defaults to 0.5.

  • circuit_breaker_timeout – Duration to keep the circuit breaker open before probing recovery. Defaults to "60s".

  • circuit_breaker_half_open_requests – Number of successful half-open probe requests required to close the circuit. Defaults to 3.

  • **extra_options – Additional options passed through as-is.

Examples:

>>> import pyflink.dataframe as pf
>>>
>>> # Classifier with ARRAY<FLOAT> features and BIGINT class output.
>>> provider = TritonProvider(
...     endpoint="<Your Triton endpoint>",
...     auth_token="<Your authentication token>",
...     model_name="classifier",
...     compression="gzip",
... )
>>> pf.set_provider(provider)
>>> df = pf.from_records(
...     [([5.1, 3.5, 1.4, 0.2],), ...],
...     schema=["features"])
>>> result = df.llm.predict(
...     "features",
...     output_type={"class_id": "BIGINT"})
>>>
>>> # Stateful conversation model with a fixed Triton sequence.
>>> provider = TritonProvider(
...     endpoint="<Your Triton endpoint>",
...     auth_token="<Your authentication token>",
...     model_name="chatbot_lstm",
...     sequence_id="conv-001",
...     sequence_start=True,
...     sequence_end=False,
... )
>>> pf.set_provider(provider)
>>> chat_messages = pf.from_records(
...     [("hello",), ...],
...     schema=["message_text"])
>>> result = chat_messages.llm.predict(
...     "message_text",
...     output_type={"bot_response": "STRING"})
>>>
>>> # Vector transform model with ARRAY<FLOAT> input and output.
>>> provider = TritonProvider(
...     endpoint="<Your Triton endpoint>",
...     auth_token="<Your authentication token>",
...     model_name="vector-transform",
...     flatten_batch_dim=True,  # Used when Triton model expects one-dimensional input
... )
>>> pf.set_provider(provider)
>>> vector_input = pf.from_records(
...     [([0.1, 0.2, 0.3, ...],), ...],
...     schema=["features"])
>>> result = vector_input.llm.predict(
...     "features",
...     output_type={"output_vector": "ARRAY<FLOAT>"})

Methods

model_option_key()

Return the Java-side option key used for a per-call model name.

provider_name()

Return the provider identifier recognized by Flink's Java runtime.

to_options()

Return all configured options as a dict with Java-side key names.

previous

pyflink.dataframe.ai.providers.DashScopeProvider

next

pyflink.dataframe.ai.providers.GenericProvider

On this page
  • TritonProvider

This Page

  • Show Source

Created using Sphinx 7.4.7.

Built with the PyData Sphinx Theme 0.16.1.