# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import re

import numpy as np
import pandas as pd

from nsys_recipe.lib import cpu_perf, helpers, nvtx

_MIN_CPU_SAMPLE_COUNT_THRESHOLD = 3
_INSTANCE_DURATION_DIFF_THRESHOLD = 0.1


def get_cpu_arch(target_info_df):
    cpu_arch_df = target_info_df[target_info_df["name"] == "CpuArchitecture"]
    if cpu_arch_df.empty:
        return cpu_perf.Architecture.UNKNOWN

    cpu_arch = cpu_arch_df.iloc[0]["value"]

    tegra_based_device_df = target_info_df[
        target_info_df["name"] == "IsTegraBasedDevice"
    ]

    is_tegra_based_device = (
        tegra_based_device_df.iloc[0]["value"]
        if not tegra_based_device_df.empty
        else False
    )

    if cpu_arch == "aarch64":
        if is_tegra_based_device:
            return cpu_perf.Architecture.AARCH64_TEGRA
        return cpu_perf.Architecture.AARCH64_SBSA

    if cpu_arch == "x86_64":
        return cpu_perf.Architecture.X86_64

    # Return the raw value from the TARGET_INFO_SYSTEM_ENV
    # database table.
    return cpu_arch


def _get_indices_of_too_short_nvtx_ranges(cpu_sample_count_df):
    # Mark an NVTX range as too short if there is no range portion
    # running on a particular CPU that contains CPU samples
    # => `_MIN_CPU_SAMPLE_COUNT_THRESHOLD`.

    def calc_min(row):
        # example of `row`:
        # <perf_event>_<cpu>_sample_count          <count>
        # OP_RETIRED_48_sample_count             85.279287
        # OP_SPEC_48_sample_count                85.279287
        # OP_RETIRED_55_sample_count             85.187535
        # OP_SPEC_55_sample_count                85.187535

        # `max_sample_counts` - dictionary with the content:
        # { perf event name: maximum number of samples among CPUs }
        # Example of `max_sample_counts`:
        # { "OP_RETIRED": 85.279287, "OP_SPEC": 85.279287 }
        row = row.dropna()
        max_sample_counts = {}
        for col in row.keys():
            # `col` structure: <perf_event>_<cpu>_sample_count
            regex_res = re.match(r"(.*)_\d*_sample_count", col)
            perf_event = regex_res.group(1)
            sample_count = row[col]
            if (
                perf_event not in max_sample_counts
                or sample_count > max_sample_counts[perf_event]
            ):
                max_sample_counts[perf_event] = sample_count

        # the minimum number of samples among perf events.
        return min(max_sample_counts.values(), default=0)

    too_short = cpu_sample_count_df[
        cpu_sample_count_df.apply(calc_min, axis=1) < _MIN_CPU_SAMPLE_COUNT_THRESHOLD
    ]
    return too_short.index


def get_nvtx_w_cpu_events_n_stack(
    nvtx_df, core_perf_df, thread_sched_df, rely_on_nvtx_tid, agg_parallel_nvtx
):
    """
    Get NVTX enriched with CPU Perf events and NVTX stack information.
    """

    def enrich_w_cpu_events(nvtx_df):
        cpu_events, cpu_samples = cpu_perf.compute_core_perf_events(
            nvtx_df, core_perf_df, thread_sched_df, rely_on_nvtx_tid
        )

        cpu_samples = cpu_samples.add_suffix("_sample_count")
        return nvtx_df.join(cpu_events).join(cpu_samples)

    nvtx_df = nvtx.compute_callstack(nvtx_df)

    if rely_on_nvtx_tid:
        nvtx_df = enrich_w_cpu_events(nvtx_df)

    if agg_parallel_nvtx:
        nvtx_df = nvtx.consolidate_parallel_ranges(nvtx_df)

    if not rely_on_nvtx_tid:
        nvtx_df = enrich_w_cpu_events(nvtx_df)

    cpu_samples = nvtx_df.filter(like="_sample_count")
    nvtx_df = nvtx_df.drop(columns=cpu_samples.columns)
    nvtx_indices = _get_indices_of_too_short_nvtx_ranges(cpu_samples)
    nvtx_df.loc[nvtx_indices, "tooShort"] = 1

    return nvtx_df


