#!/usr/bin/python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import sys, os
from subprocess import PIPE, Popen
import logging
import traceback
from os.path import exists, join
from signal import alarm, signal, SIGALRM, SIGKILL

class CloudRuntimeException(Exception):
    def __init__(self, errMsg):
        self.errMsg = errMsg
    def __str__(self):
        return self.errMsg
def formatExceptionInfo(maxTBlevel=5):
    cla, exc, trbk = sys.exc_info()
    excTb = traceback.format_tb(trbk, maxTBlevel)
    msg = str(exc) + "\n"
    for tb in excTb:
        msg += tb
    return msg

class bash:
    def __init__(self, args, timeout=600):
        self.args = args
        logging.debug("execute:%s"%args)
        self.timeout = timeout
        self.process = None
        self.success = False
        self.run()

    def run(self):
        class Alarm(Exception):
            pass
        def alarm_handler(signum, frame):
            raise Alarm

        try:
            self.process = Popen(self.args, shell=True, stdout=PIPE, stderr=PIPE)
            if self.timeout != -1:
                signal(SIGALRM, alarm_handler)
                alarm(self.timeout)

            try:
                self.stdout, self.stderr = self.process.communicate()
                if self.timeout != -1:
                    alarm(0)
            except Alarm:
                os.kill(self.process.pid, SIGKILL)
                raise  CloudRuntimeException("Timeout during command execution")

            self.success = self.process.returncode == 0
        except:
            raise  CloudRuntimeException(formatExceptionInfo())

#        if not self.success: 
#            raise  CloudRuntimeException(self.getStderr())

    def isSuccess(self):
        return self.success
    
    def getStdout(self):
        return self.stdout.strip("\n")
    
    def getLines(self):
        return self.stdout.split("\n")

    def getStderr(self):
        return self.stderr.strip("\n")


def initLoging(logFile=None):
    try:
        if logFile is None:
            logging.basicConfig(level=logging.DEBUG) 
        else: 
            logging.basicConfig(filename=logFile, level=logging.DEBUG) 
    except:
        logging.basicConfig(level=logging.DEBUG) 

def writeProgressBar(msg, result=None):    
    if msg is not None:
        output = "%-80s"%msg
    elif result is True:
        output = "[ \033[92m%-2s\033[0m ]\n"%"OK"
    elif result is False:
        output = "[ \033[91m%-6s\033[0m ]\n"%"FAILED"
    sys.stdout.write(output)
    sys.stdout.flush()
    
def printError(msg):
    sys.stderr.write(msg)
    sys.stderr.write("\n")
    sys.stderr.flush()

def printMsg(msg):
    sys.stdout.write(msg+"\n")
    sys.stdout.flush()

def checkRpm(pkgName):
    chkPkg = bash("rpm -q %s"%pkgName)
    writeProgressBar("Checking %s"%pkgName, None)
    if not chkPkg.isSuccess():
        writeProgressBar(None, False)
        printError("%s is not found, please make sure it is installed. You may try 'yum install %s'\n"%(pkgName, pkgName))
        return False
    else:
        writeProgressBar(None, True)
        return True
      
def checkEnv():
   writeProgressBar("Checking is root")
   ret = bash("whoami")
   if ret.getStdout() != "root":
       writeProgressBar(None, False)
       printError("This script must run as root")
       return False
   else:
       writeProgressBar(None, True)
       
   pkgList = ['tftp-server', 'syslinux', 'xinetd', 'chkconfig', 'dhcp']
   for pkg in pkgList:
       if not checkRpm(pkg):
           return False
   return True

def exitIfFail(ret):
    if not ret: sys.exit(1) 
    
def bashWithResult(cmd):
    writeProgressBar("Executing '%s'"%cmd)
    ret = bash(cmd)
    if not ret.isSuccess():
        writeProgressBar(None, False)
        writeProgressBar(ret.getStderr() + '\n')
        return False
    else:
        writeProgressBar(None, True)
        return True
    
def configurePxeStuff(): 
    stuff = ['tftp', 'xinetd', 'dhcpd']
    cmds = ['chkconfig --level 345 %s on' % i for i in stuff]
    cmds.append('/etc/init.d/xinetd restart')
    
    for cmd in cmds:
        if not bashWithResult(cmd): return False
        
    chkIptable = bash('chkconfig --list iptables')
    if 'on' in chkIptable.getStdout():
        printMsg("Detected iptables is running, need to open tftp port 69")
        if not bashWithResult('iptables -I INPUT 1 -p udp --dport 69 -j ACCEPT'): return False
        if not bashWithResult('/etc/init.d/iptables save'): return False
        
    return True  
    
def getTftpRootDir(tftpRootDirList):
    tftpRoot = bash("cat /etc/xinetd.d/tftp | grep server_args")
    if not tftpRoot.isSuccess():
        printError("Cannot get tftp root directory from /etc/xinetd.d/tftp, here may be something wrong with your tftp-server, try reinstall it\n")
        return False
    tftpRootDir = tftpRoot.getStdout()
    index = tftpRootDir.find("/")
    if index == -1:
        printError("Wrong server_arg in /etc/xinetd.d/tftp (%s)"%tftpRootDir)
        return False
    tftpRootDir = tftpRootDir[index:]
    tftpRootDirList.append(tftpRootDir)
    return True

def preparePING(tftpRootDir):
    pingFiles = ['boot.msg', 'initrd.gz', 'kernel', 'pxelinux.0']
    pingDir = "/usr/share/PING"
    
    for f in pingFiles:
        path = join(pingDir, f)
        if not exists(path):
            printError("Cannot find %s, please make sure PING-3.01 is installed"%path)
            return False
        if not bashWithResult("cp -f %s %s"%(path, tftpRootDir)): return False
     
    if not bashWithResult("mkdir -p %s/pxelinux.cfg"%tftpRootDir): return False
    
    return True
            
        
if __name__ == "__main__":
    initLoging("/tmp/cloud-setup-baremetal.log")
    tftpRootDirList = []
    
    exitIfFail(checkEnv())
    exitIfFail(configurePxeStuff())
    exitIfFail(getTftpRootDir(tftpRootDirList))
    
    tftpRootDir = tftpRootDirList[0].strip()
    exitIfFail(preparePING(tftpRootDir))
    printMsg("")
    printMsg("Setup BareMetal PXE server successfully")
    printMsg("TFTP root directory is: %s\n"%tftpRootDir)
    sys.exit(0)
    
