#!/usr/bin/python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import sys, os
from subprocess import PIPE, Popen
import logging
import traceback
from os.path import exists, join
from signal import alarm, signal, SIGALRM, SIGKILL


class CloudRuntimeException(Exception):
    def __init__(self, errMsg):
        self.errMsg = errMsg

    def __str__(self):
        return self.errMsg


def formatExceptionInfo(maxTBlevel=5):
    cla, exc, trbk = sys.exc_info()
    excTb = traceback.format_tb(trbk, maxTBlevel)
    msg = str(exc) + "\n"
    for tb in excTb:
        msg += tb
    return msg


class bash:
    def __init__(self, args, timeout=600):
        self.args = args
        logging.debug("execute:%s" % args)
        self.timeout = timeout
        self.process = None
        self.success = False
        self.run()

    def run(self):
        class Alarm(Exception):
            pass

        def alarm_handler(signum, frame):
            raise Alarm

        try:
            self.process = Popen(self.args, shell=True, stdout=PIPE, stderr=PIPE)
            if self.timeout != -1:
                signal(SIGALRM, alarm_handler)
                alarm(self.timeout)

            try:
                self.stdout, self.stderr = self.process.communicate()
                if self.timeout != -1:
                    alarm(0)
            except Alarm:
                os.kill(self.process.pid, SIGKILL)
                raise CloudRuntimeException("Timeout during command execution")

            self.success = self.process.returncode == 0
        except:
            raise CloudRuntimeException(formatExceptionInfo())

    #        if not self.success:
    #            raise  CloudRuntimeException(self.getStderr())

    def isSuccess(self):
        return self.success

    def getStdout(self):
        return self.stdout.strip("\n")

    def getLines(self):
        return self.stdout.split("\n")

    def getStderr(self):
        return self.stderr.strip("\n")


def initLoging(logFile=None):
    try:
        if logFile is None:
            logging.basicConfig(level=logging.DEBUG)
        else:
            logging.basicConfig(filename=logFile, level=logging.DEBUG)
    except:
        logging.basicConfig(level=logging.DEBUG)


def writeProgressBar(msg, result=None):
    if msg is not None:
        output = "%-80s" % msg
    elif result is True:
        output = "[ \033[92m%-2s\033[0m ]\n" % "OK"
    elif result is False:
        output = "[ \033[91m%-6s\033[0m ]\n" % "FAILED"
    sys.stdout.write(output)
    sys.stdout.flush()


def printError(msg):
    sys.stderr.write(msg)
    sys.stderr.write("\n")
    sys.stderr.flush()


def printMsg(msg):
    sys.stdout.write(msg + "\n")
    sys.stdout.flush()


def checkRpm(pkgName):
    chkPkg = bash("rpm -q %s" % pkgName)
    writeProgressBar("Checking %s" % pkgName, None)
    if not chkPkg.isSuccess():
        writeProgressBar(None, False)
        printError(
            "%s is not found, please make sure it is installed. You may try 'yum install %s'\n"
            % (pkgName, pkgName)
        )
        return False
    else:
        writeProgressBar(None, True)
        return True


def checkEnv():
    writeProgressBar("Checking is root")
    ret = bash("whoami")
    if ret.getStdout() != "root":
        writeProgressBar(None, False)
        printError("This script must run as root")
        return False
    else:
        writeProgressBar(None, True)

    pkgList = ["tftp-server", "syslinux", "xinetd", "chkconfig", "dhcp"]
    for pkg in pkgList:
        if not checkRpm(pkg):
            return False
    return True


def exitIfFail(ret):
    if not ret:
        sys.exit(1)


def bashWithResult(cmd):
    writeProgressBar("Executing '%s'" % cmd)
    ret = bash(cmd)
    if not ret.isSuccess():
        writeProgressBar(None, False)
        writeProgressBar(ret.getStderr() + "\n")
        return False
    else:
        writeProgressBar(None, True)
        return True


def configurePxeStuff():
    stuff = ["tftp", "xinetd", "dhcpd"]
    cmds = ["chkconfig --level 345 %s on" % i for i in stuff]
    cmds.append("/etc/init.d/xinetd restart")

    for cmd in cmds:
        if not bashWithResult(cmd):
            return False

    chkIptable = bash("chkconfig --list iptables")
    if "on" in chkIptable.getStdout():
        printMsg("Detected iptables is running, need to open tftp port 69")
        if not bashWithResult("iptables -I INPUT 1 -p udp --dport 69 -j ACCEPT"):
            return False
        if not bashWithResult("/etc/init.d/iptables save"):
            return False

    return True


def getTftpRootDir(tftpRootDirList):
    tftpRoot = bash("cat /etc/xinetd.d/tftp | grep server_args")
    if not tftpRoot.isSuccess():
        printError(
            "Cannot get tftp root directory from /etc/xinetd.d/tftp, here may be something wrong with your tftp-server, try reinstall it\n"
        )
        return False
    tftpRootDir = tftpRoot.getStdout()
    index = tftpRootDir.find("/")
    if index == -1:
        printError("Wrong server_arg in /etc/xinetd.d/tftp (%s)" % tftpRootDir)
        return False
    tftpRootDir = tftpRootDir[index:]
    tftpRootDirList.append(tftpRootDir)
    return True


def preparePING(tftpRootDir):
    pingFiles = ["boot.msg", "initrd.gz", "kernel", "pxelinux.0"]
    pingDir = "/usr/share/PING"

    for f in pingFiles:
        path = join(pingDir, f)
        if not exists(path):
            printError("Cannot find %s, please make sure PING-3.01 is installed" % path)
            return False
        if not bashWithResult("cp -f %s %s" % (path, tftpRootDir)):
            return False

    if not bashWithResult("mkdir -p %s/pxelinux.cfg" % tftpRootDir):
        return False

    return True


if __name__ == "__main__":
    initLoging("/tmp/cloud-setup-baremetal.log")
    tftpRootDirList = []

    exitIfFail(checkEnv())
    exitIfFail(configurePxeStuff())
    exitIfFail(getTftpRootDir(tftpRootDirList))

    tftpRootDir = tftpRootDirList[0].strip()
    exitIfFail(preparePING(tftpRootDir))
    printMsg("")
    printMsg("Setup BareMetal PXE server successfully")
    printMsg("TFTP root directory is: %s\n" % tftpRootDir)
    sys.exit(0)
