Source code for pyflink.table.udf

################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import abc
import functools
import inspect
from typing import (
    Any, Awaitable, Callable, Generic, Iterable, List, Optional, TYPE_CHECKING, Type,
    TypeVar, Union, overload,
)

from pyflink.java_gateway import get_gateway
from pyflink.metrics import MetricGroup
from pyflink.table import Expression
from pyflink.table.types import DataType, _to_java_data_type
from pyflink.util import java_utils

__all__ = ['FunctionContext', 'AggregateFunction', 'ScalarFunction', 'TableFunction',
           'TableAggregateFunction', 'AsyncScalarFunction', 'udf', 'udtf', 'udaf', 'udtaf']


class FunctionContext(object):
    """
    Used to obtain global runtime information about the context in which the
    user-defined function is executed. The information includes the metric group,
    and global job parameters, etc.
    """

    def __init__(self, base_metric_group, job_parameters):
        self._base_metric_group = base_metric_group
        self._job_parameters = job_parameters

    def get_metric_group(self) -> MetricGroup:
        """
        Returns the metric group for this parallel subtask.

        .. versionadded:: 1.11.0
        """
        if self._base_metric_group is None:
            raise RuntimeError("Metric has not been enabled. You can enable "
                               "metric with the 'python.metric.enabled' configuration.")
        return self._base_metric_group

    def get_job_parameter(self, key: str, default_value: str) -> str:
        """
        Gets the global job parameter value associated with the given key as a string.

        :param key: The key pointing to the associated value.
        :param default_value: The default value which is returned in case global job parameter is
                              null or there is no value associated with the given key.

        .. versionadded:: 1.17.0
        """
        return self._job_parameters[key] if key in self._job_parameters else default_value

    def get_config_parameter(self, key: str, default_value: str) -> str:
        """
        Gets the framework configuration value associated with the given key as a string.

        This is separate from user-defined job parameters at the API boundary, even though
        Python function execution currently transports selected configuration entries through
        the same serialized parameter map.

        :param key: The key pointing to the associated configuration value.
        :param default_value: The default value returned when the key is absent.
        """
        return self._job_parameters[key] if key in self._job_parameters else default_value


class UserDefinedFunction(abc.ABC):
    """
    Base interface for user-defined function.

    .. versionadded:: 1.10.0
    """

    def open(self, function_context: FunctionContext):
        """
        Initialization method for the function. It is called before the actual working methods
        and thus suitable for one time setup work.

        :param function_context: the context of the function
        :type function_context: FunctionContext
        """
        pass

    def close(self):
        """
        Tear-down method for the user code. It is called after the last call to the main
        working methods.
        """
        pass

    def is_deterministic(self) -> bool:
        """
        Returns information about the determinism of the function's results.
        It returns true if and only if a call to this function is guaranteed to
        always return the same result given the same parameters. true is assumed by default.
        If the function is not pure functional like random(), date(), now(),
        this method must return false.

        :return: the determinism of the function's results.
        """
        return True


