################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import abc
import functools
import inspect
from typing import (
Any, Awaitable, Callable, Generic, Iterable, List, Optional, TYPE_CHECKING, Type,
TypeVar, Union, overload,
)
from pyflink.java_gateway import get_gateway
from pyflink.metrics import MetricGroup
from pyflink.table import Expression
from pyflink.table.types import DataType, _to_java_data_type
from pyflink.util import java_utils
__all__ = ['FunctionContext', 'AggregateFunction', 'ScalarFunction', 'TableFunction',
'TableAggregateFunction', 'AsyncScalarFunction', 'udf', 'udtf', 'udaf', 'udtaf']
class FunctionContext(object):
"""
Used to obtain global runtime information about the context in which the
user-defined function is executed. The information includes the metric group,
and global job parameters, etc.
"""
def __init__(self, base_metric_group, job_parameters):
self._base_metric_group = base_metric_group
self._job_parameters = job_parameters
def get_metric_group(self) -> MetricGroup:
"""
Returns the metric group for this parallel subtask.
.. versionadded:: 1.11.0
"""
if self._base_metric_group is None:
raise RuntimeError("Metric has not been enabled. You can enable "
"metric with the 'python.metric.enabled' configuration.")
return self._base_metric_group
def get_job_parameter(self, key: str, default_value: str) -> str:
"""
Gets the global job parameter value associated with the given key as a string.
:param key: The key pointing to the associated value.
:param default_value: The default value which is returned in case global job parameter is
null or there is no value associated with the given key.
.. versionadded:: 1.17.0
"""
return self._job_parameters[key] if key in self._job_parameters else default_value
def get_config_parameter(self, key: str, default_value: str) -> str:
"""
Gets the framework configuration value associated with the given key as a string.
This is separate from user-defined job parameters at the API boundary, even though
Python function execution currently transports selected configuration entries through
the same serialized parameter map.
:param key: The key pointing to the associated configuration value.
:param default_value: The default value returned when the key is absent.
"""
return self._job_parameters[key] if key in self._job_parameters else default_value
class UserDefinedFunction(abc.ABC):
"""
Base interface for user-defined function.
.. versionadded:: 1.10.0
"""
def open(self, function_context: FunctionContext):
"""
Initialization method for the function. It is called before the actual working methods
and thus suitable for one time setup work.
:param function_context: the context of the function
:type function_context: FunctionContext
"""
pass
def close(self):
"""
Tear-down method for the user code. It is called after the last call to the main
working methods.
"""
pass
def is_deterministic(self) -> bool:
"""
Returns information about the determinism of the function's results.
It returns true if and only if a call to this function is guaranteed to
always return the same result given the same parameters. true is assumed by default.
If the function is not pure functional like random(), date(), now(),
this method must return false.
:return: the determinism of the function's results.
"""
return True
[docs]
class ScalarFunction(UserDefinedFunction):
"""
Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one,
or multiple scalar values to a new scalar value.
.. versionadded:: 1.10.0
"""
if TYPE_CHECKING:
eval: Callable[..., Any]
else:
@abc.abstractmethod
def eval(self, *args):
"""
Method which defines the logic of the scalar function.
"""
pass
class AsyncScalarFunction(UserDefinedFunction):
"""
Base interface for user-defined async scalar function. A user-defined async scalar function
maps zero, one, or multiple scalar values to a new scalar value asynchronously.
This function is similar to ScalarFunction but is executed asynchronously. It's useful when
interacting with external systems (e.g., databases, REST APIs) where I/O operations would
otherwise block.
The eval method should be an async coroutine function that returns the result asynchronously.
Example:
::
>>> class AsyncLookupFunction(AsyncScalarFunction):
... async def eval(self, key):
... # Simulate async I/O operation
... await asyncio.sleep(0.1)
... return f"value_for_{key}"
.. versionadded:: 2.3.0
"""
if TYPE_CHECKING:
eval: Callable[..., Awaitable[Any]]
else:
@abc.abstractmethod
async def eval(self, *args):
"""
Async method which defines the logic of the async scalar function.
This method should be an async coroutine.
"""
pass
[docs]
class TableFunction(UserDefinedFunction):
"""
Base interface for user-defined table function. A user-defined table function creates zero, one,
or multiple rows to a new row value.
.. versionadded:: 1.11.0
"""
if TYPE_CHECKING:
eval: Callable[..., Iterable[Any]]
else:
@abc.abstractmethod
def eval(self, *args):
"""
Method which defines the logic of the table function.
"""
pass
T = TypeVar('T')
ACC = TypeVar('ACC')
class ImperativeAggregateFunction(UserDefinedFunction, Generic[T, ACC]):
"""
Base interface for user-defined aggregate function and table aggregate function.
This class is used for unified handling of imperative aggregating functions. Concrete
implementations should extend from :class:`~pyflink.table.AggregateFunction` or
:class:`~pyflink.table.TableAggregateFunction`.
.. versionadded:: 1.13.0
"""
@abc.abstractmethod
def create_accumulator(self) -> ACC:
"""
Creates and initializes the accumulator for this AggregateFunction.
:return: the accumulator with the initial value
"""
pass
@abc.abstractmethod
def accumulate(self, accumulator: ACC, *args):
"""
Processes the input values and updates the provided accumulator instance.
:param accumulator: the accumulator which contains the current aggregated results
:param args: the input value (usually obtained from new arrived data)
"""
pass
def retract(self, accumulator: ACC, *args):
"""
Retracts the input values from the accumulator instance.The current design assumes the
inputs are the values that have been previously accumulated.
:param accumulator: the accumulator which contains the current aggregated results
:param args: the input value (usually obtained from new arrived data).
"""
raise RuntimeError("Method retract is not implemented")
def merge(self, accumulator: ACC, accumulators):
"""
Merges a group of accumulator instances into one accumulator instance. This method must be
implemented for unbounded session window grouping aggregates and bounded grouping
aggregates.
:param accumulator: the accumulator which will keep the merged aggregate results. It should
be noted that the accumulator may contain the previous aggregated
results. Therefore user should not replace or clean this instance in the
custom merge method.
:param accumulators: a group of accumulators that will be merged.
"""
raise RuntimeError("Method merge is not implemented")
def get_result_type(self) -> Union[DataType, str]:
"""
Returns the DataType of the AggregateFunction's result.
:return: The :class:`~pyflink.table.types.DataType` of the AggregateFunction's result.
"""
raise RuntimeError("Method get_result_type is not implemented")
def get_accumulator_type(self) -> Union[DataType, str]:
"""
Returns the DataType of the AggregateFunction's accumulator.
:return: The :class:`~pyflink.table.types.DataType` of the AggregateFunction's accumulator.
"""
raise RuntimeError("Method get_accumulator_type is not implemented")
[docs]
class AggregateFunction(ImperativeAggregateFunction):
"""
Base interface for user-defined aggregate function. A user-defined aggregate function maps
scalar values of multiple rows to a new scalar value.
.. versionadded:: 1.12.0
"""
@abc.abstractmethod
def get_value(self, accumulator: ACC) -> T: # type: ignore[type-var]
"""
Called every time when an aggregation result should be materialized. The returned value
could be either an early and incomplete result (periodically emitted as data arrives) or
the final result of the aggregation.
:param accumulator: the accumulator which contains the current intermediate results
:return: the aggregation result
"""
pass
[docs]
class TableAggregateFunction(ImperativeAggregateFunction):
"""
Base class for a user-defined table aggregate function. A user-defined table aggregate function
maps scalar values of multiple rows to zero, one, or multiple rows (or structured types). If an
output record consists of only one field, the structured record can be omitted, and a scalar
value can be emitted that will be implicitly wrapped into a row by the runtime.
.. versionadded:: 1.13.0
"""
@abc.abstractmethod
def emit_value(self, accumulator: ACC) -> Iterable[T]:
"""
Called every time when an aggregation result should be materialized. The returned value
could be either an early and incomplete result (periodically emitted as data arrives) or the
final result of the aggregation.
:param accumulator: the accumulator which contains the current aggregated results.
:return: multiple aggregated result
"""
pass
class DelegatingScalarFunction(ScalarFunction):
"""
Helper scalar function implementation for lambda expression and python function. It's for
internal use only.
"""
def __init__(self, func):
self.func = func
def eval(self, *args):
return self.func(*args)
class DelegatingAsyncScalarFunction(AsyncScalarFunction):
"""
Helper async scalar function implementation for async lambda expression and python async
function. It's for internal use only.
"""
def __init__(self, func):
self.func = func
async def eval(self, *args):
return await self.func(*args)
class DelegationTableFunction(TableFunction):
"""
Helper table function implementation for lambda expression and python function. It's for
internal use only.
"""
def __init__(self, func):
self.func = func
def eval(self, *args):
return self.func(*args)
class DelegatingPandasAggregateFunction(AggregateFunction):
"""
Helper pandas aggregate function implementation for lambda expression and python function.
It's for internal use only.
"""
def __init__(self, func):
self.func = func
def get_value(self, accumulator):
return accumulator[0]
def create_accumulator(self):
return []
def accumulate(self, accumulator, *args):
accumulator.append(self.func(*args))
class PandasAggregateFunctionWrapper(object):
"""
Wrapper for Pandas Aggregate function.
"""
def __init__(self, func: AggregateFunction):
self.func = func
def open(self, function_context: FunctionContext):
self.func.open(function_context)
def eval(self, *args):
accumulator = self.func.create_accumulator()
self.func.accumulate(accumulator, *args)
return self.func.get_value(accumulator)
def close(self):
self.func.close()
class UserDefinedFunctionWrapper(object):
"""
Base Wrapper for Python user-defined function. It handles things like converting lambda
functions to user-defined functions, creating the Java user-defined function representation,
etc. It's for internal use only.
"""
def __init__(self, func, input_types, func_type, deterministic=None, name=None,
concurrency=None, batch_size=None, num_gpus=None, gpu_type=None):
if inspect.isclass(func) or (
not isinstance(func, UserDefinedFunction) and not callable(func)):
raise TypeError(
"Invalid function: not a function or callable (__call__ is not defined): {0}"
.format(type(func)))
if input_types is not None:
from pyflink.table.types import RowType
if isinstance(input_types, RowType):
input_types = input_types.field_types()
elif isinstance(input_types, (DataType, str)):
input_types = [input_types]
else:
input_types = list(input_types)
for input_type in input_types:
if not isinstance(input_type, (DataType, str)):
raise TypeError(
"Invalid input_type: input_type should be DataType or str but contains {}"
.format(input_type))
self._func = func
self._input_types = input_types
self._name = name or (
func.__name__ if hasattr(func, '__name__') else func.__class__.__name__)
if deterministic is not None and isinstance(func, UserDefinedFunction) and deterministic \
!= func.is_deterministic():
raise ValueError("Inconsistent deterministic: {} and {}".format(
deterministic, func.is_deterministic()))
# default deterministic is True
self._deterministic = deterministic if deterministic is not None else (
func.is_deterministic() if isinstance(func, UserDefinedFunction) else True)
self._func_type = func_type
self._judf_placeholder = None
self._takes_row_as_input = False
self._concurrency = concurrency
self._batch_size = batch_size
self._num_gpus = num_gpus
self._gpu_type = gpu_type
def __call__(self, *args) -> Expression:
from pyflink.table import expressions as expr
return expr.call(self, *args)
def alias(self, *alias_names: str):
self._alias_names = alias_names
return self
def _set_takes_row_as_input(self):
self._takes_row_as_input = True
return self
def _java_user_defined_function(self):
if self._judf_placeholder is None:
gateway = get_gateway()
def get_python_function_kind():
JPythonFunctionKind = gateway.jvm.org.apache.flink.table.functions.python. \
PythonFunctionKind
if self._func_type == "general":
return JPythonFunctionKind.GENERAL
elif self._func_type == "pandas":
return JPythonFunctionKind.PANDAS
elif self._func_type == "arrow":
return JPythonFunctionKind.ARROW
else:
raise TypeError("Unsupported func_type: %s." % self._func_type)
if self._input_types is not None:
if isinstance(self._input_types[0], str):
j_input_types = java_utils.to_jarray(gateway.jvm.String, self._input_types)
else:
j_input_types = java_utils.to_jarray(
gateway.jvm.DataType, [_to_java_data_type(i) for i in self._input_types])
else:
j_input_types = None
j_function_kind = get_python_function_kind()
func = self._func
if not isinstance(self._func, UserDefinedFunction):
func = self._create_delegate_function()
import cloudpickle
serialized_func = cloudpickle.dumps(func)
self._judf_placeholder = \
self._create_judf(serialized_func, j_input_types, j_function_kind)
return self._judf_placeholder
def _create_delegate_function(self) -> UserDefinedFunction:
pass
def _create_judf(self, serialized_func, j_input_types, j_function_kind):
pass
class UserDefinedScalarFunctionWrapper(UserDefinedFunctionWrapper):
"""
Wrapper for Python user-defined scalar function.
"""
def __init__(self, func, input_types, result_type, func_type, deterministic, name,
concurrency=None, batch_size=None, num_gpus=None, gpu_type=None):
super(UserDefinedScalarFunctionWrapper, self).__init__(
func, input_types, func_type, deterministic, name, concurrency, batch_size,
num_gpus, gpu_type)
if not isinstance(result_type, (DataType, str)):
raise TypeError(
"Invalid returnType: returnType should be DataType or str but is {}".format(
result_type))
self._result_type = result_type
self._judf_placeholder = None
def _create_judf(self, serialized_func, j_input_types, j_function_kind):
gateway = get_gateway()
if isinstance(self._result_type, DataType):
j_result_type = _to_java_data_type(self._result_type)
else:
j_result_type = self._result_type
PythonScalarFunction = gateway.jvm \
.org.apache.flink.table.functions.python.PythonScalarFunction
j_scalar_function = PythonScalarFunction(
self._name,
bytearray(serialized_func),
j_input_types,
j_result_type,
j_function_kind,
self._deterministic,
self._takes_row_as_input,
_get_python_env())
if self._concurrency is not None:
j_scalar_function.setParallelism(self._concurrency)
if self._batch_size is not None:
j_scalar_function.setMaxArrowBatchSize(self._batch_size)
if self._num_gpus is not None:
j_scalar_function.setNumGpus(self._num_gpus)
if self._gpu_type is not None:
j_scalar_function.setGpuType(self._gpu_type)
return j_scalar_function
def _create_delegate_function(self) -> UserDefinedFunction:
return DelegatingScalarFunction(self._func)
class UserDefinedAsyncScalarFunctionWrapper(UserDefinedFunctionWrapper):
"""
Wrapper for Python user-defined async scalar function.
"""
def __init__(self, func, input_types, result_type, func_type, deterministic, name,
concurrency=None, batch_size=None, num_gpus=None, gpu_type=None):
super(UserDefinedAsyncScalarFunctionWrapper, self).__init__(
func, input_types, func_type, deterministic, name, concurrency, batch_size,
num_gpus, gpu_type)
if not isinstance(result_type, (DataType, str)):
raise TypeError(
"Invalid returnType: returnType should be DataType or str but is {}".format(
result_type))
self._result_type = result_type
self._judf_placeholder = None
def _create_judf(self, serialized_func, j_input_types, j_function_kind):
gateway = get_gateway()
if isinstance(self._result_type, DataType):
j_result_type = _to_java_data_type(self._result_type)
else:
j_result_type = self._result_type
PythonAsyncScalarFunction = gateway.jvm \
.org.apache.flink.table.functions.python.PythonAsyncScalarFunction
j_async_scalar_function = PythonAsyncScalarFunction(
self._name,
bytearray(serialized_func),
j_input_types,
j_result_type,
j_function_kind,
self._deterministic,
self._takes_row_as_input,
_get_python_env(),
self._concurrency if self._concurrency is not None else -1)
if self._batch_size is not None:
j_async_scalar_function.setMaxArrowBatchSize(self._batch_size)
if self._num_gpus is not None:
j_async_scalar_function.setNumGpus(self._num_gpus)
if self._gpu_type is not None:
j_async_scalar_function.setGpuType(self._gpu_type)
return j_async_scalar_function
def _create_delegate_function(self) -> UserDefinedFunction:
return DelegatingAsyncScalarFunction(self._func)
class UserDefinedTableFunctionWrapper(UserDefinedFunctionWrapper):
"""
Wrapper for Python user-defined table function.
"""
def __init__(self, func, input_types, result_types, deterministic=None, name=None):
super(UserDefinedTableFunctionWrapper, self).__init__(
func, input_types, "general", deterministic, name)
from pyflink.table.types import RowType
if isinstance(result_types, RowType):
# DataTypes.ROW([DataTypes.FIELD("f0", DataTypes.INT()),
# DataTypes.FIELD("f1", DataTypes.BIGINT())])
result_types = result_types.field_types()
elif isinstance(result_types, str):
# ROW<f0 INT, f1 BIGINT>
result_types = result_types
elif isinstance(result_types, DataType):
# DataTypes.INT()
result_types = [result_types]
else:
# [DataTypes.INT(), DataTypes.BIGINT()]
result_types = list(result_types)
for result_type in result_types:
if not isinstance(result_type, (DataType, str)):
raise TypeError(
"Invalid result_type: result_type should be DataType or str but contains {}"
.format(result_type))
self._result_types = result_types
def _create_judf(self, serialized_func, j_input_types, j_function_kind):
gateway = get_gateway()
if isinstance(self._result_types, str):
j_result_type = self._result_types
elif isinstance(self._result_types[0], DataType):
j_result_types = java_utils.to_jarray(
gateway.jvm.DataType, [_to_java_data_type(i) for i in self._result_types])
j_result_type = gateway.jvm.DataTypes.ROW(j_result_types)
else:
j_result_type = 'Row<{0}>'.format(','.join(
['f{0} {1}'.format(i, result_type)
for i, result_type in enumerate(self._result_types)]))
PythonTableFunction = gateway.jvm \
.org.apache.flink.table.functions.python.PythonTableFunction
j_table_function = PythonTableFunction(
self._name,
bytearray(serialized_func),
j_input_types,
j_result_type,
j_function_kind,
self._deterministic,
self._takes_row_as_input,
_get_python_env())
return j_table_function
def _create_delegate_function(self) -> UserDefinedFunction:
return DelegationTableFunction(self._func)
class UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper):
"""
Wrapper for Python user-defined aggregate function or user-defined table aggregate function.
"""
def __init__(self, func, input_types, result_type, accumulator_type, func_type,
deterministic, name, is_table_aggregate=False):
super(UserDefinedAggregateFunctionWrapper, self).__init__(
func, input_types, func_type, deterministic, name)
if accumulator_type is None and func_type == "general":
accumulator_type = func.get_accumulator_type()
if result_type is None:
result_type = func.get_result_type()
if not isinstance(result_type, (DataType, str)):
raise TypeError(
"Invalid returnType: returnType should be DataType or str but is {}"
.format(result_type))
from pyflink.table.types import MapType
if func_type == 'pandas' and isinstance(result_type, MapType):
raise TypeError(
"Invalid returnType: Pandas UDAF doesn't support DataType type {} currently"
.format(result_type))
if accumulator_type is not None and not isinstance(accumulator_type, (DataType, str)):
raise TypeError(
"Invalid accumulator_type: accumulator_type should be DataType or str but is {}"
.format(accumulator_type))
if (func_type == "general" and
not (isinstance(result_type, str) and (accumulator_type, str) or
isinstance(result_type, DataType) and isinstance(accumulator_type, DataType))):
raise TypeError("result_type and accumulator_type should be DataType or str "
"at the same time.")
self._result_type = result_type
self._accumulator_type = accumulator_type
self._is_table_aggregate = is_table_aggregate
def _create_judf(self, serialized_func, j_input_types, j_function_kind):
if self._func_type == "pandas":
if isinstance(self._result_type, DataType):
from pyflink.table.types import DataTypes
self._accumulator_type = DataTypes.ARRAY(self._result_type)
else:
self._accumulator_type = 'ARRAY<{0}>'.format(self._result_type)
if j_input_types is not None:
gateway = get_gateway()
j_input_types = java_utils.to_jarray(
gateway.jvm.DataType, [_to_java_data_type(i) for i in self._input_types])
if isinstance(self._result_type, DataType):
j_result_type = _to_java_data_type(self._result_type)
else:
j_result_type = self._result_type
if isinstance(self._accumulator_type, DataType):
j_accumulator_type = _to_java_data_type(self._accumulator_type)
else:
j_accumulator_type = self._accumulator_type
gateway = get_gateway()
if self._is_table_aggregate:
PythonAggregateFunction = gateway.jvm \
.org.apache.flink.table.functions.python.PythonTableAggregateFunction
else:
PythonAggregateFunction = gateway.jvm \
.org.apache.flink.table.functions.python.PythonAggregateFunction
j_aggregate_function = PythonAggregateFunction(
self._name,
bytearray(serialized_func),
j_input_types,
j_result_type,
j_accumulator_type,
j_function_kind,
self._deterministic,
self._takes_row_as_input,
_get_python_env())
return j_aggregate_function
def _create_delegate_function(self) -> UserDefinedFunction:
assert self._func_type == 'pandas'
return DelegatingPandasAggregateFunction(self._func)
# TODO: support to configure the python execution environment
def _get_python_env():
gateway = get_gateway()
exec_type = gateway.jvm.org.apache.flink.table.functions.python.PythonEnv.ExecType.PROCESS
return gateway.jvm.org.apache.flink.table.functions.python.PythonEnv(exec_type)
def _create_udf(f, input_types, result_type, func_type, deterministic, name, concurrency=None,
batch_size=None, num_gpus=None, gpu_type=None):
if isinstance(f, AsyncScalarFunction) or inspect.iscoroutinefunction(f):
return UserDefinedAsyncScalarFunctionWrapper(
f, input_types, result_type, func_type, deterministic, name, concurrency,
batch_size, num_gpus, gpu_type)
else:
return UserDefinedScalarFunctionWrapper(
f, input_types, result_type, func_type, deterministic, name, concurrency,
batch_size, num_gpus, gpu_type)
def _create_udtf(f, input_types, result_types, deterministic, name):
return UserDefinedTableFunctionWrapper(f, input_types, result_types, deterministic, name)
def _create_udaf(f, input_types, result_type, accumulator_type, func_type, deterministic, name):
return UserDefinedAggregateFunctionWrapper(
f, input_types, result_type, accumulator_type, func_type, deterministic, name)
def _create_udtaf(f, input_types, result_type, accumulator_type, func_type, deterministic, name):
return UserDefinedAggregateFunctionWrapper(
f, input_types, result_type, accumulator_type, func_type, deterministic, name, True)
# Overload 1: function is supplied — udf returns the wrapped UDF directly.
# Covers the direct-call forms:
# udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())
# udf(MyScalarFunction())
@overload
def udf(
f: Union[Callable, ScalarFunction, AsyncScalarFunction, Type],
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ...,
result_type: Optional[Union[DataType, str]] = ...,
deterministic: Optional[bool] = ...,
name: Optional[str] = ...,
func_type: str = ...,
udf_type: Optional[str] = ...,
concurrency: Optional[int] = ...,
batch_size: Optional[int] = ...,
num_gpus: Optional[float] = ...,
gpu_type: Optional[str] = ...,
) -> Union[UserDefinedScalarFunctionWrapper, UserDefinedAsyncScalarFunctionWrapper]:
...
# Overload 2: no function supplied — udf is being used as a decorator factory
# and must return a decorator that will receive the function next. Covers:
# @udf(result_type=DataTypes.BIGINT())
# def add(i, j): ...
@overload
def udf(
f: None = ...,
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ...,
result_type: Optional[Union[DataType, str]] = ...,
deterministic: Optional[bool] = ...,
name: Optional[str] = ...,
func_type: str = ...,
udf_type: Optional[str] = ...,
concurrency: Optional[int] = ...,
batch_size: Optional[int] = ...,
num_gpus: Optional[float] = ...,
gpu_type: Optional[str] = ...,
) -> Callable[
[Union[Callable, ScalarFunction, AsyncScalarFunction, Type]],
Union[UserDefinedScalarFunctionWrapper, UserDefinedAsyncScalarFunctionWrapper],
]:
...
[docs]
def udf(f: Optional[Union[Callable, ScalarFunction, AsyncScalarFunction, Type]] = None,
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = None,
result_type: Optional[Union[DataType, str]] = None,
deterministic: Optional[bool] = None, name: Optional[str] = None,
func_type: str = "general",
udf_type: Optional[str] = None, concurrency: Optional[int] = None,
batch_size: Optional[int] = None, num_gpus: Optional[float] = None,
gpu_type: Optional[str] = None) -> Union[
UserDefinedScalarFunctionWrapper, UserDefinedAsyncScalarFunctionWrapper, Callable]:
"""
Helper method for creating a user-defined scalar function.
This decorator can automatically detect whether the function is async (defined with `async def`
or is an instance of AsyncScalarFunction).
Example:
::
>>> add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT())
>>> # The input_types is optional.
>>> @udf(result_type=DataTypes.BIGINT())
... def add(i, j):
... return i + j
>>> # Specify result_type via string.
>>> @udf(result_type='BIGINT')
... def add(i, j):
... return i + j
>>> # Async function will be automatically detected
>>> @udf(result_type=DataTypes.STRING())
... async def async_lookup(key):
... await asyncio.sleep(0.1)
... return f"value_for_{key}"
>>> class SubtractOne(ScalarFunction):
... def eval(self, i):
... return i - 1
>>> subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT())
>>> # AsyncScalarFunction will be automatically detected
>>> class AsyncLookup(AsyncScalarFunction):
... async def eval(self, key):
... await asyncio.sleep(0.1)
... return f"value_for_{key}"
>>> async_lookup = udf(AsyncLookup(), result_type=DataTypes.STRING())
:param f: lambda function, user-defined function, or async function.
:param input_types: optional, the input data types.
:param result_type: the result data type.
:param deterministic: the determinism of the function's results. True if and only if a call to
this function is guaranteed to always return the same result given the
same parameters. (default True)
:param name: the function name.
:param func_type: the type of the python function, available value: general, pandas, arrow,
(default: general)
:param concurrency: optional, the concurrency (parallelism) for the UDF operator. If specified,
the operator running this UDF will use this parallelism.
:param batch_size: optional, the maximum number of elements to buffer before processing them
in a batch. Only valid for batch-wise UDF.
:return: UserDefinedScalarFunctionWrapper, UserDefinedAsyncScalarFunctionWrapper, or function.
.. versionadded:: 1.10.0
"""
if udf_type:
import warnings
warnings.warn("The param udf_type is deprecated in 1.12. Use func_type instead.")
func_type = udf_type
if func_type not in ('general', 'pandas', 'arrow'):
raise ValueError("The func_type must be one of 'general, pandas, arrow', got %s."
% func_type)
if func_type == 'general' and batch_size is not None:
raise ValueError("batch_size is only supported for pandas or arrow UDFs.")
if concurrency is not None and (not isinstance(concurrency, int) or concurrency <= 0):
raise ValueError("concurrency must be a positive integer, got: {}".format(concurrency))
if num_gpus is not None and (not isinstance(num_gpus, (int, float)) or num_gpus <= 0):
raise ValueError("num_gpus must be a positive number, got: {}".format(num_gpus))
if num_gpus is not None and gpu_type is None:
raise ValueError("gpu_type must be specified when num_gpus is set")
# decorator
if f is None:
return functools.partial(_create_udf, input_types=input_types, result_type=result_type,
func_type=func_type, deterministic=deterministic,
name=name, concurrency=concurrency, batch_size=batch_size,
num_gpus=num_gpus, gpu_type=gpu_type)
else:
return _create_udf(f, input_types, result_type, func_type, deterministic, name,
concurrency=concurrency, batch_size=batch_size,
num_gpus=num_gpus, gpu_type=gpu_type)
# Overload 1: function is supplied — udtf returns the wrapped UDTF directly.
@overload
def udtf(
f: Union[Callable, TableFunction, Type],
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ...,
result_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ...,
deterministic: Optional[bool] = ...,
name: Optional[str] = ...,
) -> UserDefinedTableFunctionWrapper:
...
# Overload 2: no function supplied — udtf returns a decorator factory.
@overload
def udtf(
f: None = ...,
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ...,
result_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ...,
deterministic: Optional[bool] = ...,
name: Optional[str] = ...,
) -> Callable[[Union[Callable, TableFunction, Type]], UserDefinedTableFunctionWrapper]:
...
[docs]
def udtf(f: Optional[Union[Callable, TableFunction, Type]] = None,
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = None,
result_types: Optional[Union[List[DataType], DataType, str, List[str]]] = None,
deterministic: Optional[bool] = None,
name: Optional[str] = None) -> Union[UserDefinedTableFunctionWrapper, Callable]:
"""
Helper method for creating a user-defined table function.
Example:
::
>>> # The input_types is optional.
>>> @udtf(result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])
... def range_emit(s, e):
... for i in range(e):
... yield s, i
>>> # Specify result_types via string
>>> @udtf(result_types=['BIGINT', 'BIGINT'])
... def range_emit(s, e):
... for i in range(e):
... yield s, i
>>> # Specify result_types via row string
>>> @udtf(result_types='Row<a BIGINT, b BIGINT>')
... def range_emit(s, e):
... for i in range(e):
... yield s, i
>>> class MultiEmit(TableFunction):
... def eval(self, i):
... return range(i)
>>> multi_emit = udtf(MultiEmit(), DataTypes.BIGINT(), DataTypes.BIGINT())
:param f: user-defined table function.
:param input_types: optional, the input data types.
:param result_types: the result data types.
:param name: the function name.
:param deterministic: the determinism of the function's results. True if and only if a call to
this function is guaranteed to always return the same result given the
same parameters. (default True)
:return: UserDefinedTableFunctionWrapper or function.
.. versionadded:: 1.11.0
"""
# decorator
if f is None:
return functools.partial(_create_udtf, input_types=input_types, result_types=result_types,
deterministic=deterministic, name=name)
else:
return _create_udtf(f, input_types, result_types, deterministic, name)
# Overload 1: function is supplied — udaf returns the wrapped UDAF directly.
@overload
def udaf(
f: Union[Callable, AggregateFunction, Type],
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ...,
result_type: Optional[Union[DataType, str]] = ...,
accumulator_type: Optional[Union[DataType, str]] = ...,
deterministic: Optional[bool] = ...,
name: Optional[str] = ...,
func_type: str = ...,
) -> UserDefinedAggregateFunctionWrapper:
...
# Overload 2: no function supplied — udaf returns a decorator factory.
@overload
def udaf(
f: None = ...,
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ...,
result_type: Optional[Union[DataType, str]] = ...,
accumulator_type: Optional[Union[DataType, str]] = ...,
deterministic: Optional[bool] = ...,
name: Optional[str] = ...,
func_type: str = ...,
) -> Callable[[Union[Callable, AggregateFunction, Type]], UserDefinedAggregateFunctionWrapper]:
...
[docs]
def udaf(f: Optional[Union[Callable, AggregateFunction, Type]] = None,
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = None,
result_type: Optional[Union[DataType, str]] = None,
accumulator_type: Optional[Union[DataType, str]] = None,
deterministic: Optional[bool] = None, name: Optional[str] = None,
func_type: str = "general") -> Union[UserDefinedAggregateFunctionWrapper, Callable]:
"""
Helper method for creating a user-defined aggregate function.
Example:
::
>>> # The input_types is optional.
>>> @udaf(result_type=DataTypes.FLOAT(), func_type="pandas")
... def mean_udaf(v):
... return v.mean()
>>> # Specify result_type via string
>>> @udaf(result_type='FLOAT', func_type="pandas")
... def mean_udaf(v):
... return v.mean()
:param f: user-defined aggregate function.
:param input_types: optional, the input data types.
:param result_type: the result data type.
:param accumulator_type: optional, the accumulator data type.
:param deterministic: the determinism of the function's results. True if and only if a call to
this function is guaranteed to always return the same result given the
same parameters. (default True)
:param name: the function name.
:param func_type: the type of the python function, available value: general, pandas,
(default: general)
:return: UserDefinedAggregateFunctionWrapper or function.
.. versionadded:: 1.12.0
"""
if func_type not in ('general', 'pandas'):
raise ValueError("The func_type must be one of 'general, pandas', got %s."
% func_type)
# decorator
if f is None:
return functools.partial(_create_udaf, input_types=input_types, result_type=result_type,
accumulator_type=accumulator_type, func_type=func_type,
deterministic=deterministic, name=name)
else:
return _create_udaf(f, input_types, result_type, accumulator_type, func_type,
deterministic, name)
# Overload 1: function is supplied — udtaf returns the wrapped UDTAF directly.
@overload
def udtaf(
f: Union[Callable, TableAggregateFunction, Type],
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ...,
result_type: Optional[Union[DataType, str]] = ...,
accumulator_type: Optional[Union[DataType, str]] = ...,
deterministic: Optional[bool] = ...,
name: Optional[str] = ...,
func_type: str = ...,
) -> UserDefinedAggregateFunctionWrapper:
...
# Overload 2: no function supplied — udtaf returns a decorator factory.
@overload
def udtaf(
f: None = ...,
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ...,
result_type: Optional[Union[DataType, str]] = ...,
accumulator_type: Optional[Union[DataType, str]] = ...,
deterministic: Optional[bool] = ...,
name: Optional[str] = ...,
func_type: str = ...,
) -> Callable[[Union[Callable, TableAggregateFunction, Type]], UserDefinedAggregateFunctionWrapper]:
...
[docs]
def udtaf(f: Optional[Union[Callable, TableAggregateFunction, Type]] = None,
input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = None,
result_type: Optional[Union[DataType, str]] = None,
accumulator_type: Optional[Union[DataType, str]] = None,
deterministic: Optional[bool] = None, name: Optional[str] = None,
func_type: str = 'general') -> Union[UserDefinedAggregateFunctionWrapper, Callable]:
"""
Helper method for creating a user-defined table aggregate function.
Example:
::
>>> # The input_types is optional.
>>> class Top2(TableAggregateFunction):
... def emit_value(self, accumulator):
... yield Row(accumulator[0])
... yield Row(accumulator[1])
...
... def create_accumulator(self):
... return [None, None]
...
... def accumulate(self, accumulator, *args):
... if args[0] is not None:
... if accumulator[0] is None or args[0] > accumulator[0]:
... accumulator[1] = accumulator[0]
... accumulator[0] = args[0]
... elif accumulator[1] is None or args[0] > accumulator[1]:
... accumulator[1] = args[0]
...
... def retract(self, accumulator, *args):
... accumulator[0] = accumulator[0] - 1
...
... def merge(self, accumulator, accumulators):
... for other_acc in accumulators:
... self.accumulate(accumulator, other_acc[0])
... self.accumulate(accumulator, other_acc[1])
...
... def get_accumulator_type(self):
... return 'ARRAY<BIGINT>'
...
... def get_result_type(self):
... return 'ROW<a BIGINT>'
>>> top2 = udtaf(Top2())
:param f: user-defined table aggregate function.
:param input_types: optional, the input data types.
:param result_type: the result data type.
:param accumulator_type: optional, the accumulator data type.
:param deterministic: the determinism of the function's results. True if and only if a call to
this function is guaranteed to always return the same result given the
same parameters. (default True)
:param name: the function name.
:param func_type: the type of the python function, available value: general
(default: general)
:return: UserDefinedAggregateFunctionWrapper or function.
.. versionadded:: 1.13.0
"""
if func_type != 'general':
raise ValueError("The func_type must be 'general', got %s."
% func_type)
if f is None:
return functools.partial(_create_udtaf, input_types=input_types, result_type=result_type,
accumulator_type=accumulator_type, func_type=func_type,
deterministic=deterministic, name=name)
else:
return _create_udtaf(f, input_types, result_type, accumulator_type, func_type,
deterministic, name)