#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
User-defined table function related classes and functions
"""
import pickle
import sys
import warnings
from typing import Any, Type, TYPE_CHECKING, Optional, Union
from py4j.java_gateway import JavaObject
from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError
from pyspark.rdd import PythonEvalType
from pyspark.sql.column import _to_java_column, _to_seq
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
from pyspark.sql.types import StructType, _parse_datatype_string
from pyspark.sql.udf import _wrap_function
if TYPE_CHECKING:
    from pyspark.sql._typing import ColumnOrName
    from pyspark.sql.dataframe import DataFrame
    from pyspark.sql.session import SparkSession
__all__ = ["UDTFRegistration"]
def _create_udtf(
    cls: Type,
    returnType: Union[StructType, str],
    name: Optional[str] = None,
    evalType: int = PythonEvalType.SQL_TABLE_UDF,
    deterministic: bool = False,
) -> "UserDefinedTableFunction":
    """Create a Python UDTF with the given eval type."""
    udtf_obj = UserDefinedTableFunction(
        cls, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
    )
    return udtf_obj
def _create_py_udtf(
    cls: Type,
    returnType: Union[StructType, str],
    name: Optional[str] = None,
    deterministic: bool = False,
    useArrow: Optional[bool] = None,
) -> "UserDefinedTableFunction":
    """Create a regular or an Arrow-optimized Python UDTF."""
    # Determine whether to create Arrow-optimized UDTFs.
    if useArrow is not None:
        arrow_enabled = useArrow
    else:
        from pyspark.sql import SparkSession
        session = SparkSession._instantiatedSession
        arrow_enabled = False
        if session is not None:
            value = session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")
            if isinstance(value, str) and value.lower() == "true":
                arrow_enabled = True
    eval_type: int = PythonEvalType.SQL_TABLE_UDF
    if arrow_enabled:
        # Return the regular UDTF if the required dependencies are not satisfied.
        try:
            require_minimum_pandas_version()
            require_minimum_pyarrow_version()
            eval_type = PythonEvalType.SQL_ARROW_TABLE_UDF
        except ImportError as e:
            warnings.warn(
                f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
                f"Falling back to using regular Python UDTFs.",
                UserWarning,
            )
    return _create_udtf(
        cls=cls,
        returnType=returnType,
        name=name,
        evalType=eval_type,
        deterministic=deterministic,
    )
def _validate_udtf_handler(cls: Any) -> None:
    """Validate the handler class of a UDTF."""
    if not isinstance(cls, type):
        raise PySparkTypeError(
            error_class="INVALID_UDTF_HANDLER_TYPE", message_parameters={"type": type(cls).__name__}
        )
    if not hasattr(cls, "eval"):
        raise PySparkAttributeError(
            error_class="INVALID_UDTF_NO_EVAL", message_parameters={"name": cls.__name__}
        )
[docs]class UserDefinedTableFunction:
    """
    User-defined table function in Python
    .. versionadded:: 3.5.0
    Notes
    -----
    The constructor of this class is not supposed to be directly called.
    Use :meth:`pyspark.sql.functions.udtf` to create this instance.
    This API is evolving.
    """
    def __init__(
        self,
        func: Type,
        returnType: Union[StructType, str],
        name: Optional[str] = None,
        evalType: int = PythonEvalType.SQL_TABLE_UDF,
        deterministic: bool = False,
    ):
        _validate_udtf_handler(func)
        self.func = func
        self._returnType = returnType
        self._returnType_placeholder: Optional[StructType] = None
        self._inputTypes_placeholder = None
        self._judtf_placeholder = None
        self._name = name or func.__name__
        self.evalType = evalType
        self.deterministic = deterministic
    @property
    def returnType(self) -> StructType:
        # `_parse_datatype_string` accesses to JVM for parsing a DDL formatted string.
        # This makes sure this is called after SparkContext is initialized.
        if self._returnType_placeholder is None:
            if isinstance(self._returnType, str):
                parsed = _parse_datatype_string(self._returnType)
            else:
                parsed = self._returnType
            if not isinstance(parsed, StructType):
                raise PySparkTypeError(
                    error_class="UDTF_RETURN_TYPE_MISMATCH",
                    message_parameters={
                        "name": self._name,
                        "return_type": f"{parsed}",
                    },
                )
            self._returnType_placeholder = parsed
        return self._returnType_placeholder
    @property
    def _judtf(self) -> JavaObject:
        if self._judtf_placeholder is None:
            self._judtf_placeholder = self._create_judtf(self.func)
        return self._judtf_placeholder
    def _create_judtf(self, func: Type) -> JavaObject:
        from pyspark.sql import SparkSession
        spark = SparkSession._getActiveSessionOrCreate()
        sc = spark.sparkContext
        try:
            wrapped_func = _wrap_function(sc, func)
        except pickle.PicklingError as e:
            if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e):
                raise PySparkRuntimeError(
                    error_class="UDTF_SERIALIZATION_ERROR",
                    message_parameters={
                        "name": self._name,
                        "message": "it appears that you are attempting to reference SparkSession "
                        "inside a UDTF. SparkSession can only be used on the driver, "
                        "not in code that runs on workers. Please remove the reference "
                        "and try again.",
                    },
                ) from None
            raise PySparkRuntimeError(
                error_class="UDTF_SERIALIZATION_ERROR",
                message_parameters={
                    "name": self._name,
                    "message": "Please check the stack trace and make sure the "
                    "function is serializable.",
                },
            )
        jdt = spark._jsparkSession.parseDataType(self.returnType.json())
        assert sc._jvm is not None
        judtf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(
            self._name, wrapped_func, jdt, self.evalType, self.deterministic
        )
        return judtf
    def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
        from pyspark.sql import DataFrame, SparkSession
        spark = SparkSession._getActiveSessionOrCreate()
        sc = spark.sparkContext
        judtf = self._judtf
        jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, cols, _to_java_column))
        return DataFrame(jPythonUDTF, spark)
[docs]    def asDeterministic(self) -> "UserDefinedTableFunction":
        """
        Updates UserDefinedTableFunction to deterministic.
        """
        # Explicitly clean the cache to create a JVM UDTF instance.
        self._judtf_placeholder = None
        self.deterministic = True
        return self  
[docs]class UDTFRegistration:
    """
    Wrapper for user-defined table function registration. This instance can be accessed by
    :attr:`spark.udtf` or :attr:`sqlContext.udtf`.
    .. versionadded:: 3.5.0
    """
    def __init__(self, sparkSession: "SparkSession"):
        self.sparkSession = sparkSession
[docs]    def register(
        self,
        name: str,
        f: "UserDefinedTableFunction",
    ) -> "UserDefinedTableFunction":
        """Register a Python user-defined table function as a SQL table function.
        .. versionadded:: 3.5.0
        Parameters
        ----------
        name : str
            The name of the user-defined table function in SQL statements.
        f : function or :meth:`pyspark.sql.functions.udtf`
            The user-defined table function.
        Returns
        -------
        function
            The registered user-defined table function.
        Notes
        -----
        Spark uses the return type of the given user-defined table function as the return
        type of the registered user-defined function.
        To register a nondeterministic Python table function, users need to first build
        a nondeterministic user-defined table function and then register it as a SQL function.
        Examples
        --------
        >>> from pyspark.sql.functions import udtf
        >>> @udtf(returnType="c1: int, c2: int")
        ... class PlusOne:
        ...     def eval(self, x: int):
        ...         yield x, x + 1
        ...
        >>> _ = spark.udtf.register(name="plus_one", f=PlusOne)
        >>> spark.sql("SELECT * FROM plus_one(1)").collect()
        [Row(c1=1, c2=2)]
        Use it with lateral join
        >>> spark.sql("SELECT * FROM VALUES (0, 1), (1, 2) t(x, y), LATERAL plus_one(x)").collect()
        [Row(x=0, y=1, c1=0, c2=1), Row(x=1, y=2, c1=1, c2=2)]
        """
        if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF]:
            raise PySparkTypeError(
                error_class="INVALID_UDTF_EVAL_TYPE",
                message_parameters={
                    "name": name,
                    "eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF",
                },
            )
        register_udtf = _create_udtf(
            cls=f.func,
            returnType=f.returnType,
            name=name,
            evalType=f.evalType,
            deterministic=f.deterministic,
        )
        self.sparkSession._jsparkSession.udtf().registerPython(name, register_udtf._judtf)
        return register_udtf  
def _test() -> None:
    import doctest
    from pyspark.sql import SparkSession
    import pyspark.sql.udf
    globs = pyspark.sql.udtf.__dict__.copy()
    spark = SparkSession.builder.master("local[4]").appName("sql.udtf tests").getOrCreate()
    globs["spark"] = spark
    (failure_count, test_count) = doctest.testmod(
        pyspark.sql.udtf, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
    )
    spark.stop()
    if failure_count:
        sys.exit(-1)
if __name__ == "__main__":
    _test()