################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
DataType for DataFrame API.
"""
import datetime
import decimal
import typing
from typing import (
Union,
List,
Dict,
Any,
get_origin,
get_args,
)
from pyflink.table.types import (
DataType as TableDataType,
DataTypes as TableDataTypes,
)
__all__ = ["DataType"]
[docs]
class DataType:
"""
DataType for DataFrame API.
This class provides a factory method style for creating data types.
It wraps the underlying Flink Table API DataType for compatibility.
Example::
>>> from pyflink.dataframe import DataType
>>> DataType.int32()
>>> DataType.int64()
>>> DataType.string()
>>> DataType.list(DataType.int32())
>>> DataType.struct({"name": DataType.string(), "age": DataType.int32()})
"""
def __init__(self, _table_type: TableDataType):
"""
Initialize DataType with underlying Table DataType.
Args:
_table_type: The underlying Flink Table API DataType.
"""
self._table_type = _table_type
def __repr__(self) -> str:
return f"DataType({self._table_type}, nullable={self._table_type._nullable})"
def __eq__(self, other) -> bool:
if not isinstance(other, DataType):
return False
return self._table_type == other._table_type
def __hash__(self) -> int:
return hash(str(self._table_type))
[docs]
def not_null(self) -> "DataType":
"""Return a non-nullable version of this type."""
return DataType(self._table_type.not_null())
[docs]
def nullable(self) -> "DataType":
"""Return a nullable version of this type."""
return DataType(self._table_type.nullable())
# ======================== Numeric Types ========================
[docs]
@classmethod
def int8(cls) -> "DataType":
"""8-bit signed integer type (TINYINT)."""
return cls(TableDataTypes.TINYINT())
[docs]
@classmethod
def int16(cls) -> "DataType":
"""16-bit signed integer type (SMALLINT)."""
return cls(TableDataTypes.SMALLINT())
[docs]
@classmethod
def int32(cls) -> "DataType":
"""32-bit signed integer type (INT)."""
return cls(TableDataTypes.INT())
[docs]
@classmethod
def int64(cls) -> "DataType":
"""64-bit signed integer type (BIGINT)."""
return cls(TableDataTypes.BIGINT())
[docs]
@classmethod
def float32(cls) -> "DataType":
"""32-bit floating point type (FLOAT)."""
return cls(TableDataTypes.FLOAT())
[docs]
@classmethod
def float64(cls) -> "DataType":
"""64-bit floating point type (DOUBLE)."""
return cls(TableDataTypes.DOUBLE())
[docs]
@classmethod
def decimal(cls, precision: int, scale: int) -> "DataType":
"""
Decimal type with specified precision and scale.
Args:
precision: Total number of digits.
scale: Number of digits after decimal point.
"""
return cls(TableDataTypes.DECIMAL(precision, scale))
# ======================== String Types ========================
[docs]
@classmethod
def string(cls) -> "DataType":
"""Variable-length string type (VARCHAR)."""
return cls(TableDataTypes.STRING())
[docs]
@classmethod
def fixed_size_string(cls, length: int) -> "DataType":
"""Fixed-length character type (CHAR(n))."""
return cls(TableDataTypes.CHAR(length))
[docs]
@classmethod
def variant(cls) -> "DataType":
"""Variant type that can store a value of any type."""
return cls(TableDataTypes.VARIANT())
# ======================== Binary Types ========================
[docs]
@classmethod
def binary(cls) -> "DataType":
"""Binary type (BYTES)."""
return cls(TableDataTypes.BYTES())
[docs]
@classmethod
def fixed_size_binary(cls, length: int) -> "DataType":
"""Fixed-length binary type (BINARY(n))."""
return cls(TableDataTypes.BINARY(length))
# ======================== Null Type ========================
[docs]
@classmethod
def null(cls) -> "DataType":
"""Null type (NULL)."""
return cls(TableDataTypes.NULL())
# ======================== Boolean Type ========================
[docs]
@classmethod
def bool(cls) -> "DataType":
"""Boolean type."""
return cls(TableDataTypes.BOOLEAN())
# Alias for bool
boolean = bool
# ======================== Date/Time Types ========================
[docs]
@classmethod
def date(cls) -> "DataType":
"""Date type (year, month, day)."""
return cls(TableDataTypes.DATE())
[docs]
@classmethod
def time(cls, precision: int = 0) -> "DataType":
"""Time type (hour, minute, second, fractional seconds)."""
return cls(TableDataTypes.TIME(precision))
[docs]
@classmethod
def timestamp(cls, precision: int = 6) -> "DataType":
"""Timestamp type without time zone."""
return cls(TableDataTypes.TIMESTAMP(precision))
[docs]
@classmethod
def timestamp_ltz(cls, precision: int = 6) -> "DataType":
"""Timestamp type with local time zone."""
return cls(TableDataTypes.TIMESTAMP_LTZ(precision))
# ======================== Composite Types ========================
[docs]
@classmethod
def list(cls, dtype: "DataType") -> "DataType":
"""
List type with element type.
Args:
dtype: The element data type.
"""
return cls(TableDataTypes.ARRAY(dtype._table_type))
[docs]
@classmethod
def map(cls, key_type: "DataType", value_type: "DataType") -> "DataType":
"""
Map type with key and value types.
Args:
key_type: The key data type.
value_type: The value data type.
"""
return cls(TableDataTypes.MAP(key_type._table_type, value_type._table_type))
[docs]
@classmethod
def struct(
cls,
fields: typing.Union[Dict[str, "DataType"], List[tuple]]
) -> "DataType":
"""
Struct/Row type with named fields.
Args:
fields: Either a dict of {name: type} or list of (name, type) tuples.
Example::
>>> DataType.struct({"name": DataType.string(), "age": DataType.int32()})
>>> DataType.struct([("name", DataType.string()), ("age", DataType.int32())])
"""
if isinstance(fields, dict):
fields = list(fields.items())
table_fields = [
TableDataTypes.FIELD(name, dtype._table_type)
for name, dtype in fields
]
return cls(TableDataTypes.ROW(table_fields))
# Alias for struct
row = struct
# ======================== Tensor Factory Methods ========================
[docs]
@classmethod
def tensor(cls, dtype: "DataType", shape: tuple) -> "DataType":
"""
Fixed-shape tensor type.
Uses the first-class ``TENSOR`` logical type so that shape metadata
is carried natively by Flink's type system rather than via a
sideband metadata mechanism.
Args:
dtype: Element type, e.g. ``DataType.float32()`` or ``DataType.int8()``.
shape: Shape tuple, e.g. ``(3, 224, 224)``. Must be fully specified
(no dynamic dimensions).
Example::
>>> DataType.tensor(DataType.float32(), shape=(3, 224, 224))
>>> DataType.tensor(DataType.int8(), shape=(224, 224, 3))
"""
table_type = TableDataTypes.TENSOR(dtype._table_type, shape)
return cls(table_type)
# ======================== Type Inference ========================
@classmethod
def _infer_from_type(cls, type_hint: Any) -> "DataType":
"""
Infer DataType from Python type hint.
Supports:
- Basic types: int, float, str, bool, bytes
- Optional[T] -> T (nullable)
- list[T] -> array(T)
- dict[K, V] -> map(K, V)
- datetime types
Args:
type_hint: Python type or type hint.
Returns:
The inferred DataType.
Raises:
TypeError: If the type hint cannot be inferred.
"""
origin = get_origin(type_hint)
args = get_args(type_hint)
# Handle Union types (including Optional)
if origin is Union:
return cls._infer_union_type(args)
# Handle list[T]
if origin is list:
if not args:
raise TypeError(
"Cannot infer DataType from list without type argument. "
"Use list[T] syntax, e.g., list[int]"
)
element_type = cls._infer_from_type(args[0])
return cls.list(element_type)
# Handle dict[K, V]
if origin is dict:
if len(args) != 2:
raise TypeError(
"Cannot infer DataType from dict without key/value types. "
"Use dict[K, V] syntax, e.g., dict[str, int]"
)
key_type = cls._infer_from_type(args[0])
value_type = cls._infer_from_type(args[1])
return cls.map(key_type, value_type)
# Handle basic types
dtype = cls._infer_basic_type(type_hint)
if dtype is not None:
return dtype
# Cannot infer
raise TypeError(
f"Cannot infer DataType from type hint '{type_hint}'. "
f"Please specify return_dtype explicitly."
)
@classmethod
def _infer_union_type(cls, args: tuple) -> "DataType":
"""Infer DataType from Union type arguments."""
# Check if it's Optional[T] (Union[T, None])
none_types = {type(None), None}
non_none_types = [t for t in args if t not in none_types]
if len(non_none_types) == 1:
# This is Optional[T]
inner_type = cls._infer_from_type(non_none_types[0])
return inner_type.nullable()
# General Union - not supported
raise TypeError(
f"Cannot infer DataType from Union type with multiple non-None types: {args}. "
f"Please specify return_dtype explicitly."
)
@classmethod
def _infer_basic_type(cls, type_hint: Any) -> typing.Optional["DataType"]:
"""Infer DataType from basic Python types."""
# Direct type mappings
type_mapping = {
bool: cls.bool(),
int: cls.int64(),
float: cls.float64(),
str: cls.string(),
bytes: cls.binary(),
bytearray: cls.binary(),
decimal.Decimal: cls.decimal(38, 18),
datetime.date: cls.date(),
datetime.time: cls.time(),
datetime.datetime: cls.timestamp(),
}
if type_hint in type_mapping:
return type_mapping[type_hint]
# Handle typing extensions
# typing.Any -> STRING as fallback
if type_hint is Any:
return cls.string()
return None
@classmethod
def _from_sql(cls, sql_type: str) -> "DataType":
"""
Create DataType from a SQL type string.
Args:
sql_type: SQL type string, e.g. ``'INT'``, ``'ARRAY<BIGINT NOT NULL>'``,
``'ROW<name STRING, age INT>'``.
Returns:
The corresponding :class:`DataType`.
Raises:
ValueError: If the string cannot be parsed.
"""
from pyflink.util.exceptions import JavaException
from pyflink.java_gateway import get_gateway
from pyflink.table.types import _from_java_data_type
try:
gateway = get_gateway()
j_logical_type = (
gateway.jvm.org.apache.flink.table.types.logical.utils
.LogicalTypeParser.parse(sql_type)
)
j_data_type = (
gateway.jvm.org.apache.flink.table.types.utils
.TypeConversions.fromLogicalToDataType(j_logical_type)
)
return cls(_from_java_data_type(j_data_type))
except JavaException as e:
raise ValueError(str(e)) from None
# Type alias for type hints
DataTypeLike = typing.Union[DataType, type, str]