################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
User-Defined Functions for DataFrame API.
This module provides a modern UDF interface for DataFrame operations,
supporting return_dtype parameter and type hint inference.
"""
import copy
import functools
import inspect
from typing import (
Any,
Callable,
Optional,
Type,
Union,
get_type_hints,
overload,
)
from pyflink.common import Row
from pyflink.table.expression import Expression
from pyflink.table.expressions import call as table_call
from pyflink.table.udf import (
udf as table_udf,
AsyncScalarFunction,
ScalarFunction,
UserDefinedFunctionWrapper,
)
from pyflink.dataframe.datatype import DataType, DataTypeLike
__all__ = ["udf"]
class DataFrameUDFWrapper:
"""
Wrapper for DataFrame UDF that provides a callable interface.
This wrapper holds the underlying Table UDF wrapper and return type,
allowing the UDF to be called with Expression arguments.
"""
def __init__(
self,
table_udf_wrapper: UserDefinedFunctionWrapper,
return_dtype: DataType,
):
"""
Initialize DataFrame UDF wrapper.
Args:
table_udf_wrapper: The underlying Table API UDF wrapper.
return_dtype: The return DataType.
"""
self._table_udf_wrapper = table_udf_wrapper
self._return_dtype = return_dtype
# Copy function metadata
self.__name__ = getattr(table_udf_wrapper, '_name', 'udf')
self.__doc__ = getattr(getattr(table_udf_wrapper, '_func', None), '__doc__', None)
def __call__(self, *args: Any) -> Expression:
"""
Call the UDF with Expression arguments.
Args:
*args: Expression arguments.
Returns:
An Expression representing the UDF call.
"""
return table_call(self._table_udf_wrapper, *args)
@property
def return_dtype(self) -> DataType:
"""Get the return DataType."""
return self._return_dtype
# Overload 1: function is supplied at the call site, so udf returns the
# wrapped UDF directly. Covers:
# @udf
# def fn(...): ... # bare decorator: udf(fn)
# udf(MyScalarFunction) # direct call with a class
# udf(MyScalarFunction(), # direct call with an instance and kwargs
# return_dtype=...)
@overload
def udf(
func: Union[Callable, ScalarFunction, AsyncScalarFunction, Type],
*,
return_dtype: Optional[DataTypeLike] = ...,
deterministic: bool = ...,
name: Optional[str] = ...,
func_type: Optional[str] = ...,
concurrency: Optional[int] = ...,
batch_size: Optional[int] = ...,
num_gpus: Optional[float] = ...,
gpu_type: Optional[str] = ...,
) -> "DataFrameUDFWrapper":
...
# Overload 2: no function supplied — udf is being used as a decorator
# factory and must return a decorator that will receive the function next.
# Covers the parameterized decorator form:
# @udf(concurrency=32)
# async def fn(...): ... # evaluated as udf(concurrency=32)(fn)
@overload
def udf(
func: None = ...,
*,
return_dtype: Optional[DataTypeLike] = ...,
deterministic: bool = ...,
name: Optional[str] = ...,
func_type: Optional[str] = ...,
concurrency: Optional[int] = ...,
batch_size: Optional[int] = ...,
num_gpus: Optional[float] = ...,
gpu_type: Optional[str] = ...,
) -> Callable[
[Union[Callable, ScalarFunction, AsyncScalarFunction, Type]],
"DataFrameUDFWrapper",
]:
...
[docs]
def udf(
func: Optional[Union[Callable, ScalarFunction, AsyncScalarFunction, Type]] = None,
*,
return_dtype: Optional[DataTypeLike] = None,
deterministic: bool = True,
name: Optional[str] = None,
func_type: Optional[str] = None,
concurrency: Optional[int] = None,
batch_size: Optional[int] = None,
num_gpus: Optional[float] = None,
gpu_type: Optional[str] = None,
):
"""
Create a user-defined function for DataFrame operations.
This decorator supports:
- Plain Python functions and lambdas
- Async functions defined with ``async def`` (executed asynchronously)
- ScalarFunction subclasses (class type or instance), using eval method
- AsyncScalarFunction subclasses (class type or instance), using async eval method
- Plain callable classes with __call__ method (class type or instance)
- Automatic type inference from Python type hints
Args:
func: The Python function, ScalarFunction instance/subclass, or
callable class instance/type to wrap as a UDF.
return_dtype: Optional return type. Can be:
- DataType instance (e.g., DataType.int64())
- Python type (e.g., int, str, float)
- String (e.g., 'INT', 'BIGINT')
If not specified, inferred from function's return type hint
(eval for ScalarFunction, __call__ for plain classes).
deterministic: Whether the function is deterministic (default: True).
name: Optional name for the UDF.
func_type: Optional execution format. Supported values are
``"general"``, ``"pandas"``, and ``"arrow"``. If omitted,
pandas and arrow formats are detected from pandas/pyarrow
type hints when possible, otherwise ``"general"`` is used.
concurrency: Optional concurrency (parallelism) for the UDF operator.
If specified, the operator running this UDF will use this
parallelism. UDFs with different concurrency values will be
split into separate operators.
batch_size: Optional maximum number of elements per batch.
Only applies to batch-wise UDFs. If not specified,
the global configuration 'python.fn-execution.arrow.batch.size'
is used (default 1000). When multiple UDFs are fused,
the maximum batch_size among them is used.
num_gpus: Optional number of GPUs requested for this UDF (e.g., 0.5, 1.0).
Each GPU UDF will run in its own operator and will not be fused
with other UDFs (including other GPU UDFs).
gpu_type: Optional GPU type (e.g., 'A10', 'V100').
Returns:
A DataFrameUDFWrapper that can be called with Expressions.
Example::
>>> from pyflink.dataframe import udf, DataType, col
>>> from pyflink.table.udf import ScalarFunction
>>>
>>> # 1. Plain function with explicit return_dtype
>>> @udf(return_dtype=DataType.string())
... def to_string(x):
... return str(x)
>>>
>>> df.with_columns(s=to_string(col("a")))
>>>
>>> # 2. Plain function with type hint inference
>>> @udf
... def add(x: int, y: int) -> int:
... return x + y
>>>
>>> df.with_columns(sum=add(col("a"), col("b")))
>>>
>>> # 3. ScalarFunction with explicit return_dtype (instance or class)
>>> class AddOne(ScalarFunction):
... def eval(self, x):
... return x + 1
>>>
>>> add_one = udf(AddOne(), return_dtype=DataType.int64())
>>> add_one = udf(AddOne, return_dtype=DataType.int64()) # auto-instantiated
>>> df.with_columns(b=add_one(col("a")))
>>>
>>> # 4. ScalarFunction instance with type hint inference
>>> class Double(ScalarFunction):
... def eval(self, x: int) -> int:
... return x * 2
>>>
>>> double = udf(Double())
>>> df.with_columns(doubled=double(col("a")))
>>>
>>> # 5. @udf decorator on ScalarFunction class (with return_dtype or type hints)
>>> @udf(return_dtype=DataType.int64())
... class AddTwo(ScalarFunction):
... def eval(self, x):
... return x + 2
>>>
>>> df.with_columns(b=AddTwo(col("a")))
>>>
>>> @udf
... class Triple(ScalarFunction):
... def eval(self, x: int) -> int:
... return x * 3
>>>
>>> df.with_columns(tripled=Triple(col("a")))
>>>
>>> # 6. Plain callable class instance
>>> class MultiplyBy:
... def __init__(self, factor):
... self._factor = factor
...
... def __call__(self, x: int) -> int:
... return x * self._factor
>>>
>>> times_three = udf(MultiplyBy(3))
>>> df.with_columns(result=times_three(col("a")))
>>>
>>> # 7. @udf decorator on plain callable class
>>> @udf(return_dtype=DataType.int64())
... class AddOne:
... def __call__(self, x):
... return x + 1
>>>
>>> df.with_columns(b=AddOne(col("a")))
>>>
>>> @udf
... class DoubleIt:
... def __call__(self, x: int) -> int:
... return x * 2
>>>
>>> df.with_columns(doubled=DoubleIt(col("a")))
>>>
>>> # 8. Async function (executed asynchronously, e.g. for I/O-bound work)
>>> import asyncio
>>> @udf
... async def async_lookup(key: str) -> str:
... await asyncio.sleep(0.01)
... return f"value_for_{key}"
>>>
>>> df.with_columns(v=async_lookup(col("a")))
>>>
>>> # 9. AsyncScalarFunction subclass
>>> from pyflink.table.udf import AsyncScalarFunction
>>> class AsyncLookup(AsyncScalarFunction):
... async def eval(self, key: str) -> str:
... await asyncio.sleep(0.01)
... return f"value_for_{key}"
>>>
>>> async_lookup = udf(AsyncLookup())
>>> df.with_columns(v=async_lookup(col("a")))
>>>
>>> # 10. Pandas scalar function (executed on pandas.Series batches)
>>> import pandas as pd
>>> @udf(return_dtype=DataType.int64(), func_type="pandas")
... def add_one_pandas(x: pd.Series) -> pd.Series:
... return x + 1
>>>
>>> df.with_columns(v=add_one_pandas(col("a")))
>>>
>>> # 11. Pandas async scalar function (executed on pandas.Series batches)
>>> import pandas as pd
>>> @udf(return_dtype=DataType.string(), batch_size=128)
... async def async_batch_lookup(keys: pd.Series) -> pd.Series:
... await asyncio.sleep(0.01)
... return keys.map(lambda key: f"value_for_{key}")
>>>
>>> df.with_columns(v=async_batch_lookup(col("a")))
>>>
>>> # 12. Arrow scalar function (executed on pyarrow.Array batches)
>>> import pyarrow as pa
>>> @udf(return_dtype=DataType.int64(), func_type="arrow")
... def add_one_arrow(x):
... return pa.array([v + 1 for v in x.to_pylist()])
>>>
>>> df.with_columns(v=add_one_arrow(col("a")))
>>>
>>> # 13. Arrow async scalar function (executed on pyarrow.Array batches)
>>> import pyarrow as pa
>>> @udf(return_dtype=DataType.string(), func_type="arrow", batch_size=128)
... async def async_arrow_lookup(keys):
... await asyncio.sleep(0.01)
... return pa.array([f"value_for_{k}" for k in keys.to_pylist()])
>>>
>>> df.with_columns(v=async_arrow_lookup(col("a")))
"""
if concurrency is not None and (not isinstance(concurrency, int) or concurrency <= 0):
raise ValueError("concurrency must be a positive integer, got: {}".format(concurrency))
if batch_size is not None and (not isinstance(batch_size, int) or batch_size <= 0):
raise ValueError("batch_size must be a positive integer, got: {}".format(batch_size))
if num_gpus is not None and (not isinstance(num_gpus, (int, float)) or num_gpus <= 0):
raise ValueError("num_gpus must be a positive number, got: {}".format(num_gpus))
if num_gpus is not None and gpu_type is None:
raise ValueError("gpu_type must be specified when num_gpus is set")
# Handle class-based UDFs (class type or instance)
if func is not None and _is_class_udf(func):
return _create_class_udf(func, return_dtype, deterministic, name, concurrency, batch_size,
num_gpus, gpu_type, func_type)
def decorator(f) -> DataFrameUDFWrapper:
# Handle class-based UDFs passed through @udf(...) decorator
if _is_class_udf(f):
return _create_class_udf(f, return_dtype, deterministic, name, concurrency, batch_size,
num_gpus, gpu_type, func_type)
# Infer return type
actual_return_dtype = _infer_return_dtype(f, return_dtype)
# Create Table API UDF
table_result_type = actual_return_dtype._table_type
# Determine func_type
actual_func_type = func_type or _detect_func_type(f)
_validate_scalar_udf_options(f, actual_func_type, batch_size)
# Create Table API UDF wrapper
table_wrapper = table_udf(
f,
result_type=table_result_type,
deterministic=deterministic,
name=name,
func_type=actual_func_type,
concurrency=concurrency,
batch_size=batch_size,
num_gpus=num_gpus,
gpu_type=gpu_type,
)
# Create DataFrame UDF wrapper
return DataFrameUDFWrapper(table_wrapper, actual_return_dtype)
# Handle both @udf and @udf(...) syntax
if func is not None:
# @udf without parentheses
return decorator(func)
else:
# @udf(...) with parentheses
return decorator
def _validate_scalar_udf_options(func, func_type: str,
batch_size: Optional[int]) -> None:
if batch_size is not None and func_type not in ("pandas", "arrow"):
raise ValueError(
"batch_size is only supported for batch-wise UDFs. "
"The function '{}' is detected as a general UDF. "
"Use it with map_batches() or annotate parameters with pandas/pyarrow types."
.format(getattr(func, '__name__', type(func).__name__)))
def _has_custom_call(cls) -> bool:
"""Check if cls defines __call__ in its MRO (excluding object)."""
return any('__call__' in c.__dict__ for c in cls.__mro__ if c is not object)
def _is_class_udf(func) -> bool:
"""
Check if func is a class-based UDF (ScalarFunction, AsyncScalarFunction, or callable class).
"""
# Case 1: ScalarFunction / AsyncScalarFunction instance
if isinstance(func, (ScalarFunction, AsyncScalarFunction)):
return True
if inspect.isclass(func):
# Case 2: ScalarFunction / AsyncScalarFunction class (will be auto-instantiated)
if issubclass(func, (ScalarFunction, AsyncScalarFunction)):
return True
# Case 3: Plain class with __call__ (will be auto-instantiated),
# e.g. udf(MyCallable, return_dtype=...) or @udf class MyCallable
# Also handles inherited __call__ from parent class.
if _has_custom_call(func):
return True
# Case 4: Instance of a plain callable class (not a function, not a class),
# e.g. udf(MultiplyBy(3), return_dtype=...)
# Note: cannot use callable() here as it would also match lambdas etc.
if not inspect.isfunction(func) and not inspect.isclass(func) and hasattr(func, '__call__'):
if _has_custom_call(type(func)):
return True
return False
def _create_class_udf(
func: Union[ScalarFunction, AsyncScalarFunction, Type],
return_dtype: Optional[DataTypeLike],
deterministic: bool,
name: Optional[str],
concurrency: Optional[int] = None,
batch_size: Optional[int] = None,
num_gpus: Optional[float] = None,
gpu_type: Optional[str] = None,
func_type: Optional[str] = None,
) -> DataFrameUDFWrapper:
"""Create a DataFrameUDFWrapper from a class-based UDF."""
# Instantiate if a class is passed
if inspect.isclass(func):
instance = func()
else:
instance = func
is_scalar = isinstance(instance, (ScalarFunction, AsyncScalarFunction))
# Determine which method to use for type hint inference
hint_method = instance.eval if is_scalar else instance.__call__
# Infer return type
actual_return_dtype = _infer_return_dtype(hint_method, return_dtype)
table_result_type = actual_return_dtype._table_type
# Detect func type (general, pandas, or arrow)
actual_func_type = func_type or _detect_func_type(hint_method)
_validate_scalar_udf_options(func, actual_func_type, batch_size)
# For ScalarFunction, pass the instance directly.
# For plain callable classes, pass the instance as a callable.
table_wrapper = table_udf(
instance,
result_type=table_result_type,
deterministic=deterministic,
name=name,
func_type=actual_func_type,
concurrency=concurrency,
batch_size=batch_size,
num_gpus=num_gpus,
gpu_type=gpu_type,
)
return DataFrameUDFWrapper(table_wrapper, actual_return_dtype)
def _infer_return_dtype(func: Callable, return_dtype: Optional[DataTypeLike]) -> DataType:
"""
Infer return DataType from function or explicit specification.
Args:
func: The Python function.
return_dtype: Explicitly specified return type (can be None).
Returns:
The inferred or specified DataType.
Raises:
ValueError: If no return type hint and no explicit return_dtype.
TypeError: If the type cannot be inferred.
"""
if return_dtype is not None:
return _convert_to_dtype(return_dtype)
# Try to infer from type hints
try:
hints = get_type_hints(func)
except Exception:
hints = {}
if "return" not in hints:
raise ValueError(
f"Function '{func.__name__}' must have a return type hint "
f"or specify return_dtype parameter explicitly.\n"
f"Example:\n"
f" @udf\n"
f" def {func.__name__}(x: int) -> int:\n"
f" return x + 1\n\n"
f"Or:\n"
f" @udf(return_dtype=DataType.int64())\n"
f" def {func.__name__}(x):\n"
f" return x + 1"
)
return DataType._infer_from_type(hints["return"])
def _convert_to_dtype(dtype_like: DataTypeLike) -> DataType:
"""
Convert DataTypeLike to DataType.
Args:
dtype_like: DataType instance, Python type, or SQL string.
Returns:
A DataType instance.
"""
if isinstance(dtype_like, DataType):
return dtype_like
elif isinstance(dtype_like, str):
return DataType._from_sql(dtype_like)
else:
# Assume it's a Python type
return DataType._infer_from_type(dtype_like)
def _detect_func_type(func: Callable) -> str:
"""
Detect the function type (general, pandas, or arrow).
This checks if the function signature uses pandas or pyarrow types.
Args:
func: The Python function.
Returns:
'general', 'pandas', or 'arrow'.
"""
try:
import pandas as pd
func_globals = getattr(func, '__globals__', {})
hints = get_type_hints(func, globalns={**func_globals, "pandas": pd, "pd": pd})
# Check if any parameter or return type is pandas Series/DataFrame
for hint in hints.values():
origin = getattr(hint, "__origin__", None)
if origin is not None:
# Check generic types
pass
elif hint in (pd.Series, pd.DataFrame):
return "pandas"
except (ImportError, NameError):
pass
try:
import pyarrow as pa
func_globals = getattr(func, '__globals__', {})
hints = get_type_hints(func, globalns={**func_globals, "pyarrow": pa, "pa": pa})
# Check if any parameter or return type is an Arrow batch/column type.
for hint in hints.values():
if hint in (pa.Array, pa.ChunkedArray):
return "arrow"
except (ImportError, NameError):
pass
return "general"
def _is_typeddict(tp) -> bool:
"""Check if tp is a TypedDict class."""
# typing.is_typeddict available since Python 3.10
try:
from typing import is_typeddict
return is_typeddict(tp)
except (ImportError, NameError):
pass
# Fallback: TypedDict classes have __required_keys__ (Python 3.9+)
return (
isinstance(tp, type)
and issubclass(tp, dict)
and hasattr(tp, '__required_keys__')
)
def _infer_struct_from_typeddict(td) -> DataType:
"""Infer a DataType.struct from a TypedDict class."""
hints = get_type_hints(td)
fields = {name: DataType._infer_from_type(hint) for name, hint in hints.items()}
return DataType.struct(fields)
def _resolve_map_udf(func, return_dtype, func_type, input_columns,
concurrency: Optional[int] = None, batch_size: Optional[int] = None,
num_gpus: Optional[float] = None, gpu_type: Optional[str] = None):
"""
Resolve a function into a Table API UDF wrapper for map/map_batches.
The user function uses dict as input/output:
- For map (general): func(dict[str, scalar]) -> dict[str, scalar]
- For map_batches (pandas): func(dict[str, pd.Series]) -> dict[str, pd.Series]
- For map_batches (arrow): func(dict[str, pa.Array]) -> dict[str, pa.Array]
If func is already a DataFrameUDFWrapper, extracts the underlying wrapper directly.
TypedDict return type hints are supported for automatic return_dtype inference.
Args:
func: A DataFrameUDFWrapper or raw callable.
return_dtype: The return DataType (can be None if TypedDict hint is present).
func_type: 'general' for map, 'pandas' or 'arrow' for map_batches.
input_columns: List of input column names from the source DataFrame.
concurrency: Optional concurrency (parallelism) for the UDF operator.
num_gpus: Optional number of GPUs for the UDF operator.
gpu_type: Optional GPU type for the UDF operator.
Returns:
A UserDefinedScalarFunctionWrapper for Table.map().
"""
if isinstance(func, DataFrameUDFWrapper):
wrapper = func._table_udf_wrapper
if concurrency is not None or batch_size is not None or num_gpus is not None \
or gpu_type is not None:
# Override settings if explicitly specified in map/map_batches call.
# Use a shallow copy to avoid mutating the original wrapper.
wrapper = copy.copy(wrapper)
wrapper._judf_placeholder = None
if concurrency is not None:
wrapper._concurrency = concurrency
if batch_size is not None:
wrapper._batch_size = batch_size
if num_gpus is not None:
wrapper._num_gpus = num_gpus
if gpu_type is not None:
wrapper._gpu_type = gpu_type
return wrapper
if return_dtype is None:
if func_type == "arrow":
raise ValueError(
"return_dtype is required for arrow map_batches because the "
"output Arrow schema cannot be inferred from type hints.")
# Try to infer return_dtype from TypedDict return type hint
try:
hints = get_type_hints(func)
except NameError as e:
import warnings
warnings.warn(
f"Failed to resolve type hints for {func}: {e}. "
"Please specify return_dtype explicitly.",
stacklevel=2,
)
hints = {}
except Exception:
hints = {}
ret = hints.get("return")
if ret is not None and _is_typeddict(ret):
return_dtype = _infer_struct_from_typeddict(ret)
else:
raise ValueError(
"return_dtype is required when passing a raw function to "
"map/map_batches.\n"
"You can either specify return_dtype explicitly:\n"
" df.map(func, return_dtype=DataType.struct("
"{'a': DataType.int64()}))\n"
"Or use a TypedDict return type hint:\n"
" class Output(TypedDict):\n"
" a: int\n"
" def func(row: ...) -> Output: ..."
)
dtype = _convert_to_dtype(return_dtype)
if func_type == "general":
wrapped = _wrap_map_general(func, input_columns, dtype)
elif func_type == "arrow":
wrapped = _wrap_map_arrow(func, input_columns, dtype)
else:
wrapped = _wrap_map_pandas(func, input_columns, dtype)
return table_udf(
wrapped,
result_type=dtype._table_type,
func_type=func_type,
name=func.__name__ if hasattr(func, '__name__') else type(func).__name__,
concurrency=concurrency,
batch_size=batch_size,
num_gpus=num_gpus,
gpu_type=gpu_type,
)
def _wrap_map_general(func, input_columns, return_dtype):
"""Wrap a dict-based map function for general UDF (Row -> Row)."""
output_names = return_dtype._table_type.field_names()
@functools.wraps(func)
def wrapper(row):
input_dict = {name: row[i] for i, name in enumerate(input_columns)}
result_dict = func(input_dict)
return Row(*[result_dict[name] for name in output_names])
return wrapper
def _wrap_map_pandas(func, input_columns, return_dtype):
"""Wrap a dict-based map_batches function for pandas UDF (pd.DataFrame -> Row of Series)."""
import pandas as pd
output_names = return_dtype._table_type.field_names()
@functools.wraps(func)
def wrapper(pdf):
input_dict = {name: pdf.iloc[:, i] for i, name in enumerate(input_columns)}
result_dict = func(input_dict)
return pd.concat([result_dict[name] for name in output_names], axis=1)
return wrapper
def _wrap_map_arrow(func, input_columns, return_dtype):
"""Wrap a dict-based map_batches function for Arrow UDF."""
output_names = return_dtype._table_type.field_names()
def wrapper(batch):
input_dict = {name: batch.column(i) for i, name in enumerate(input_columns)}
result_dict = func(input_dict)
return _record_batch_from_column_dict(result_dict, output_names)
return wrapper
def _record_batch_from_column_dict(result_dict, output_names):
import pyarrow as pa
import pyarrow_hotfix # noqa # pylint: disable=unused-import
columns = []
for name in output_names:
if name not in result_dict:
raise KeyError(
"Arrow batch UDF result is missing output column %r." % name)
column = result_dict[name]
if isinstance(column, pa.ChunkedArray):
if column.num_chunks == 1:
column = column.chunk(0)
else:
column = column.combine_chunks()
columns.append(column)
return pa.record_batch(columns, names=output_names)