################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
LLM / AI functions for the DataFrame API.
Example::
import pyflink.dataframe as pf
# Option 1: Use a provider
pf.set_provider(pf.OpenAICompatProvider(
endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
api_key="sk-..."))
df = pf.from_dict({"text": ["hello", "world"]})
df = df.llm.predict("text", model="qwen-plus")
# Option 2: Use a catalog model
df = df.llm.predict("text", model="my_catalog_model")
df = df.llm.ai_classify("text", ["pos", "neg"], model="my_catalog_model")
"""
from functools import lru_cache
import uuid
from typing import Dict, List, Mapping, Optional, Union, TYPE_CHECKING
from pyflink.dataframe.ai.providers import Provider, GenericProvider
if TYPE_CHECKING:
from pyflink.dataframe.dataframe import DataFrame
from pyflink.dataframe.datatype import DataType
from pyflink.table import Expression
# ---------------------------------------------------------------------------
# Global provider registry
# ---------------------------------------------------------------------------
_provider_registry: Dict[str, Provider] = {}
_default_provider: Optional[str] = None
[docs]
def set_provider(name_or_provider: Union[str, Provider],
provider: Optional[Provider] = None,
**options) -> None:
"""
Register a global provider configuration.
Can be called in three ways:
1. ``set_provider(Provider_instance)`` — register under the provider's
default name (e.g. ``"openai-compat"``).
2. ``set_provider("name", Provider_instance)`` — register under a custom
name. This allows registering multiple instances of the same provider
type (e.g. one for chat, one for embeddings).
3. ``set_provider("name", **options)`` — create a
:class:`GenericProvider` with the given options.
Args:
name_or_provider: Either a :class:`Provider` instance (form 1) or a
custom name string (forms 2 and 3).
provider: A :class:`Provider` instance to register under the custom
name (form 2 only).
**options: Provider options (form 3 only, wraps in
:class:`GenericProvider`).
Example::
>>> import pyflink.dataframe as pf
>>> # Form 1: Provider instance (registered as "openai-compat")
>>> pf.set_provider(pf.OpenAICompatProvider(
... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
... api_key="sk-..."))
>>> # Form 2: Custom name + Provider instance
>>> pf.set_provider("chat", pf.OpenAICompatProvider(
... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
... api_key="sk-..."))
>>> pf.set_provider("embedding", pf.OpenAICompatProvider(
... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/embeddings",
... api_key="sk-..."))
>>> # Form 3: Generic string API
>>> pf.set_provider(
... "openai-compat",
... endpoint="https://dashscope.aliyuncs.com/compatible-mode"
... "/v1/chat/completions",
... api_key="sk-...")
"""
if isinstance(name_or_provider, Provider):
if provider is not None:
raise ValueError(
"When passing a Provider as the first argument, "
"the second argument must not be a Provider.")
if options:
raise ValueError(
"When passing a Provider as the first argument, "
"**options must not be provided.")
p = name_or_provider
_provider_registry[p.provider_name()] = p
elif isinstance(name_or_provider, str):
if provider is not None:
if not isinstance(provider, Provider):
raise TypeError(
f"The second argument must be a Provider instance, "
f"got {type(provider).__name__}.")
if options:
raise ValueError(
"Cannot pass both a Provider instance and **options.")
_provider_registry[name_or_provider] = provider
else:
_provider_registry[name_or_provider] = GenericProvider(
name_or_provider, **options)
else:
raise TypeError(
f"First argument must be a str or Provider, got "
f"{type(name_or_provider).__name__}.")
[docs]
def set_default_provider(name: str) -> None:
"""
Set the default provider for AI functions.
When multiple providers are registered, this determines which one
is used when ``provider=`` is not specified in a function call.
Args:
name: Provider name (must already be registered via set_provider).
Example::
>>> pf.set_provider("openai", api_key="sk-...")
>>> pf.set_provider("deepseek", api_key="dk-...")
>>> pf.set_default_provider("openai")
"""
if name not in _provider_registry:
raise ValueError(
f"Provider {name!r} is not registered. "
"Call pf.set_provider(...) first.")
global _default_provider
_default_provider = name
[docs]
def list_providers() -> List[str]:
"""
List all registered provider names.
Returns:
A list of registered provider names.
Example::
>>> pf.set_provider("openai", api_key="sk-...")
>>> pf.set_provider("deepseek", api_key="dk-...")
>>> pf.list_providers()
['openai', 'deepseek']
"""
return list(_provider_registry.keys())
def _resolve_provider(explicit: Optional[str] = None) -> Optional[str]:
"""Resolve which provider to use. Returns None if no provider available."""
if explicit is not None:
if not explicit:
raise ValueError("Provider name cannot be empty.")
return explicit
if _default_provider:
return _default_provider
if len(_provider_registry) == 1:
return next(iter(_provider_registry))
if len(_provider_registry) == 0:
return None
raise ValueError(
"Multiple providers registered. Either call "
"pf.set_default_provider(...) or pass provider='...' explicitly.")
# ---------------------------------------------------------------------------
# LLMAccessor — accessed via df.llm
# ---------------------------------------------------------------------------
@lru_cache(maxsize=1)
def _ai_function_schemas():
"""Schema specs for AI functions: {func_name: (input_cols, output_cols)}.
Each col is (name, DataType). Built lazily because DataTypes requires JVM.
"""
from pyflink.table import DataTypes
return {
"ai_classify": ([("input", DataTypes.STRING())],
[("content", DataTypes.VARIANT())]),
"ai_sentiment": ([("input", DataTypes.STRING())],
[("content", DataTypes.VARIANT())]),
"ai_extract": ([("input", DataTypes.STRING())],
[("content", DataTypes.VARIANT())]),
"ai_translate": ([("input", DataTypes.STRING())],
[("content", DataTypes.VARIANT())]),
"ai_summarize": ([("input", DataTypes.STRING())],
[("content", DataTypes.VARIANT())]),
"ai_mask": ([("input", DataTypes.STRING())],
[("content", DataTypes.VARIANT())]),
"ai_embed": ([("input", DataTypes.STRING())],
[("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))]),
}
def _build_schema(columns):
"""Build a Schema from a list of (name, DataType) tuples."""
from pyflink.table import Schema
builder = Schema.new_builder()
for col_name, data_type in columns:
builder.column(col_name, data_type)
return builder.build()
def _create_model_descriptor(t_env, provider_name, model_name, input_columns,
output_columns):
"""Create a ModelDescriptor from a provider registry entry.
Args:
t_env: The TableEnvironment. Used to auto-load the built-in
model-provider jar before the descriptor is built so
FactoryUtil can resolve the provider class.
provider_name: Provider name.
model_name: Model name.
input_columns: List of (name, DataType) tuples for input schema.
output_columns: List of (name, DataType) tuples for output schema.
"""
from pyflink.dataframe.util.artifacts import add_built_in_model
from pyflink.table import ModelDescriptor
provider = _provider_registry.get(provider_name)
provider_options = provider.to_options() if provider else {}
java_provider_name = provider.provider_name() if provider else provider_name
add_built_in_model(t_env, java_provider_name)
builder = ModelDescriptor.for_provider(java_provider_name)
for k, v in provider_options.items():
builder = builder.option(k, v)
if model_name:
model_option_key = (
provider.model_option_key() if provider is not None else "model")
builder = builder.option(model_option_key, model_name)
builder = builder.input_schema(_build_schema(input_columns))
builder = builder.output_schema(_build_schema(output_columns))
return builder.build()
def _create_model(t_env, provider_name, model_name, input_columns,
output_columns):
"""Create a temporary model and return the Model object.
Args:
t_env: The TableEnvironment.
provider_name: Provider name.
model_name: Model name.
input_columns: List of (name, DataType) tuples for input schema.
output_columns: List of (name, DataType) tuples for output schema.
"""
temp_name = f"_pf_llm_{uuid.uuid4().hex[:8]}"
descriptor = _create_model_descriptor(
t_env, provider_name, model_name, input_columns, output_columns)
t_env.create_temporary_model(temp_name, descriptor, True)
return t_env.from_model(temp_name)
def _resolve_model(t_env, model, provider=None, input_columns=None,
output_columns=None):
"""Resolve a Model, either from provider registry or catalog.
If a provider is available (explicitly given or resolved from registry),
creates a temporary model. Otherwise, treats ``model`` as a catalog
model name.
Args:
t_env: The TableEnvironment.
model: Model name — either a provider model id (e.g. "qwen-plus")
or a catalog model name (e.g. "my_catalog_model").
provider: Explicit provider name, or None to auto-resolve.
input_columns: List of (name, DataType) tuples for input schema.
output_columns: List of (name, DataType) tuples for output schema.
Returns:
A Model object.
"""
provider_name = _resolve_provider(provider)
if provider_name is not None:
return _create_model(t_env, provider_name, model,
input_columns, output_columns)
if model is None:
raise ValueError(
"Either configure a provider via pf.set_provider(...) "
"or pass a catalog model name via model='...'.")
return t_env.from_model(model)
class LLMAccessor:
"""
Accessor for LLM / AI functions on a DataFrame.
Accessed via ``df.llm``. Groups all AI-related operations together.
Example::
>>> import pyflink.dataframe as pf
>>> pf.set_provider(pf.OpenAICompatProvider(
... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
... api_key="sk-..."))
>>>
>>> df = pf.from_dict({"text": ["hello", "world"]})
>>> df = df.llm.predict("text", model="qwen-plus")
>>> df = df.llm.ai_classify("text", ["pos", "neg"], model="qwen-plus")
"""
def __init__(self, df: "DataFrame"):
self._df = df
@staticmethod
def _to_col_expr(input_col: Union[str, "Expression"]) -> "Expression":
"""Convert a column name string to an Expression if needed."""
if isinstance(input_col, str):
from pyflink.dataframe.dataframe import col
return col(input_col)
return input_col
def _get_column_type(self, col_name: str):
"""Get the DataType of a column from the DataFrame's schema."""
schema = self._df._table.get_resolved_schema()
for column in schema.get_columns():
if column.get_name() == col_name:
return column.get_data_type()
available = [c.get_name() for c in schema.get_columns()]
raise ValueError(
f"Column {col_name!r} not found. Available columns: {available}")
[docs]
def predict(self, *input_cols: str, provider: str = None,
model: str = None,
output_type: Optional[Mapping[str, Union[str, "DataType"]]] = None,
config: Dict[str, str] = None) -> "DataFrame":
"""
Perform prediction using a model.
This is the general-purpose prediction method.
When a provider is configured, a temporary model is created with
the given input/output schema. When using a catalog model
(no provider), the model's registered schema is used and
``output_type`` is ignored.
Args:
*input_cols: Column names to use as input.
provider: Provider name. Uses default if not specified.
If no provider is configured, ``model`` is treated as
a catalog model name.
model: Model name (e.g. "qwen-plus") or catalog model
name.
output_type: Output column schema as a dict
``{name: type}``. Type can be a SQL type string
(e.g. ``"STRING"``) or a ``DataType`` object.
Defaults to ``{"output": "STRING"}``.
Only used when a provider is configured. Ignored
for catalog models.
config: Optional dict of runtime config options.
Returns:
A new DataFrame with the model output columns appended.
Example::
>>> # Single output column (default)
>>> df.llm.predict("question", model="qwen-plus")
>>> # JSON structured output
>>> df.llm.predict("question", model="qwen-plus",
... output_type={"output": "VARIANT"})
>>> # Multiple output columns
>>> df.llm.predict("question", model="qwen-plus",
... output_type={"answer": DataType.string(),
... "score": DataType.float64()})
"""
from pyflink.dataframe.dataframe import DataFrame as DF
from pyflink.dataframe.datatype import DataType
if output_type is None:
output_type = {"output": "STRING"}
# Build output columns from output_type
output_columns = []
for name, dtype in output_type.items():
if isinstance(dtype, str):
dtype = DataType._from_sql(dtype)
output_columns.append((name, dtype._table_type))
# Build input schema from actual column types
input_columns = [
(col_name, self._get_column_type(col_name))
for col_name in input_cols
]
flink_model = _resolve_model(
self._df._table._t_env, model, provider,
input_columns, output_columns)
result = flink_model.predict(
self._df._table, list(input_cols), config)
return DF(result)
def _ai_func(self, func_name: str, input_col, model, provider, config,
ai_func, *extra_args):
"""Common logic for ai_* functions."""
from pyflink.dataframe.dataframe import DataFrame as DF
input_cols, output_cols = _ai_function_schemas()[func_name]
flink_model = _resolve_model(
self._df._table._t_env, model, provider,
input_cols, output_cols)
expr = ai_func(
flink_model, self._to_col_expr(input_col), *extra_args, config)
return DF(self._df._table.join_lateral(expr))
[docs]
def ai_classify(self, input_col: Union[str, "Expression"],
labels: List[str], *,
provider: str = None, model: str = None,
config: Dict[str, str] = None) -> "DataFrame":
"""
Classify text into one of the provided labels.
Args:
input_col: Column name (str) or Expression for the input text.
labels: List of label strings.
provider: Provider name.
model: Model name.
config: Optional runtime config.
Returns:
A new DataFrame with columns appended:
- ``category`` (STRING): the predicted label.
- ``confidence`` (DOUBLE): confidence score.
Example::
>>> df.llm.ai_classify("text", ["positive", "negative"],
... model="qwen-plus")
"""
from pyflink.table.ai_functions import ai_classify as _ai_classify
return self._ai_func("ai_classify", input_col, model, provider,
config, _ai_classify, labels)
[docs]
def ai_sentiment(self, input_col: Union[str, "Expression"], *,
provider: str = None, model: str = None,
config: Dict[str, str] = None) -> "DataFrame":
"""
Analyze the sentiment of input text.
Args:
input_col: Column name (str) or Expression for the input text.
provider: Provider name.
model: Model name.
config: Optional runtime config.
Returns:
A new DataFrame with columns appended:
- ``score`` (DOUBLE): sentiment score from -1.0 to 1.0.
- ``label`` (STRING): one of "positive", "negative", "neutral".
- ``confidence`` (DOUBLE): confidence score.
Example::
>>> df.llm.ai_sentiment("review", model="qwen-plus")
"""
from pyflink.table.ai_functions import ai_sentiment as _ai_sentiment
return self._ai_func("ai_sentiment", input_col, model, provider,
config, _ai_sentiment)
[docs]
def ai_translate(self, input_col: Union[str, "Expression"],
source_lang: str, target_lang: str, *,
provider: str = None, model: str = None,
config: Dict[str, str] = None) -> "DataFrame":
"""
Translate text from one language to another.
Args:
input_col: Column name (str) or Expression for the input text.
source_lang: Source language code (e.g. ``"zh"``, ``"en"``,
``"auto"`` for auto-detection). Supported: ``auto``, ``zh``,
``en``, ``ja``, ``ko``, ``fr``, ``de``, ``es``, ``ru``,
``ar``, ``pt``.
target_lang: Target language code (cannot be ``"auto"``).
provider: Provider name.
model: Model name.
config: Optional runtime config.
Returns:
A new DataFrame with columns appended:
- ``translated_text`` (STRING): the translated text.
- ``detected_language`` (STRING): detected source language code.
Example::
>>> df.llm.ai_translate("text", "zh", "en", model="qwen-plus")
"""
from pyflink.table.ai_functions import ai_translate as _ai_translate
return self._ai_func("ai_translate", input_col, model, provider,
config, _ai_translate, source_lang, target_lang)
[docs]
def ai_summarize(self, input_col: Union[str, "Expression"],
max_length: int, *, provider: str = None,
model: str = None,
config: Dict[str, str] = None) -> "DataFrame":
"""
Summarize text to a maximum length.
Args:
input_col: Column name (str) or Expression for the input text.
max_length: Maximum length of the summary in characters.
Must be greater than 0.
provider: Provider name.
model: Model name.
config: Optional runtime config.
Returns:
A new DataFrame with a column appended:
- ``summary`` (STRING): the summarized text.
Example::
>>> df.llm.ai_summarize("article", 200, model="qwen-plus")
"""
from pyflink.table.ai_functions import ai_summarize as _ai_summarize
return self._ai_func("ai_summarize", input_col, model, provider,
config, _ai_summarize, max_length)
[docs]
def ai_mask(self, input_col: Union[str, "Expression"],
entities: List[str], *, provider: str = None,
model: str = None,
config: Dict[str, str] = None) -> "DataFrame":
"""
Mask sensitive information in text.
Args:
input_col: Column name (str) or Expression for the input text.
entities: List of entity types to mask (e.g. ``["name", "phone"]``).
provider: Provider name.
model: Model name.
config: Optional runtime config.
Returns:
A new DataFrame with columns appended:
- ``masked_text`` (STRING): text with sensitive info replaced.
- ``detected_entities`` (ARRAY<STRING>): list of detected entity
types.
Example::
>>> df.llm.ai_mask("text", ["name", "phone"], model="qwen-plus")
"""
from pyflink.table.ai_functions import ai_mask as _ai_mask
return self._ai_func("ai_mask", input_col, model, provider,
config, _ai_mask, entities)
[docs]
def ai_embed(self, input_col: Union[str, "Expression"],
dimension: int = 1024, *, provider: str = None,
model: str = None,
config: Dict[str, str] = None) -> "DataFrame":
"""
Generate embedding vectors for text.
Args:
input_col: Column name (str) or Expression for the input text.
dimension: Dimension of the embedding vector (default 1024).
provider: Provider name.
model: Model name.
config: Optional runtime config.
Returns:
A new DataFrame with a column appended:
- ``embedding`` (ARRAY<FLOAT>): the embedding vector.
Example::
>>> df.llm.ai_embed("text", 512, model="qwen-plus")
"""
from pyflink.table.ai_functions import ai_embed as _ai_embed
return self._ai_func("ai_embed", input_col, model, provider,
config, _ai_embed, dimension)