def create_nvtx_grouper(nvtx_df, nvtx_grouping_strategy):
    """
    Create an NVTX grouper based on the provided NVTX grouping strategy
    for the provided NVTX data frame.
    This grouper can be reused to aggregate already grouped data in various ways.
    """
    nvtx_df = nvtx_df.copy()

    if "originalIndices" not in nvtx_df.columns:
        nvtx_df = nvtx.add_original_indices(nvtx_df)

    nvtx_df["instances"] = 1
    nvtx_df["instDuration"] = nvtx_df["end"] - nvtx_df["start"]
    nvtx_df["instDurationStd"] = nvtx_df["instDuration"]

    # NVTX regions are sorted by instDuration
    # to calculate mean based on this metric.
    nvtx_df.sort_values(by="instDuration", inplace=True)

    return nvtx.NvtxGrouper(nvtx_df, nvtx_grouping_strategy)


def aggregate_nvtx_ranges(nvtx_grouper, grouping_key_cols):
    """
    Aggregate NVTX ranges for the provided NVTX grouper.
    The resulting aggregated NVTX ranges are sorted by the minimal original NVTX index
    to preserve the call order.
    """

    def middle(x):
        x = x.reset_index(drop=True)
        return x.iloc[len(x) // 2]

    def std(x):
        expected_mean = middle(x)
        sum_of_squares = sum(np.float64(val - expected_mean) ** 2 for val in x)
        return np.sqrt(sum_of_squares / len(x))

    col_to_agg_func_map = {
        "originalIndices": "sum",
        "instances": "sum",
        "instDuration": middle,
        "instDurationStd": std,
    }
    agg_nvtx_df = nvtx_grouper.aggregate(col_to_agg_func_map)
    # Agg. NVTX ranges are sorted by the minimal original NVTX index
    # to preserve the call order.
    agg_nvtx_df["minOriginalIndex"] = agg_nvtx_df["originalIndices"].apply(min)
    agg_nvtx_df = agg_nvtx_df.sort_values(by="minOriginalIndex")

    required_cols = [
        *grouping_key_cols,
        "text",
        "instances",
        "instDuration",
        "instDurationStd",
    ]
    return agg_nvtx_df[required_cols].reset_index(drop=True)


def number_nvtx_instances(nvtx_grouper, inst_idx_col):
    """
    Sequentially number the NVTX instances in each group.
    The resulting data frame contains the NVTX range ID and the instance index.
    """
    dfs = []
    for key, group_df in nvtx_grouper.df_grouped:
        group_df = group_df.sort_values(by="rangeId").reset_index(drop=True)
        group_df[inst_idx_col] = group_df.index
        dfs.append(group_df[["rangeId", inst_idx_col]])
    return pd.concat(dfs)


def _fill_callstack_path(df, td_key_cols):
    td_key_col, par_td_key_col = td_key_cols

    stack = [(-1, [])]
    while len(stack) > 0:
        id, path = stack.pop()

        mask = df[par_td_key_col] == id
        df.loc[mask, "callstackPath"] = df.loc[mask, "text"].apply(lambda x: path + [x])

        child_df = df[mask]
        if child_df.empty:
            continue

        for _, range in child_df.iterrows():
            new_path = path + [range["text"]]
            stack.append((range[td_key_col], new_path))


def add_callstack_to_duplicated_names(df, td_key_cols):
    """
    Add callstack information to the NVTX range names in case of duplicated names.
    """

    def create_callstack_path(callstack):
        left_arrow_sign = "\u2190"
        return f" {left_arrow_sign} ".join(reversed(callstack))

    _fill_callstack_path(df, td_key_cols)

    non_unique_mask = df["text"].duplicated(keep=False)
    df.loc[non_unique_mask, "text"] = df.loc[non_unique_mask, "callstackPath"].apply(
        create_callstack_path
    )
    df.drop(columns=["callstackPath"], inplace=True)
    return df


def _check_nvtx_stability_btw_nsys_reps(nvtx_summary_dfs, merge_by_cols):
    dfs = [df[["instances", "instDuration", *merge_by_cols]] for df in nvtx_summary_dfs]
    merged_df = helpers.merge(dfs, merge_by_cols, how="outer")

    inst_count_means = merged_df.filter(like="instances").mean(axis=1)
    inst_dur_means = merged_df.filter(like="instDuration").mean(axis=1)

    inst_count_stability_df = pd.DataFrame(merged_df[merge_by_cols])
    inst_dur_stability_df = pd.DataFrame(merged_df[merge_by_cols])

    for idx, df in enumerate(nvtx_summary_dfs):
        report = df["report"].iloc[0]
        inst_count_col = f"instances#{idx}"
        inst_dur_col = f"instDuration#{idx}"

        def is_inst_count_stable(x):
            if np.isnan(x[inst_count_col]):
                return np.nan
            return x[inst_count_col] == inst_count_means[x.name]

        inst_count_stability_df[report] = merged_df.apply(is_inst_count_stable, axis=1)

        def is_inst_dur_stable(x):
            if np.isnan(x[inst_dur_col]):
                return np.nan
            inst_dur_diff = abs(x[inst_dur_col] - inst_dur_means[x.name])
            return (
                inst_dur_diff / x[inst_dur_col]
            ) <= _INSTANCE_DURATION_DIFF_THRESHOLD

        inst_dur_stability_df[report] = merged_df.apply(is_inst_dur_stable, axis=1)

    def fill_report_names(row, filter_mask, res_column):
        res = row[merge_by_cols]
        res[res_column] = row[filter_mask].index.to_list()
        return res

    range_absent_in_reports_df = inst_count_stability_df.apply(
        lambda x: fill_report_names(x, x.isna(), "absentInReports"), axis=1
    )
    inst_count_differs_in_reports_df = inst_count_stability_df.apply(
        lambda x: fill_report_names(x, x == False, "instCountDiffersInReports"), axis=1
    )
    inst_dur_differs_in_reports_df = inst_dur_stability_df.apply(
        lambda x: fill_report_names(x, x == False, "instDurDiffersInReports"), axis=1
    )

    for idx, df in enumerate(nvtx_summary_dfs):
        df = df.merge(range_absent_in_reports_df, on=merge_by_cols)
        df = df.merge(inst_count_differs_in_reports_df, on=merge_by_cols)
        df = df.merge(inst_dur_differs_in_reports_df, on=merge_by_cols)
        nvtx_summary_dfs[idx] = df


def _compute_nvtx_summary(nvtx_summary_dfs, td_key_cols):
    _check_nvtx_stability_btw_nsys_reps(nvtx_summary_dfs, td_key_cols)
    return pd.concat(nvtx_summary_dfs)


def _reflect_callstack_in_name(df, td_key_cols):
    def add_nesting_level(callstack):
        branch_sign = "\u2514"  # └
        parent_sign = branch_sign + "─ "
        nbsp_sign = "\u00a0"  # non-breaking space
        prev_parent_sign = nbsp_sign * 5
        callstack_len = len(callstack)
        if callstack_len == 1:
            return callstack[-1]
        elif len(callstack) > 1:
            return prev_parent_sign * (len(callstack) - 2) + parent_sign + callstack[-1]
        return ""

    _fill_callstack_path(df, td_key_cols)

    df["text"] = df["callstackPath"].apply(add_nesting_level)
    df.drop(columns=["callstackPath"], inplace=True)

    return df


def _prepare_nvtx_summary_for_display(df, td_key_cols):
    def add_warning_sign(row, column, column_to_check):
        # Add warning sign only if the NVTX range is not filtered out
        # because it's too short or is absent in some reports.
        if row["tooShort"] or len(row["absentInReports"]) > 0:
            return row[column]
        warn_sign = "\u26a0"  # ⚠
        return (
            f"{warn_sign} {row[column]}"
            if len(row[column_to_check]) > 0
            else row[column]
        )

    df["instances"] = df["instances"].astype(str)
    df["instances"] = df.apply(
        lambda row: add_warning_sign(row, "instances", "instCountDiffersInReports"),
        axis=1,
    )

    def add_std_to_inst_duration(row):
        inst_dur_ms = row["instDuration"] * 1e-6
        inst_dur_std_ms = row["instDurationStd"] * 1e-6
        res = f"{round(inst_dur_ms, 2):.2f} ms"
        if inst_dur_std_ms > 0:
            res += f" ± {round(inst_dur_std_ms, 2):.2f} ms"
        return res

    df["instDuration"] = df.apply(add_std_to_inst_duration, axis=1)
    df["instDuration"] = df.apply(
        lambda row: add_warning_sign(row, "instDuration", "instDurDiffersInReports"),
        axis=1,
    )

    def add_note(row):
        if row["tooShort"]:
            return (
                "This range is filtered out, because it contains fewer than "
                f"{_MIN_CPU_SAMPLE_COUNT_THRESHOLD} PMU samples "
                "in at least one report."
            )

        if len(row["absentInReports"]) > 0:
            reports = row["absentInReports"]
            return (
                "This range is filtered out, "
                f"because it is absent in {', '.join(reports)}."
            )

        note = ""
        if len(row["instCountDiffersInReports"]) > 0:
            reports = row["instCountDiffersInReports"]
            if row["report"] in reports:
                reports = ["<i>this report</i>"] + reports
                reports.remove(row["report"])
            note += (
                f"The number of instances in {', '.join(reports)}"
                " differs from the mean value for all reports."
            )

        if len(row["instDurDiffersInReports"]) > 0:
            reports = row["instDurDiffersInReports"]
            if row["report"] in reports:
                reports = ["<i>this report</i>"] + reports
                reports.remove(row["report"])
            if note != "":
                note += "<br/>"
            note += (
                f"The instance duration in {', '.join(reports)}"
                " differs from the mean value for all reports by more than "
                f"{_INSTANCE_DURATION_DIFF_THRESHOLD * 100}%."
            )
        return note

    notes = df.apply(add_note, axis=1)
    if not notes[notes != ""].empty:
        df["Notes"] = notes

    df = _reflect_callstack_in_name(df, td_key_cols)

    df.rename(
        columns={
            "text": "NVTX Range",
            "instances": "Count",
            "instDuration": "Instance Duration",
            "report": "Report",
        },
        inplace=True,
    )

    required_columns = ["NVTX Range", "Count", "Instance Duration", "Report"]
    if "Notes" in df.columns:
        required_columns.append("Notes")

    return df[required_columns]


def create_nvtx_summary(nvtx_summary_dfs, td_key_cols):
    """
    Create final NVTX summary based on the provided NVTX summary dataframes
    for each NSys report.
    """
    df = _compute_nvtx_summary(nvtx_summary_dfs, td_key_cols)
    return _prepare_nvtx_summary_for_display(df, td_key_cols)


def get_only_main_thread_nvtx_warning():
    return "Only NVTX ranges from the main thread are processed in the recipe."


def get_all_nvtx_too_short_warning():
    return (
        "All NVTX ranges are filtered out from processing since "
        f"they contain fewer than {_MIN_CPU_SAMPLE_COUNT_THRESHOLD} PMU samples "
        "in at least one report."
    )


def log_details(
    details_log_path,
    filename,
    nvtx_df=None,
    nvtx_selected_df=None,
    nvtx_summary_df=None,
    cpu_metrics_df=None,
):
    """
    Log details on NVTX ranges with CPU events and metrics
    to the provided details log path.
    """
    with open(details_log_path, "a") as f:
        f.write(f"res: {filename}\n")
        if nvtx_df is not None:
            f.write(
                "\nNVTX Ranges or "
                "Aggregated NVTX Ranges (--aggregate-parallel-nvtx-ranges):\n"
            )
            f.write(nvtx_df.to_string())
        if nvtx_selected_df is not None:
            f.write("\nNVTX / NVTX Aggregated Ranges selected " "from the report:\n")
            f.write(nvtx_selected_df.to_string())
        if nvtx_summary_df is not None:
            f.write("\nNVTX Summary:\n")
            f.write(nvtx_summary_df.to_string())
        if cpu_metrics_df is not None:
            f.write("\nCPU Metrics:\n")
            f.write(cpu_metrics_df.to_string())
        f.write("\n\n")
