Source code for pyflink.dataframe.ai.llm

################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

"""
LLM / AI functions for the DataFrame API.

Example::

    import pyflink.dataframe as pf

    # Option 1: Use a provider
    pf.set_provider(pf.OpenAICompatProvider(
        endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
        api_key="sk-..."))

    df = pf.from_dict({"text": ["hello", "world"]})
    df = df.llm.predict("text", model="qwen-plus")

    # Option 2: Use a catalog model
    df = df.llm.predict("text", model="my_catalog_model")
    df = df.llm.ai_classify("text", ["pos", "neg"], model="my_catalog_model")
"""

from functools import lru_cache
import uuid
from typing import Dict, List, Mapping, Optional, Union, TYPE_CHECKING

from pyflink.dataframe.ai.providers import Provider, GenericProvider

if TYPE_CHECKING:
    from pyflink.dataframe.dataframe import DataFrame
    from pyflink.dataframe.datatype import DataType
    from pyflink.table import Expression

# ---------------------------------------------------------------------------
# Global provider registry
# ---------------------------------------------------------------------------

_provider_registry: Dict[str, Provider] = {}
_default_provider: Optional[str] = None


[docs] def set_provider(name_or_provider: Union[str, Provider], provider: Optional[Provider] = None, **options) -> None: """ Register a global provider configuration. Can be called in three ways: 1. ``set_provider(Provider_instance)`` — register under the provider's default name (e.g. ``"openai-compat"``). 2. ``set_provider("name", Provider_instance)`` — register under a custom name. This allows registering multiple instances of the same provider type (e.g. one for chat, one for embeddings). 3. ``set_provider("name", **options)`` — create a :class:`GenericProvider` with the given options. Args: name_or_provider: Either a :class:`Provider` instance (form 1) or a custom name string (forms 2 and 3). provider: A :class:`Provider` instance to register under the custom name (form 2 only). **options: Provider options (form 3 only, wraps in :class:`GenericProvider`). Example:: >>> import pyflink.dataframe as pf >>> # Form 1: Provider instance (registered as "openai-compat") >>> pf.set_provider(pf.OpenAICompatProvider( ... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", ... api_key="sk-...")) >>> # Form 2: Custom name + Provider instance >>> pf.set_provider("chat", pf.OpenAICompatProvider( ... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", ... api_key="sk-...")) >>> pf.set_provider("embedding", pf.OpenAICompatProvider( ... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/embeddings", ... api_key="sk-...")) >>> # Form 3: Generic string API >>> pf.set_provider( ... "openai-compat", ... endpoint="https://dashscope.aliyuncs.com/compatible-mode" ... "/v1/chat/completions", ... api_key="sk-...") """ if isinstance(name_or_provider, Provider): if provider is not None: raise ValueError( "When passing a Provider as the first argument, " "the second argument must not be a Provider.") if options: raise ValueError( "When passing a Provider as the first argument, " "**options must not be provided.") p = name_or_provider _provider_registry[p.provider_name()] = p elif isinstance(name_or_provider, str): if provider is not None: if not isinstance(provider, Provider): raise TypeError( f"The second argument must be a Provider instance, " f"got {type(provider).__name__}.") if options: raise ValueError( "Cannot pass both a Provider instance and **options.") _provider_registry[name_or_provider] = provider else: _provider_registry[name_or_provider] = GenericProvider( name_or_provider, **options) else: raise TypeError( f"First argument must be a str or Provider, got " f"{type(name_or_provider).__name__}.")
[docs] def set_default_provider(name: str) -> None: """ Set the default provider for AI functions. When multiple providers are registered, this determines which one is used when ``provider=`` is not specified in a function call. Args: name: Provider name (must already be registered via set_provider). Example:: >>> pf.set_provider("openai", api_key="sk-...") >>> pf.set_provider("deepseek", api_key="dk-...") >>> pf.set_default_provider("openai") """ if name not in _provider_registry: raise ValueError( f"Provider {name!r} is not registered. " "Call pf.set_provider(...) first.") global _default_provider _default_provider = name
[docs] def list_providers() -> List[str]: """ List all registered provider names. Returns: A list of registered provider names. Example:: >>> pf.set_provider("openai", api_key="sk-...") >>> pf.set_provider("deepseek", api_key="dk-...") >>> pf.list_providers() ['openai', 'deepseek'] """ return list(_provider_registry.keys())
def _resolve_provider(explicit: Optional[str] = None) -> Optional[str]: """Resolve which provider to use. Returns None if no provider available.""" if explicit is not None: if not explicit: raise ValueError("Provider name cannot be empty.") return explicit if _default_provider: return _default_provider if len(_provider_registry) == 1: return next(iter(_provider_registry)) if len(_provider_registry) == 0: return None raise ValueError( "Multiple providers registered. Either call " "pf.set_default_provider(...) or pass provider='...' explicitly.") # --------------------------------------------------------------------------- # LLMAccessor — accessed via df.llm # --------------------------------------------------------------------------- @lru_cache(maxsize=1) def _ai_function_schemas(): """Schema specs for AI functions: {func_name: (input_cols, output_cols)}. Each col is (name, DataType). Built lazily because DataTypes requires JVM. """ from pyflink.table import DataTypes return { "ai_classify": ([("input", DataTypes.STRING())], [("content", DataTypes.VARIANT())]), "ai_sentiment": ([("input", DataTypes.STRING())], [("content", DataTypes.VARIANT())]), "ai_extract": ([("input", DataTypes.STRING())], [("content", DataTypes.VARIANT())]), "ai_translate": ([("input", DataTypes.STRING())], [("content", DataTypes.VARIANT())]), "ai_summarize": ([("input", DataTypes.STRING())], [("content", DataTypes.VARIANT())]), "ai_mask": ([("input", DataTypes.STRING())], [("content", DataTypes.VARIANT())]), "ai_embed": ([("input", DataTypes.STRING())], [("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))]), } def _build_schema(columns): """Build a Schema from a list of (name, DataType) tuples.""" from pyflink.table import Schema builder = Schema.new_builder() for col_name, data_type in columns: builder.column(col_name, data_type) return builder.build() def _create_model_descriptor(t_env, provider_name, model_name, input_columns, output_columns): """Create a ModelDescriptor from a provider registry entry. Args: t_env: The TableEnvironment. Used to auto-load the built-in model-provider jar before the descriptor is built so FactoryUtil can resolve the provider class. provider_name: Provider name. model_name: Model name. input_columns: List of (name, DataType) tuples for input schema. output_columns: List of (name, DataType) tuples for output schema. """ from pyflink.dataframe.util.artifacts import add_built_in_model from pyflink.table import ModelDescriptor provider = _provider_registry.get(provider_name) provider_options = provider.to_options() if provider else {} java_provider_name = provider.provider_name() if provider else provider_name add_built_in_model(t_env, java_provider_name) builder = ModelDescriptor.for_provider(java_provider_name) for k, v in provider_options.items(): builder = builder.option(k, v) if model_name: model_option_key = ( provider.model_option_key() if provider is not None else "model") builder = builder.option(model_option_key, model_name) builder = builder.input_schema(_build_schema(input_columns)) builder = builder.output_schema(_build_schema(output_columns)) return builder.build() def _create_model(t_env, provider_name, model_name, input_columns, output_columns): """Create a temporary model and return the Model object. Args: t_env: The TableEnvironment. provider_name: Provider name. model_name: Model name. input_columns: List of (name, DataType) tuples for input schema. output_columns: List of (name, DataType) tuples for output schema. """ temp_name = f"_pf_llm_{uuid.uuid4().hex[:8]}" descriptor = _create_model_descriptor( t_env, provider_name, model_name, input_columns, output_columns) t_env.create_temporary_model(temp_name, descriptor, True) return t_env.from_model(temp_name) def _resolve_model(t_env, model, provider=None, input_columns=None, output_columns=None): """Resolve a Model, either from provider registry or catalog. If a provider is available (explicitly given or resolved from registry), creates a temporary model. Otherwise, treats ``model`` as a catalog model name. Args: t_env: The TableEnvironment. model: Model name — either a provider model id (e.g. "qwen-plus") or a catalog model name (e.g. "my_catalog_model"). provider: Explicit provider name, or None to auto-resolve. input_columns: List of (name, DataType) tuples for input schema. output_columns: List of (name, DataType) tuples for output schema. Returns: A Model object. """ provider_name = _resolve_provider(provider) if provider_name is not None: return _create_model(t_env, provider_name, model, input_columns, output_columns) if model is None: raise ValueError( "Either configure a provider via pf.set_provider(...) " "or pass a catalog model name via model='...'.") return t_env.from_model(model) class LLMAccessor: """ Accessor for LLM / AI functions on a DataFrame. Accessed via ``df.llm``. Groups all AI-related operations together. Example:: >>> import pyflink.dataframe as pf >>> pf.set_provider(pf.OpenAICompatProvider( ... endpoint="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", ... api_key="sk-...")) >>> >>> df = pf.from_dict({"text": ["hello", "world"]}) >>> df = df.llm.predict("text", model="qwen-plus") >>> df = df.llm.ai_classify("text", ["pos", "neg"], model="qwen-plus") """ def __init__(self, df: "DataFrame"): self._df = df @staticmethod def _to_col_expr(input_col: Union[str, "Expression"]) -> "Expression": """Convert a column name string to an Expression if needed.""" if isinstance(input_col, str): from pyflink.dataframe.dataframe import col return col(input_col) return input_col def _get_column_type(self, col_name: str): """Get the DataType of a column from the DataFrame's schema.""" schema = self._df._table.get_resolved_schema() for column in schema.get_columns(): if column.get_name() == col_name: return column.get_data_type() available = [c.get_name() for c in schema.get_columns()] raise ValueError( f"Column {col_name!r} not found. Available columns: {available}")
[docs] def predict(self, *input_cols: str, provider: str = None, model: str = None, output_type: Optional[Mapping[str, Union[str, "DataType"]]] = None, config: Dict[str, str] = None) -> "DataFrame": """ Perform prediction using a model. This is the general-purpose prediction method. When a provider is configured, a temporary model is created with the given input/output schema. When using a catalog model (no provider), the model's registered schema is used and ``output_type`` is ignored. Args: *input_cols: Column names to use as input. provider: Provider name. Uses default if not specified. If no provider is configured, ``model`` is treated as a catalog model name. model: Model name (e.g. "qwen-plus") or catalog model name. output_type: Output column schema as a dict ``{name: type}``. Type can be a SQL type string (e.g. ``"STRING"``) or a ``DataType`` object. Defaults to ``{"output": "STRING"}``. Only used when a provider is configured. Ignored for catalog models. config: Optional dict of runtime config options. Returns: A new DataFrame with the model output columns appended. Example:: >>> # Single output column (default) >>> df.llm.predict("question", model="qwen-plus") >>> # JSON structured output >>> df.llm.predict("question", model="qwen-plus", ... output_type={"output": "VARIANT"}) >>> # Multiple output columns >>> df.llm.predict("question", model="qwen-plus", ... output_type={"answer": DataType.string(), ... "score": DataType.float64()}) """ from pyflink.dataframe.dataframe import DataFrame as DF from pyflink.dataframe.datatype import DataType if output_type is None: output_type = {"output": "STRING"} # Build output columns from output_type output_columns = [] for name, dtype in output_type.items(): if isinstance(dtype, str): dtype = DataType._from_sql(dtype) output_columns.append((name, dtype._table_type)) # Build input schema from actual column types input_columns = [ (col_name, self._get_column_type(col_name)) for col_name in input_cols ] flink_model = _resolve_model( self._df._table._t_env, model, provider, input_columns, output_columns) result = flink_model.predict( self._df._table, list(input_cols), config) return DF(result)
def _ai_func(self, func_name: str, input_col, model, provider, config, ai_func, *extra_args): """Common logic for ai_* functions.""" from pyflink.dataframe.dataframe import DataFrame as DF input_cols, output_cols = _ai_function_schemas()[func_name] flink_model = _resolve_model( self._df._table._t_env, model, provider, input_cols, output_cols) expr = ai_func( flink_model, self._to_col_expr(input_col), *extra_args, config) return DF(self._df._table.join_lateral(expr))
[docs] def ai_classify(self, input_col: Union[str, "Expression"], labels: List[str], *, provider: str = None, model: str = None, config: Dict[str, str] = None) -> "DataFrame": """ Classify text into one of the provided labels. Args: input_col: Column name (str) or Expression for the input text. labels: List of label strings. provider: Provider name. model: Model name. config: Optional runtime config. Returns: A new DataFrame with columns appended: - ``category`` (STRING): the predicted label. - ``confidence`` (DOUBLE): confidence score. Example:: >>> df.llm.ai_classify("text", ["positive", "negative"], ... model="qwen-plus") """ from pyflink.table.ai_functions import ai_classify as _ai_classify return self._ai_func("ai_classify", input_col, model, provider, config, _ai_classify, labels)
[docs] def ai_sentiment(self, input_col: Union[str, "Expression"], *, provider: str = None, model: str = None, config: Dict[str, str] = None) -> "DataFrame": """ Analyze the sentiment of input text. Args: input_col: Column name (str) or Expression for the input text. provider: Provider name. model: Model name. config: Optional runtime config. Returns: A new DataFrame with columns appended: - ``score`` (DOUBLE): sentiment score from -1.0 to 1.0. - ``label`` (STRING): one of "positive", "negative", "neutral". - ``confidence`` (DOUBLE): confidence score. Example:: >>> df.llm.ai_sentiment("review", model="qwen-plus") """ from pyflink.table.ai_functions import ai_sentiment as _ai_sentiment return self._ai_func("ai_sentiment", input_col, model, provider, config, _ai_sentiment)
[docs] def ai_extract(self, input_col: Union[str, "Expression"], schema: str, *, provider: str = None, model: str = None, config: Dict[str, str] = None) -> "DataFrame": """ Extract structured information from text. Args: input_col: Column name (str) or Expression for the input text. schema: JSON schema string describing the fields to extract, e.g. ``'{"name":"STRING", "phone":"STRING"}'``. provider: Provider name. model: Model name. config: Optional runtime config. Returns: A new DataFrame with a column appended: - ``extracted_json`` (STRING): extracted fields as a JSON string. Example:: >>> df.llm.ai_extract("text", ... '{"name":"STRING", "phone":"STRING"}', model="qwen-plus") """ from pyflink.table.ai_functions import ai_extract as _ai_extract return self._ai_func("ai_extract", input_col, model, provider, config, _ai_extract, schema)
[docs] def ai_translate(self, input_col: Union[str, "Expression"], source_lang: str, target_lang: str, *, provider: str = None, model: str = None, config: Dict[str, str] = None) -> "DataFrame": """ Translate text from one language to another. Args: input_col: Column name (str) or Expression for the input text. source_lang: Source language code (e.g. ``"zh"``, ``"en"``, ``"auto"`` for auto-detection). Supported: ``auto``, ``zh``, ``en``, ``ja``, ``ko``, ``fr``, ``de``, ``es``, ``ru``, ``ar``, ``pt``. target_lang: Target language code (cannot be ``"auto"``). provider: Provider name. model: Model name. config: Optional runtime config. Returns: A new DataFrame with columns appended: - ``translated_text`` (STRING): the translated text. - ``detected_language`` (STRING): detected source language code. Example:: >>> df.llm.ai_translate("text", "zh", "en", model="qwen-plus") """ from pyflink.table.ai_functions import ai_translate as _ai_translate return self._ai_func("ai_translate", input_col, model, provider, config, _ai_translate, source_lang, target_lang)
[docs] def ai_summarize(self, input_col: Union[str, "Expression"], max_length: int, *, provider: str = None, model: str = None, config: Dict[str, str] = None) -> "DataFrame": """ Summarize text to a maximum length. Args: input_col: Column name (str) or Expression for the input text. max_length: Maximum length of the summary in characters. Must be greater than 0. provider: Provider name. model: Model name. config: Optional runtime config. Returns: A new DataFrame with a column appended: - ``summary`` (STRING): the summarized text. Example:: >>> df.llm.ai_summarize("article", 200, model="qwen-plus") """ from pyflink.table.ai_functions import ai_summarize as _ai_summarize return self._ai_func("ai_summarize", input_col, model, provider, config, _ai_summarize, max_length)
[docs] def ai_mask(self, input_col: Union[str, "Expression"], entities: List[str], *, provider: str = None, model: str = None, config: Dict[str, str] = None) -> "DataFrame": """ Mask sensitive information in text. Args: input_col: Column name (str) or Expression for the input text. entities: List of entity types to mask (e.g. ``["name", "phone"]``). provider: Provider name. model: Model name. config: Optional runtime config. Returns: A new DataFrame with columns appended: - ``masked_text`` (STRING): text with sensitive info replaced. - ``detected_entities`` (ARRAY<STRING>): list of detected entity types. Example:: >>> df.llm.ai_mask("text", ["name", "phone"], model="qwen-plus") """ from pyflink.table.ai_functions import ai_mask as _ai_mask return self._ai_func("ai_mask", input_col, model, provider, config, _ai_mask, entities)
[docs] def ai_embed(self, input_col: Union[str, "Expression"], dimension: int = 1024, *, provider: str = None, model: str = None, config: Dict[str, str] = None) -> "DataFrame": """ Generate embedding vectors for text. Args: input_col: Column name (str) or Expression for the input text. dimension: Dimension of the embedding vector (default 1024). provider: Provider name. model: Model name. config: Optional runtime config. Returns: A new DataFrame with a column appended: - ``embedding`` (ARRAY<FLOAT>): the embedding vector. Example:: >>> df.llm.ai_embed("text", 512, model="qwen-plus") """ from pyflink.table.ai_functions import ai_embed as _ai_embed return self._ai_func("ai_embed", input_col, model, provider, config, _ai_embed, dimension)