[docs] class ScalarFunction(UserDefinedFunction): """ Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one, or multiple scalar values to a new scalar value. .. versionadded:: 1.10.0 """ if TYPE_CHECKING: eval: Callable[..., Any] else: @abc.abstractmethod def eval(self, *args): """ Method which defines the logic of the scalar function. """ pass
class AsyncScalarFunction(UserDefinedFunction): """ Base interface for user-defined async scalar function. A user-defined async scalar function maps zero, one, or multiple scalar values to a new scalar value asynchronously. This function is similar to ScalarFunction but is executed asynchronously. It's useful when interacting with external systems (e.g., databases, REST APIs) where I/O operations would otherwise block. The eval method should be an async coroutine function that returns the result asynchronously. Example: :: >>> class AsyncLookupFunction(AsyncScalarFunction): ... async def eval(self, key): ... # Simulate async I/O operation ... await asyncio.sleep(0.1) ... return f"value_for_{key}" .. versionadded:: 2.3.0 """ if TYPE_CHECKING: eval: Callable[..., Awaitable[Any]] else: @abc.abstractmethod async def eval(self, *args): """ Async method which defines the logic of the async scalar function. This method should be an async coroutine. """ pass
[docs] class TableFunction(UserDefinedFunction): """ Base interface for user-defined table function. A user-defined table function creates zero, one, or multiple rows to a new row value. .. versionadded:: 1.11.0 """ if TYPE_CHECKING: eval: Callable[..., Iterable[Any]] else: @abc.abstractmethod def eval(self, *args): """ Method which defines the logic of the table function. """ pass
T = TypeVar('T') ACC = TypeVar('ACC') class ImperativeAggregateFunction(UserDefinedFunction, Generic[T, ACC]): """ Base interface for user-defined aggregate function and table aggregate function. This class is used for unified handling of imperative aggregating functions. Concrete implementations should extend from :class:`~pyflink.table.AggregateFunction` or :class:`~pyflink.table.TableAggregateFunction`. .. versionadded:: 1.13.0 """ @abc.abstractmethod def create_accumulator(self) -> ACC: """ Creates and initializes the accumulator for this AggregateFunction. :return: the accumulator with the initial value """ pass @abc.abstractmethod def accumulate(self, accumulator: ACC, *args): """ Processes the input values and updates the provided accumulator instance. :param accumulator: the accumulator which contains the current aggregated results :param args: the input value (usually obtained from new arrived data) """ pass def retract(self, accumulator: ACC, *args): """ Retracts the input values from the accumulator instance.The current design assumes the inputs are the values that have been previously accumulated. :param accumulator: the accumulator which contains the current aggregated results :param args: the input value (usually obtained from new arrived data). """ raise RuntimeError("Method retract is not implemented") def merge(self, accumulator: ACC, accumulators): """ Merges a group of accumulator instances into one accumulator instance. This method must be implemented for unbounded session window grouping aggregates and bounded grouping aggregates. :param accumulator: the accumulator which will keep the merged aggregate results. It should be noted that the accumulator may contain the previous aggregated results. Therefore user should not replace or clean this instance in the custom merge method. :param accumulators: a group of accumulators that will be merged. """ raise RuntimeError("Method merge is not implemented") def get_result_type(self) -> Union[DataType, str]: """ Returns the DataType of the AggregateFunction's result. :return: The :class:`~pyflink.table.types.DataType` of the AggregateFunction's result. """ raise RuntimeError("Method get_result_type is not implemented") def get_accumulator_type(self) -> Union[DataType, str]: """ Returns the DataType of the AggregateFunction's accumulator. :return: The :class:`~pyflink.table.types.DataType` of the AggregateFunction's accumulator. """ raise RuntimeError("Method get_accumulator_type is not implemented")
[docs] class AggregateFunction(ImperativeAggregateFunction): """ Base interface for user-defined aggregate function. A user-defined aggregate function maps scalar values of multiple rows to a new scalar value. .. versionadded:: 1.12.0 """ @abc.abstractmethod def get_value(self, accumulator: ACC) -> T: # type: ignore[type-var] """ Called every time when an aggregation result should be materialized. The returned value could be either an early and incomplete result (periodically emitted as data arrives) or the final result of the aggregation. :param accumulator: the accumulator which contains the current intermediate results :return: the aggregation result """ pass
[docs] class TableAggregateFunction(ImperativeAggregateFunction): """ Base class for a user-defined table aggregate function. A user-defined table aggregate function maps scalar values of multiple rows to zero, one, or multiple rows (or structured types). If an output record consists of only one field, the structured record can be omitted, and a scalar value can be emitted that will be implicitly wrapped into a row by the runtime. .. versionadded:: 1.13.0 """ @abc.abstractmethod def emit_value(self, accumulator: ACC) -> Iterable[T]: """ Called every time when an aggregation result should be materialized. The returned value could be either an early and incomplete result (periodically emitted as data arrives) or the final result of the aggregation. :param accumulator: the accumulator which contains the current aggregated results. :return: multiple aggregated result """ pass
class DelegatingScalarFunction(ScalarFunction): """ Helper scalar function implementation for lambda expression and python function. It's for internal use only. """ def __init__(self, func): self.func = func def eval(self, *args): return self.func(*args) class DelegatingAsyncScalarFunction(AsyncScalarFunction): """ Helper async scalar function implementation for async lambda expression and python async function. It's for internal use only. """ def __init__(self, func): self.func = func async def eval(self, *args): return await self.func(*args) class DelegationTableFunction(TableFunction): """ Helper table function implementation for lambda expression and python function. It's for internal use only. """ def __init__(self, func): self.func = func def eval(self, *args): return self.func(*args) class DelegatingPandasAggregateFunction(AggregateFunction): """ Helper pandas aggregate function implementation for lambda expression and python function. It's for internal use only. """ def __init__(self, func): self.func = func def get_value(self, accumulator): return accumulator[0] def create_accumulator(self): return [] def accumulate(self, accumulator, *args): accumulator.append(self.func(*args)) class PandasAggregateFunctionWrapper(object): """ Wrapper for Pandas Aggregate function. """ def __init__(self, func: AggregateFunction): self.func = func def open(self, function_context: FunctionContext): self.func.open(function_context) def eval(self, *args): accumulator = self.func.create_accumulator() self.func.accumulate(accumulator, *args) return self.func.get_value(accumulator) def close(self): self.func.close() class UserDefinedFunctionWrapper(object): """ Base Wrapper for Python user-defined function. It handles things like converting lambda functions to user-defined functions, creating the Java user-defined function representation, etc. It's for internal use only. """ def __init__(self, func, input_types, func_type, deterministic=None, name=None, concurrency=None, batch_size=None, num_gpus=None, gpu_type=None): if inspect.isclass(func) or ( not isinstance(func, UserDefinedFunction) and not callable(func)): raise TypeError( "Invalid function: not a function or callable (__call__ is not defined): {0}" .format(type(func))) if input_types is not None: from pyflink.table.types import RowType if isinstance(input_types, RowType): input_types = input_types.field_types() elif isinstance(input_types, (DataType, str)): input_types = [input_types] else: input_types = list(input_types) for input_type in input_types: if not isinstance(input_type, (DataType, str)): raise TypeError( "Invalid input_type: input_type should be DataType or str but contains {}" .format(input_type)) self._func = func self._input_types = input_types self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) if deterministic is not None and isinstance(func, UserDefinedFunction) and deterministic \ != func.is_deterministic(): raise ValueError("Inconsistent deterministic: {} and {}".format( deterministic, func.is_deterministic())) # default deterministic is True self._deterministic = deterministic if deterministic is not None else ( func.is_deterministic() if isinstance(func, UserDefinedFunction) else True) self._func_type = func_type self._judf_placeholder = None self._takes_row_as_input = False self._concurrency = concurrency self._batch_size = batch_size self._num_gpus = num_gpus self._gpu_type = gpu_type def __call__(self, *args) -> Expression: from pyflink.table import expressions as expr return expr.call(self, *args) def alias(self, *alias_names: str): self._alias_names = alias_names return self def _set_takes_row_as_input(self): self._takes_row_as_input = True return self def _java_user_defined_function(self): if self._judf_placeholder is None: gateway = get_gateway() def get_python_function_kind(): JPythonFunctionKind = gateway.jvm.org.apache.flink.table.functions.python. \ PythonFunctionKind if self._func_type == "general": return JPythonFunctionKind.GENERAL elif self._func_type == "pandas": return JPythonFunctionKind.PANDAS elif self._func_type == "arrow": return JPythonFunctionKind.ARROW else: raise TypeError("Unsupported func_type: %s." % self._func_type) if self._input_types is not None: if isinstance(self._input_types[0], str): j_input_types = java_utils.to_jarray(gateway.jvm.String, self._input_types) else: j_input_types = java_utils.to_jarray( gateway.jvm.DataType, [_to_java_data_type(i) for i in self._input_types]) else: j_input_types = None j_function_kind = get_python_function_kind() func = self._func if not isinstance(self._func, UserDefinedFunction): func = self._create_delegate_function() import cloudpickle serialized_func = cloudpickle.dumps(func) self._judf_placeholder = \ self._create_judf(serialized_func, j_input_types, j_function_kind) return self._judf_placeholder def _create_delegate_function(self) -> UserDefinedFunction: pass def _create_judf(self, serialized_func, j_input_types, j_function_kind): pass class UserDefinedScalarFunctionWrapper(UserDefinedFunctionWrapper): """ Wrapper for Python user-defined scalar function. """ def __init__(self, func, input_types, result_type, func_type, deterministic, name, concurrency=None, batch_size=None, num_gpus=None, gpu_type=None): super(UserDefinedScalarFunctionWrapper, self).__init__( func, input_types, func_type, deterministic, name, concurrency, batch_size, num_gpus, gpu_type) if not isinstance(result_type, (DataType, str)): raise TypeError( "Invalid returnType: returnType should be DataType or str but is {}".format( result_type)) self._result_type = result_type self._judf_placeholder = None def _create_judf(self, serialized_func, j_input_types, j_function_kind): gateway = get_gateway() if isinstance(self._result_type, DataType): j_result_type = _to_java_data_type(self._result_type) else: j_result_type = self._result_type PythonScalarFunction = gateway.jvm \ .org.apache.flink.table.functions.python.PythonScalarFunction j_scalar_function = PythonScalarFunction( self._name, bytearray(serialized_func), j_input_types, j_result_type, j_function_kind, self._deterministic, self._takes_row_as_input, _get_python_env()) if self._concurrency is not None: j_scalar_function.setParallelism(self._concurrency) if self._batch_size is not None: j_scalar_function.setMaxArrowBatchSize(self._batch_size) if self._num_gpus is not None: j_scalar_function.setNumGpus(self._num_gpus) if self._gpu_type is not None: j_scalar_function.setGpuType(self._gpu_type) return j_scalar_function def _create_delegate_function(self) -> UserDefinedFunction: return DelegatingScalarFunction(self._func) class UserDefinedAsyncScalarFunctionWrapper(UserDefinedFunctionWrapper): """ Wrapper for Python user-defined async scalar function. """ def __init__(self, func, input_types, result_type, func_type, deterministic, name, concurrency=None, batch_size=None, num_gpus=None, gpu_type=None): super(UserDefinedAsyncScalarFunctionWrapper, self).__init__( func, input_types, func_type, deterministic, name, concurrency, batch_size, num_gpus, gpu_type) if not isinstance(result_type, (DataType, str)): raise TypeError( "Invalid returnType: returnType should be DataType or str but is {}".format( result_type)) self._result_type = result_type self._judf_placeholder = None def _create_judf(self, serialized_func, j_input_types, j_function_kind): gateway = get_gateway() if isinstance(self._result_type, DataType): j_result_type = _to_java_data_type(self._result_type) else: j_result_type = self._result_type PythonAsyncScalarFunction = gateway.jvm \ .org.apache.flink.table.functions.python.PythonAsyncScalarFunction j_async_scalar_function = PythonAsyncScalarFunction( self._name, bytearray(serialized_func), j_input_types, j_result_type, j_function_kind, self._deterministic, self._takes_row_as_input, _get_python_env(), self._concurrency if self._concurrency is not None else -1) if self._batch_size is not None: j_async_scalar_function.setMaxArrowBatchSize(self._batch_size) if self._num_gpus is not None: j_async_scalar_function.setNumGpus(self._num_gpus) if self._gpu_type is not None: j_async_scalar_function.setGpuType(self._gpu_type) return j_async_scalar_function def _create_delegate_function(self) -> UserDefinedFunction: return DelegatingAsyncScalarFunction(self._func) class UserDefinedTableFunctionWrapper(UserDefinedFunctionWrapper): """ Wrapper for Python user-defined table function. """ def __init__(self, func, input_types, result_types, deterministic=None, name=None): super(UserDefinedTableFunctionWrapper, self).__init__( func, input_types, "general", deterministic, name) from pyflink.table.types import RowType if isinstance(result_types, RowType): # DataTypes.ROW([DataTypes.FIELD("f0", DataTypes.INT()), # DataTypes.FIELD("f1", DataTypes.BIGINT())]) result_types = result_types.field_types() elif isinstance(result_types, str): # ROW<f0 INT, f1 BIGINT> result_types = result_types elif isinstance(result_types, DataType): # DataTypes.INT() result_types = [result_types] else: # [DataTypes.INT(), DataTypes.BIGINT()] result_types = list(result_types) for result_type in result_types: if not isinstance(result_type, (DataType, str)): raise TypeError( "Invalid result_type: result_type should be DataType or str but contains {}" .format(result_type)) self._result_types = result_types def _create_judf(self, serialized_func, j_input_types, j_function_kind): gateway = get_gateway() if isinstance(self._result_types, str): j_result_type = self._result_types elif isinstance(self._result_types[0], DataType): j_result_types = java_utils.to_jarray( gateway.jvm.DataType, [_to_java_data_type(i) for i in self._result_types]) j_result_type = gateway.jvm.DataTypes.ROW(j_result_types) else: j_result_type = 'Row<{0}>'.format(','.join( ['f{0} {1}'.format(i, result_type) for i, result_type in enumerate(self._result_types)])) PythonTableFunction = gateway.jvm \ .org.apache.flink.table.functions.python.PythonTableFunction j_table_function = PythonTableFunction( self._name, bytearray(serialized_func), j_input_types, j_result_type, j_function_kind, self._deterministic, self._takes_row_as_input, _get_python_env()) return j_table_function def _create_delegate_function(self) -> UserDefinedFunction: return DelegationTableFunction(self._func) class UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper): """ Wrapper for Python user-defined aggregate function or user-defined table aggregate function. """ def __init__(self, func, input_types, result_type, accumulator_type, func_type, deterministic, name, is_table_aggregate=False): super(UserDefinedAggregateFunctionWrapper, self).__init__( func, input_types, func_type, deterministic, name) if accumulator_type is None and func_type == "general": accumulator_type = func.get_accumulator_type() if result_type is None: result_type = func.get_result_type() if not isinstance(result_type, (DataType, str)): raise TypeError( "Invalid returnType: returnType should be DataType or str but is {}" .format(result_type)) from pyflink.table.types import MapType if func_type == 'pandas' and isinstance(result_type, MapType): raise TypeError( "Invalid returnType: Pandas UDAF doesn't support DataType type {} currently" .format(result_type)) if accumulator_type is not None and not isinstance(accumulator_type, (DataType, str)): raise TypeError( "Invalid accumulator_type: accumulator_type should be DataType or str but is {}" .format(accumulator_type)) if (func_type == "general" and not (isinstance(result_type, str) and (accumulator_type, str) or isinstance(result_type, DataType) and isinstance(accumulator_type, DataType))): raise TypeError("result_type and accumulator_type should be DataType or str " "at the same time.") self._result_type = result_type self._accumulator_type = accumulator_type self._is_table_aggregate = is_table_aggregate def _create_judf(self, serialized_func, j_input_types, j_function_kind): if self._func_type == "pandas": if isinstance(self._result_type, DataType): from pyflink.table.types import DataTypes self._accumulator_type = DataTypes.ARRAY(self._result_type) else: self._accumulator_type = 'ARRAY<{0}>'.format(self._result_type) if j_input_types is not None: gateway = get_gateway() j_input_types = java_utils.to_jarray( gateway.jvm.DataType, [_to_java_data_type(i) for i in self._input_types]) if isinstance(self._result_type, DataType): j_result_type = _to_java_data_type(self._result_type) else: j_result_type = self._result_type if isinstance(self._accumulator_type, DataType): j_accumulator_type = _to_java_data_type(self._accumulator_type) else: j_accumulator_type = self._accumulator_type gateway = get_gateway() if self._is_table_aggregate: PythonAggregateFunction = gateway.jvm \ .org.apache.flink.table.functions.python.PythonTableAggregateFunction else: PythonAggregateFunction = gateway.jvm \ .org.apache.flink.table.functions.python.PythonAggregateFunction j_aggregate_function = PythonAggregateFunction( self._name, bytearray(serialized_func), j_input_types, j_result_type, j_accumulator_type, j_function_kind, self._deterministic, self._takes_row_as_input, _get_python_env()) return j_aggregate_function def _create_delegate_function(self) -> UserDefinedFunction: assert self._func_type == 'pandas' return DelegatingPandasAggregateFunction(self._func) # TODO: support to configure the python execution environment def _get_python_env(): gateway = get_gateway() exec_type = gateway.jvm.org.apache.flink.table.functions.python.PythonEnv.ExecType.PROCESS return gateway.jvm.org.apache.flink.table.functions.python.PythonEnv(exec_type) def _create_udf(f, input_types, result_type, func_type, deterministic, name, concurrency=None, batch_size=None, num_gpus=None, gpu_type=None): if isinstance(f, AsyncScalarFunction) or inspect.iscoroutinefunction(f): return UserDefinedAsyncScalarFunctionWrapper( f, input_types, result_type, func_type, deterministic, name, concurrency, batch_size, num_gpus, gpu_type) else: return UserDefinedScalarFunctionWrapper( f, input_types, result_type, func_type, deterministic, name, concurrency, batch_size, num_gpus, gpu_type) def _create_udtf(f, input_types, result_types, deterministic, name): return UserDefinedTableFunctionWrapper(f, input_types, result_types, deterministic, name) def _create_udaf(f, input_types, result_type, accumulator_type, func_type, deterministic, name): return UserDefinedAggregateFunctionWrapper( f, input_types, result_type, accumulator_type, func_type, deterministic, name) def _create_udtaf(f, input_types, result_type, accumulator_type, func_type, deterministic, name): return UserDefinedAggregateFunctionWrapper( f, input_types, result_type, accumulator_type, func_type, deterministic, name, True) # Overload 1: function is supplied — udf returns the wrapped UDF directly. # Covers the direct-call forms: # udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()) # udf(MyScalarFunction()) @overload def udf( f: Union[Callable, ScalarFunction, AsyncScalarFunction, Type], input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ..., result_type: Optional[Union[DataType, str]] = ..., deterministic: Optional[bool] = ..., name: Optional[str] = ..., func_type: str = ..., udf_type: Optional[str] = ..., concurrency: Optional[int] = ..., batch_size: Optional[int] = ..., num_gpus: Optional[float] = ..., gpu_type: Optional[str] = ..., ) -> Union[UserDefinedScalarFunctionWrapper, UserDefinedAsyncScalarFunctionWrapper]: ... # Overload 2: no function supplied — udf is being used as a decorator factory # and must return a decorator that will receive the function next. Covers: # @udf(result_type=DataTypes.BIGINT()) # def add(i, j): ... @overload def udf( f: None = ..., input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ..., result_type: Optional[Union[DataType, str]] = ..., deterministic: Optional[bool] = ..., name: Optional[str] = ..., func_type: str = ..., udf_type: Optional[str] = ..., concurrency: Optional[int] = ..., batch_size: Optional[int] = ..., num_gpus: Optional[float] = ..., gpu_type: Optional[str] = ..., ) -> Callable[ [Union[Callable, ScalarFunction, AsyncScalarFunction, Type]], Union[UserDefinedScalarFunctionWrapper, UserDefinedAsyncScalarFunctionWrapper], ]: ...
[docs] def udf(f: Optional[Union[Callable, ScalarFunction, AsyncScalarFunction, Type]] = None, input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = None, result_type: Optional[Union[DataType, str]] = None, deterministic: Optional[bool] = None, name: Optional[str] = None, func_type: str = "general", udf_type: Optional[str] = None, concurrency: Optional[int] = None, batch_size: Optional[int] = None, num_gpus: Optional[float] = None, gpu_type: Optional[str] = None) -> Union[ UserDefinedScalarFunctionWrapper, UserDefinedAsyncScalarFunctionWrapper, Callable]: """ Helper method for creating a user-defined scalar function. This decorator can automatically detect whether the function is async (defined with `async def` or is an instance of AsyncScalarFunction). Example: :: >>> add_one = udf(lambda i: i + 1, DataTypes.BIGINT(), DataTypes.BIGINT()) >>> # The input_types is optional. >>> @udf(result_type=DataTypes.BIGINT()) ... def add(i, j): ... return i + j >>> # Specify result_type via string. >>> @udf(result_type='BIGINT') ... def add(i, j): ... return i + j >>> # Async function will be automatically detected >>> @udf(result_type=DataTypes.STRING()) ... async def async_lookup(key): ... await asyncio.sleep(0.1) ... return f"value_for_{key}" >>> class SubtractOne(ScalarFunction): ... def eval(self, i): ... return i - 1 >>> subtract_one = udf(SubtractOne(), DataTypes.BIGINT(), DataTypes.BIGINT()) >>> # AsyncScalarFunction will be automatically detected >>> class AsyncLookup(AsyncScalarFunction): ... async def eval(self, key): ... await asyncio.sleep(0.1) ... return f"value_for_{key}" >>> async_lookup = udf(AsyncLookup(), result_type=DataTypes.STRING()) :param f: lambda function, user-defined function, or async function. :param input_types: optional, the input data types. :param result_type: the result data type. :param deterministic: the determinism of the function's results. True if and only if a call to this function is guaranteed to always return the same result given the same parameters. (default True) :param name: the function name. :param func_type: the type of the python function, available value: general, pandas, arrow, (default: general) :param concurrency: optional, the concurrency (parallelism) for the UDF operator. If specified, the operator running this UDF will use this parallelism. :param batch_size: optional, the maximum number of elements to buffer before processing them in a batch. Only valid for batch-wise UDF. :return: UserDefinedScalarFunctionWrapper, UserDefinedAsyncScalarFunctionWrapper, or function. .. versionadded:: 1.10.0 """ if udf_type: import warnings warnings.warn("The param udf_type is deprecated in 1.12. Use func_type instead.") func_type = udf_type if func_type not in ('general', 'pandas', 'arrow'): raise ValueError("The func_type must be one of 'general, pandas, arrow', got %s." % func_type) if func_type == 'general' and batch_size is not None: raise ValueError("batch_size is only supported for pandas or arrow UDFs.") if concurrency is not None and (not isinstance(concurrency, int) or concurrency <= 0): raise ValueError("concurrency must be a positive integer, got: {}".format(concurrency)) if num_gpus is not None and (not isinstance(num_gpus, (int, float)) or num_gpus <= 0): raise ValueError("num_gpus must be a positive number, got: {}".format(num_gpus)) if num_gpus is not None and gpu_type is None: raise ValueError("gpu_type must be specified when num_gpus is set") # decorator if f is None: return functools.partial(_create_udf, input_types=input_types, result_type=result_type, func_type=func_type, deterministic=deterministic, name=name, concurrency=concurrency, batch_size=batch_size, num_gpus=num_gpus, gpu_type=gpu_type) else: return _create_udf(f, input_types, result_type, func_type, deterministic, name, concurrency=concurrency, batch_size=batch_size, num_gpus=num_gpus, gpu_type=gpu_type)
# Overload 1: function is supplied — udtf returns the wrapped UDTF directly. @overload def udtf( f: Union[Callable, TableFunction, Type], input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ..., result_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ..., deterministic: Optional[bool] = ..., name: Optional[str] = ..., ) -> UserDefinedTableFunctionWrapper: ... # Overload 2: no function supplied — udtf returns a decorator factory. @overload def udtf( f: None = ..., input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ..., result_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ..., deterministic: Optional[bool] = ..., name: Optional[str] = ..., ) -> Callable[[Union[Callable, TableFunction, Type]], UserDefinedTableFunctionWrapper]: ...
[docs] def udtf(f: Optional[Union[Callable, TableFunction, Type]] = None, input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = None, result_types: Optional[Union[List[DataType], DataType, str, List[str]]] = None, deterministic: Optional[bool] = None, name: Optional[str] = None) -> Union[UserDefinedTableFunctionWrapper, Callable]: """ Helper method for creating a user-defined table function. Example: :: >>> # The input_types is optional. >>> @udtf(result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]) ... def range_emit(s, e): ... for i in range(e): ... yield s, i >>> # Specify result_types via string >>> @udtf(result_types=['BIGINT', 'BIGINT']) ... def range_emit(s, e): ... for i in range(e): ... yield s, i >>> # Specify result_types via row string >>> @udtf(result_types='Row<a BIGINT, b BIGINT>') ... def range_emit(s, e): ... for i in range(e): ... yield s, i >>> class MultiEmit(TableFunction): ... def eval(self, i): ... return range(i) >>> multi_emit = udtf(MultiEmit(), DataTypes.BIGINT(), DataTypes.BIGINT()) :param f: user-defined table function. :param input_types: optional, the input data types. :param result_types: the result data types. :param name: the function name. :param deterministic: the determinism of the function's results. True if and only if a call to this function is guaranteed to always return the same result given the same parameters. (default True) :return: UserDefinedTableFunctionWrapper or function. .. versionadded:: 1.11.0 """ # decorator if f is None: return functools.partial(_create_udtf, input_types=input_types, result_types=result_types, deterministic=deterministic, name=name) else: return _create_udtf(f, input_types, result_types, deterministic, name)
# Overload 1: function is supplied — udaf returns the wrapped UDAF directly. @overload def udaf( f: Union[Callable, AggregateFunction, Type], input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ..., result_type: Optional[Union[DataType, str]] = ..., accumulator_type: Optional[Union[DataType, str]] = ..., deterministic: Optional[bool] = ..., name: Optional[str] = ..., func_type: str = ..., ) -> UserDefinedAggregateFunctionWrapper: ... # Overload 2: no function supplied — udaf returns a decorator factory. @overload def udaf( f: None = ..., input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ..., result_type: Optional[Union[DataType, str]] = ..., accumulator_type: Optional[Union[DataType, str]] = ..., deterministic: Optional[bool] = ..., name: Optional[str] = ..., func_type: str = ..., ) -> Callable[[Union[Callable, AggregateFunction, Type]], UserDefinedAggregateFunctionWrapper]: ...
[docs] def udaf(f: Optional[Union[Callable, AggregateFunction, Type]] = None, input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = None, result_type: Optional[Union[DataType, str]] = None, accumulator_type: Optional[Union[DataType, str]] = None, deterministic: Optional[bool] = None, name: Optional[str] = None, func_type: str = "general") -> Union[UserDefinedAggregateFunctionWrapper, Callable]: """ Helper method for creating a user-defined aggregate function. Example: :: >>> # The input_types is optional. >>> @udaf(result_type=DataTypes.FLOAT(), func_type="pandas") ... def mean_udaf(v): ... return v.mean() >>> # Specify result_type via string >>> @udaf(result_type='FLOAT', func_type="pandas") ... def mean_udaf(v): ... return v.mean() :param f: user-defined aggregate function. :param input_types: optional, the input data types. :param result_type: the result data type. :param accumulator_type: optional, the accumulator data type. :param deterministic: the determinism of the function's results. True if and only if a call to this function is guaranteed to always return the same result given the same parameters. (default True) :param name: the function name. :param func_type: the type of the python function, available value: general, pandas, (default: general) :return: UserDefinedAggregateFunctionWrapper or function. .. versionadded:: 1.12.0 """ if func_type not in ('general', 'pandas'): raise ValueError("The func_type must be one of 'general, pandas', got %s." % func_type) # decorator if f is None: return functools.partial(_create_udaf, input_types=input_types, result_type=result_type, accumulator_type=accumulator_type, func_type=func_type, deterministic=deterministic, name=name) else: return _create_udaf(f, input_types, result_type, accumulator_type, func_type, deterministic, name)
# Overload 1: function is supplied — udtaf returns the wrapped UDTAF directly. @overload def udtaf( f: Union[Callable, TableAggregateFunction, Type], input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ..., result_type: Optional[Union[DataType, str]] = ..., accumulator_type: Optional[Union[DataType, str]] = ..., deterministic: Optional[bool] = ..., name: Optional[str] = ..., func_type: str = ..., ) -> UserDefinedAggregateFunctionWrapper: ... # Overload 2: no function supplied — udtaf returns a decorator factory. @overload def udtaf( f: None = ..., input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = ..., result_type: Optional[Union[DataType, str]] = ..., accumulator_type: Optional[Union[DataType, str]] = ..., deterministic: Optional[bool] = ..., name: Optional[str] = ..., func_type: str = ..., ) -> Callable[[Union[Callable, TableAggregateFunction, Type]], UserDefinedAggregateFunctionWrapper]: ...
[docs] def udtaf(f: Optional[Union[Callable, TableAggregateFunction, Type]] = None, input_types: Optional[Union[List[DataType], DataType, str, List[str]]] = None, result_type: Optional[Union[DataType, str]] = None, accumulator_type: Optional[Union[DataType, str]] = None, deterministic: Optional[bool] = None, name: Optional[str] = None, func_type: str = 'general') -> Union[UserDefinedAggregateFunctionWrapper, Callable]: """ Helper method for creating a user-defined table aggregate function. Example: :: >>> # The input_types is optional. >>> class Top2(TableAggregateFunction): ... def emit_value(self, accumulator): ... yield Row(accumulator[0]) ... yield Row(accumulator[1]) ... ... def create_accumulator(self): ... return [None, None] ... ... def accumulate(self, accumulator, *args): ... if args[0] is not None: ... if accumulator[0] is None or args[0] > accumulator[0]: ... accumulator[1] = accumulator[0] ... accumulator[0] = args[0] ... elif accumulator[1] is None or args[0] > accumulator[1]: ... accumulator[1] = args[0] ... ... def retract(self, accumulator, *args): ... accumulator[0] = accumulator[0] - 1 ... ... def merge(self, accumulator, accumulators): ... for other_acc in accumulators: ... self.accumulate(accumulator, other_acc[0]) ... self.accumulate(accumulator, other_acc[1]) ... ... def get_accumulator_type(self): ... return 'ARRAY<BIGINT>' ... ... def get_result_type(self): ... return 'ROW<a BIGINT>' >>> top2 = udtaf(Top2()) :param f: user-defined table aggregate function. :param input_types: optional, the input data types. :param result_type: the result data type. :param accumulator_type: optional, the accumulator data type. :param deterministic: the determinism of the function's results. True if and only if a call to this function is guaranteed to always return the same result given the same parameters. (default True) :param name: the function name. :param func_type: the type of the python function, available value: general (default: general) :return: UserDefinedAggregateFunctionWrapper or function. .. versionadded:: 1.13.0 """ if func_type != 'general': raise ValueError("The func_type must be 'general', got %s." % func_type) if f is None: return functools.partial(_create_udtaf, input_types=input_types, result_type=result_type, accumulator_type=accumulator_type, func_type=func_type, deterministic=deterministic, name=name) else: return _create_udtaf(f, input_types, result_type, accumulator_type, func_type, deterministic, name)