from typing import Any

import numpy as np
import pytest
from numpy.testing import assert_array_equal

import polars as pl
from polars.testing import assert_series_equal


@pytest.fixture(
    params=[
        ("int8", [1, 3, 2], pl.Int8, np.int8),
        ("int16", [1, 3, 2], pl.Int16, np.int16),
        ("int32", [1, 3, 2], pl.Int32, np.int32),
        ("int64", [1, 3, 2], pl.Int64, np.int64),
        ("uint8", [1, 3, 2], pl.UInt8, np.uint8),
        ("uint16", [1, 3, 2], pl.UInt16, np.uint16),
        ("uint32", [1, 3, 2], pl.UInt32, np.uint32),
        ("uint64", [1, 3, 2], pl.UInt64, np.uint64),
        ("float16", [-123.0, 0.0, 456.0], pl.Float32, np.float16),
        ("float32", [21.7, 21.8, 21], pl.Float32, np.float32),
        ("float64", [21.7, 21.8, 21], pl.Float64, np.float64),
        ("bool", [True, False, False], pl.Boolean, np.bool_),
        ("object", [21.7, "string1", object()], pl.Object, np.object_),
        ("str", ["string1", "string2", "string3"], pl.String, np.str_),
        ("intc", [1, 3, 2], pl.Int32, np.intc),
        ("uintc", [1, 3, 2], pl.UInt32, np.uintc),
        ("str_fixed", ["string1", "string2", "string3"], pl.String, np.str_),
        (
            "bytes",
            [b"byte_string1", b"byte_string2", b"byte_string3"],
            pl.Binary,
            np.bytes_,
        ),
    ]
)
def numpy_interop_test_data(request: Any) -> Any:
    return request.param


def test_df_from_numpy(numpy_interop_test_data: Any) -> None:
    name, values, pl_dtype, np_dtype = numpy_interop_test_data
    df = pl.DataFrame({name: np.array(values, dtype=np_dtype)})
    assert [pl_dtype] == df.dtypes


def test_asarray(numpy_interop_test_data: Any) -> None:
    name, values, pl_dtype, np_dtype = numpy_interop_test_data
    pl_series_to_numpy_array = np.asarray(pl.Series(name, values, pl_dtype))
    numpy_array = np.asarray(values, dtype=np_dtype)
    assert_array_equal(pl_series_to_numpy_array, numpy_array)


def test_to_numpy(numpy_interop_test_data: Any) -> None:
    name, values, pl_dtype, np_dtype = numpy_interop_test_data
    pl_series_to_numpy_array = pl.Series(name, values, pl_dtype).to_numpy()
    numpy_array = np.asarray(values, dtype=np_dtype)
    assert_array_equal(pl_series_to_numpy_array, numpy_array)


def test_numpy_to_lit() -> None:
    out = pl.select(pl.lit(np.array([1, 2, 3]))).to_series().to_list()
    assert out == [1, 2, 3]
    out = pl.select(pl.lit(np.float32(0))).to_series().to_list()
    assert out == [0.0]


def test_numpy_disambiguation() -> None:
    a = np.array([1, 2])
    df = pl.DataFrame({"a": a})
    result = df.with_columns(b=a).to_dict(as_series=False)
    expected = {
        "a": [1, 2],
        "b": [1, 2],
    }
    assert result == expected


def test_respect_dtype_with_series_from_numpy() -> None:
    assert pl.Series("foo", np.array([1, 2, 3]), dtype=pl.UInt32).dtype == pl.UInt32


@pytest.mark.parametrize(
    ("np_dtype_cls", "expected_pl_dtype"),
    [
        (np.int8, pl.Int8),
        (np.int16, pl.Int16),
        (np.int32, pl.Int32),
        (np.int64, pl.Int64),
        (np.uint8, pl.UInt8),
        (np.uint16, pl.UInt16),
        (np.uint32, pl.UInt32),
        (np.uint64, pl.UInt64),
        (np.float16, pl.Float32),  # << note: we don't currently have a native f16
        (np.float32, pl.Float32),
        (np.float64, pl.Float64),
    ],
)
def test_init_from_numpy_values(np_dtype_cls: Any, expected_pl_dtype: Any) -> None:
    # test init from raw numpy values (vs arrays)
    s = pl.Series("n", [np_dtype_cls(0), np_dtype_cls(4), np_dtype_cls(8)])
    assert s.dtype == expected_pl_dtype


def test_from_numpy_nonbit_bools_24296() -> None:
    a = np.array([24, 15, 32, 1, 0], dtype=np.uint8).view(bool)
    assert_series_equal(pl.Series(a), pl.Series([True, True, True, True, False]))
