Source code for pyflink.dataframe.datatype

################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
DataType for DataFrame API.
"""

import datetime
import decimal
import typing
from typing import (
    Union,
    List,
    Dict,
    Any,
    get_origin,
    get_args,
)

from pyflink.table.types import (
    DataType as TableDataType,
    DataTypes as TableDataTypes,
)


__all__ = ["DataType"]


[docs] class DataType: """ DataType for DataFrame API. This class provides a factory method style for creating data types. It wraps the underlying Flink Table API DataType for compatibility. Example:: >>> from pyflink.dataframe import DataType >>> DataType.int32() >>> DataType.int64() >>> DataType.string() >>> DataType.list(DataType.int32()) >>> DataType.struct({"name": DataType.string(), "age": DataType.int32()}) """ def __init__(self, _table_type: TableDataType): """ Initialize DataType with underlying Table DataType. Args: _table_type: The underlying Flink Table API DataType. """ self._table_type = _table_type def __repr__(self) -> str: return f"DataType({self._table_type}, nullable={self._table_type._nullable})" def __eq__(self, other) -> bool: if not isinstance(other, DataType): return False return self._table_type == other._table_type def __hash__(self) -> int: return hash(str(self._table_type))
[docs] def not_null(self) -> "DataType": """Return a non-nullable version of this type.""" return DataType(self._table_type.not_null())
[docs] def nullable(self) -> "DataType": """Return a nullable version of this type.""" return DataType(self._table_type.nullable())
# ======================== Numeric Types ========================
[docs] @classmethod def int8(cls) -> "DataType": """8-bit signed integer type (TINYINT).""" return cls(TableDataTypes.TINYINT())
[docs] @classmethod def int16(cls) -> "DataType": """16-bit signed integer type (SMALLINT).""" return cls(TableDataTypes.SMALLINT())
[docs] @classmethod def int32(cls) -> "DataType": """32-bit signed integer type (INT).""" return cls(TableDataTypes.INT())
[docs] @classmethod def int64(cls) -> "DataType": """64-bit signed integer type (BIGINT).""" return cls(TableDataTypes.BIGINT())
[docs] @classmethod def float32(cls) -> "DataType": """32-bit floating point type (FLOAT).""" return cls(TableDataTypes.FLOAT())
[docs] @classmethod def float64(cls) -> "DataType": """64-bit floating point type (DOUBLE).""" return cls(TableDataTypes.DOUBLE())
[docs] @classmethod def decimal(cls, precision: int, scale: int) -> "DataType": """ Decimal type with specified precision and scale. Args: precision: Total number of digits. scale: Number of digits after decimal point. """ return cls(TableDataTypes.DECIMAL(precision, scale))
# ======================== String Types ========================
[docs] @classmethod def string(cls) -> "DataType": """Variable-length string type (VARCHAR).""" return cls(TableDataTypes.STRING())
[docs] @classmethod def fixed_size_string(cls, length: int) -> "DataType": """Fixed-length character type (CHAR(n)).""" return cls(TableDataTypes.CHAR(length))
[docs] @classmethod def variant(cls) -> "DataType": """Variant type that can store a value of any type.""" return cls(TableDataTypes.VARIANT())
# ======================== Binary Types ========================
[docs] @classmethod def binary(cls) -> "DataType": """Binary type (BYTES).""" return cls(TableDataTypes.BYTES())
[docs] @classmethod def fixed_size_binary(cls, length: int) -> "DataType": """Fixed-length binary type (BINARY(n)).""" return cls(TableDataTypes.BINARY(length))
# ======================== Null Type ========================
[docs] @classmethod def null(cls) -> "DataType": """Null type (NULL).""" return cls(TableDataTypes.NULL())
# ======================== Boolean Type ========================
[docs] @classmethod def bool(cls) -> "DataType": """Boolean type.""" return cls(TableDataTypes.BOOLEAN())
# Alias for bool boolean = bool # ======================== Date/Time Types ========================
[docs] @classmethod def date(cls) -> "DataType": """Date type (year, month, day).""" return cls(TableDataTypes.DATE())
[docs] @classmethod def time(cls, precision: int = 0) -> "DataType": """Time type (hour, minute, second, fractional seconds).""" return cls(TableDataTypes.TIME(precision))
[docs] @classmethod def timestamp(cls, precision: int = 6) -> "DataType": """Timestamp type without time zone.""" return cls(TableDataTypes.TIMESTAMP(precision))
[docs] @classmethod def timestamp_ltz(cls, precision: int = 6) -> "DataType": """Timestamp type with local time zone.""" return cls(TableDataTypes.TIMESTAMP_LTZ(precision))
# ======================== Composite Types ========================
[docs] @classmethod def list(cls, dtype: "DataType") -> "DataType": """ List type with element type. Args: dtype: The element data type. """ return cls(TableDataTypes.ARRAY(dtype._table_type))
[docs] @classmethod def map(cls, key_type: "DataType", value_type: "DataType") -> "DataType": """ Map type with key and value types. Args: key_type: The key data type. value_type: The value data type. """ return cls(TableDataTypes.MAP(key_type._table_type, value_type._table_type))
[docs] @classmethod def struct( cls, fields: typing.Union[Dict[str, "DataType"], List[tuple]] ) -> "DataType": """ Struct/Row type with named fields. Args: fields: Either a dict of {name: type} or list of (name, type) tuples. Example:: >>> DataType.struct({"name": DataType.string(), "age": DataType.int32()}) >>> DataType.struct([("name", DataType.string()), ("age", DataType.int32())]) """ if isinstance(fields, dict): fields = list(fields.items()) table_fields = [ TableDataTypes.FIELD(name, dtype._table_type) for name, dtype in fields ] return cls(TableDataTypes.ROW(table_fields))
# Alias for struct row = struct # ======================== Tensor Factory Methods ========================
[docs] @classmethod def tensor(cls, dtype: "DataType", shape: tuple) -> "DataType": """ Fixed-shape tensor type. Uses the first-class ``TENSOR`` logical type so that shape metadata is carried natively by Flink's type system rather than via a sideband metadata mechanism. Args: dtype: Element type, e.g. ``DataType.float32()`` or ``DataType.int8()``. shape: Shape tuple, e.g. ``(3, 224, 224)``. Must be fully specified (no dynamic dimensions). Example:: >>> DataType.tensor(DataType.float32(), shape=(3, 224, 224)) >>> DataType.tensor(DataType.int8(), shape=(224, 224, 3)) """ table_type = TableDataTypes.TENSOR(dtype._table_type, shape) return cls(table_type)
# ======================== Type Inference ======================== @classmethod def _infer_from_type(cls, type_hint: Any) -> "DataType": """ Infer DataType from Python type hint. Supports: - Basic types: int, float, str, bool, bytes - Optional[T] -> T (nullable) - list[T] -> array(T) - dict[K, V] -> map(K, V) - datetime types Args: type_hint: Python type or type hint. Returns: The inferred DataType. Raises: TypeError: If the type hint cannot be inferred. """ origin = get_origin(type_hint) args = get_args(type_hint) # Handle Union types (including Optional) if origin is Union: return cls._infer_union_type(args) # Handle list[T] if origin is list: if not args: raise TypeError( "Cannot infer DataType from list without type argument. " "Use list[T] syntax, e.g., list[int]" ) element_type = cls._infer_from_type(args[0]) return cls.list(element_type) # Handle dict[K, V] if origin is dict: if len(args) != 2: raise TypeError( "Cannot infer DataType from dict without key/value types. " "Use dict[K, V] syntax, e.g., dict[str, int]" ) key_type = cls._infer_from_type(args[0]) value_type = cls._infer_from_type(args[1]) return cls.map(key_type, value_type) # Handle basic types dtype = cls._infer_basic_type(type_hint) if dtype is not None: return dtype # Cannot infer raise TypeError( f"Cannot infer DataType from type hint '{type_hint}'. " f"Please specify return_dtype explicitly." ) @classmethod def _infer_union_type(cls, args: tuple) -> "DataType": """Infer DataType from Union type arguments.""" # Check if it's Optional[T] (Union[T, None]) none_types = {type(None), None} non_none_types = [t for t in args if t not in none_types] if len(non_none_types) == 1: # This is Optional[T] inner_type = cls._infer_from_type(non_none_types[0]) return inner_type.nullable() # General Union - not supported raise TypeError( f"Cannot infer DataType from Union type with multiple non-None types: {args}. " f"Please specify return_dtype explicitly." ) @classmethod def _infer_basic_type(cls, type_hint: Any) -> typing.Optional["DataType"]: """Infer DataType from basic Python types.""" # Direct type mappings type_mapping = { bool: cls.bool(), int: cls.int64(), float: cls.float64(), str: cls.string(), bytes: cls.binary(), bytearray: cls.binary(), decimal.Decimal: cls.decimal(38, 18), datetime.date: cls.date(), datetime.time: cls.time(), datetime.datetime: cls.timestamp(), } if type_hint in type_mapping: return type_mapping[type_hint] # Handle typing extensions # typing.Any -> STRING as fallback if type_hint is Any: return cls.string() return None @classmethod def _from_sql(cls, sql_type: str) -> "DataType": """ Create DataType from a SQL type string. Args: sql_type: SQL type string, e.g. ``'INT'``, ``'ARRAY<BIGINT NOT NULL>'``, ``'ROW<name STRING, age INT>'``. Returns: The corresponding :class:`DataType`. Raises: ValueError: If the string cannot be parsed. """ from pyflink.util.exceptions import JavaException from pyflink.java_gateway import get_gateway from pyflink.table.types import _from_java_data_type try: gateway = get_gateway() j_logical_type = ( gateway.jvm.org.apache.flink.table.types.logical.utils .LogicalTypeParser.parse(sql_type) ) j_data_type = ( gateway.jvm.org.apache.flink.table.types.utils .TypeConversions.fromLogicalToDataType(j_logical_type) ) return cls(_from_java_data_type(j_data_type)) except JavaException as e: raise ValueError(str(e)) from None
# Type alias for type hints DataTypeLike = typing.Union[DataType, type, str]