#!/usr/bin/env python
import logging
import math
import re
import sys
import time
from dataclasses import dataclass, field
from enum import Enum, auto

logger = logging.getLogger(__name__)


@dataclass
class Frames:
    start: float = 0
    end: float = 0


@dataclass
class FrameDurations:
    start: float = 0
    duration: float = 0
    # optional metrics
    min_value: float = 0
    max_value: float = 0
    total_value: float = 0

@dataclass
class TimeSlice:
    """TimeSlice class to hold timeslice info from OS Sched events"""
    start: float = 0
    end: float = 0
    cpu: int = 0
    gtid: int = 0


@dataclass
class CallStackFrame:
    function: int = -1
    module: int = -1
    depth: int = -1


class CallStackType(Enum):
    SAMPLED = auto()
    CSWITCH = auto()
    EVENT = auto()


@dataclass
class CallStack:
    id: int = -1
    tid: int = -1
    pid: int = -1
    time: int = -1
    stack_type: CallStackType = CallStackType.EVENT
    stack: list = field(default_factory=list)


@dataclass
class GPUMetric:
    timestamp: int = 0
    value: int = 0


@dataclass
class SourceMetaInfo:
    report_name: str
    gtl_dq_job_id: int = -1
    gtl_application_id: int = -1
    gtl_application_name: str = ''
    fc_job_id: str = ''
    fc_task_id: str = ''
    report_resolved_nsys_id: str = ''
    report_resolved_nsys_path: str = ''
    report_source_nsys_id: str = ''
    report_source_nsys_path: str = ''
    report_source_sqlite_id: str = ''
    report_source_sqlite_path: str = ''


@dataclass
class CPUConfig:
    physical_p_core_count: int = 0
    logical_p_core_count: int = 0
    physical_e_core_count: int = 0
    logical_e_core_count: int = 0
    p_core_starting_index: int = 0


# https://stackoverflow.com/questions/3160699/python-progress-bar
def progressbar(it, prefix="", size=60, quiet=False, file=sys.stdout):
    if quiet:
        for i, item in enumerate(it):
            yield item
        return

    count = len(it)
    inc = count / size
    min_val = int(inc / 2)

    def show(j):
        if count == 0:
            return
        x = int(size * j / count)
        file.write(f'{prefix}[{"#" * x}{"." * (size - x)}] {j}/{count}\r')
        file.flush()

    show(0)
    for i, item in enumerate(it):
        yield item
        if (i % min_val) == 0:
            show(i + 1)
    show(count)
    file.write("\n")
    file.flush()


def convert_global_tid(gtid: int):
    pid = int((gtid / 0x1000000) % 0x1000000)
    tid = gtid % 0x1000000
    return pid, tid


def get_pid(gtid: int):
    return int((gtid / 0x1000000) % 0x1000000)


def get_tid(gtid: int):
    return gtid % 0x1000000


def get_gtid(pid: int, tid: int):
    return pid * 0x1000000 + tid


# Compares the pid/tid only
def compare_gtid(gtid1: int, gtid2: int) -> bool:
    return (gtid1 & 0xFFFFFFFFFFFF) == (gtid2 & 0xFFFFFFFFFFFF)


def get_PE_core_counts_from_filename(nsys_report_filename: str) -> CPUConfig:
    cpu_config = CPUConfig

    # P_E_Freq.Freq_SMT[Off|On]
    matches = re.search(r"\[(\d+)_(\d+)_\d+.\d+_(Off|On)]", nsys_report_filename)
    if not matches:
        matches = re.search(r"\[(\d+)_(\d+)_(Off|On)]", nsys_report_filename)

    if matches:
        cpu_config.physical_p_core_count = int(matches.group(1))
        cpu_config.physical_e_core_count = int(matches.group(2))
        cpu_config.logical_e_core_count = cpu_config.physical_e_core_count
        if matches.group(3) == "On":
            cpu_config.logical_p_core_count = cpu_config.physical_p_core_count * 2

        return cpu_config

    # 8C16T8E
    matches = re.search(r"(\d+)C(\d+)T(\d+)E", nsys_report_filename)
    if matches:
        cpu_config.physical_p_core_count = int(matches.group(1))
        cpu_config.logical_p_core_count = int(matches.group(2))
        cpu_config.physical_e_core_count = int(matches.group(3))
        cpu_config.logical_e_core_count = cpu_config.physical_e_core_count
        cpu_config.p_core_starting_index = 0
        return cpu_config

    # 8E8P
    matches = re.search(r"(\d+)E(\d+)P", nsys_report_filename)
    if matches:
        cpu_config.physical_p_core_count = int(matches.group(2))
        cpu_config.logical_p_core_count = cpu_config.physical_p_core_count
        cpu_config.physical_e_core_count = int(matches.group(1))
        cpu_config.logical_e_core_count = cpu_config.physical_e_core_count
        cpu_config.p_core_starting_index = cpu_config.physical_e_core_count
        return cpu_config

    # 8P8E
    matches = re.search(r"(\d+)P(\d+)E", nsys_report_filename)
    if matches:
        cpu_config.physical_p_core_count = int(matches.group(1))
        cpu_config.logical_p_core_count = cpu_config.physical_p_core_count
        cpu_config.physical_e_core_count = int(matches.group(2))
        cpu_config.logical_e_core_count = cpu_config.physical_e_core_count
        cpu_config.p_core_starting_index = 0
        return cpu_config

    return None


def safe_type(val, to_type, default=None):
    if val is None:
        return default
    try:
        return to_type(val)
    except (ValueError, TypeError) as e:
        logger.warning(f"Failed to convert {val} to {to_type.__name__}; {e}")
        return default


def safe_float(val, default=None, check_bool=False):
    if val is None:
        return default
    if check_bool and isinstance(val, str):
        val = val.lower() in ('true', '1', 't')
    float_value = safe_type(val, float, default)
    return float_value if isinstance(float_value, float) and not math.isnan(float_value) else default


def get_metric_unit(metric_name: str) -> str:
    name = metric_name.lower()
    if 'frequency' in name:
        return 'MHz'
    if 'throughput' in name:
        return 'pct'
    if 'avg' in name:
        return 'avg'
    return 'float'


def safe_list_get(input_list: list, idx: int, default: any = None):
    if len(input_list) > idx:
        return input_list[idx]
    return default


def timeit(method):
    def timed(*args, **kw):
        ts = time.time()
        result = method(*args, **kw)
        te = time.time()

        logger.debug(f'{method.__name__!r} took {te - ts:2.2f} sec')
        return result

    return timed
