#!/usr/bin/env python3
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2023  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#
#
'''
Get direction of maximum polarizability from DFTB+ TD data
'''

import numpy as np
import argparse
from scipy import linalg, constants
import sys
hplanck = constants.physical_constants['Planck constant in eV s'][0] * 1.0E15

DESCRIPTION = """Reads output from TD calculation (after kick) and excitation energy
(or wavelength) of interest and diagonalizes polarizability tensor, to
calculate direction of maximum polarizability of that particular excitation.
Needs mux.dat, muy.dat and muz.dat files in working directory."""

def main():
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument("-d", "--damping", dest="tau", required=True,
                        type=float, help="damping constant in fs")

    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("-e", "--energy", dest="energy", type=float,
                       help="excitation energy in eV")
    group.add_argument("-w", "--wavelength", dest="wvlength", type=float,
                       help="wavelength of excitation in nm")

    args = parser.parse_args()

    if args.energy is not None:
        energy = args.energy
    elif args.wvlength is not None:
        cspeednm = constants.speed_of_light * 1.0e9 / 1.0e15
        energy = hplanck * cspeednm / args.wvlength

    mux = np.loadtxt('mux.dat')
    muy = np.loadtxt('muy.dat')
    muz = np.loadtxt('muz.dat')

    time = mux[:,0]

    mu = np.zeros((3, 3, mux.shape[0]))
    specs = np.zeros((3, 3, time.shape[0]*5+1))
    alfa = np.zeros((3, 3))

    mu[0,0] = mux[:,1]
    mu[0,1] = mux[:,2]
    mu[0,2] = mux[:,3]
    mu[1,0] = muy[:,1]
    mu[1,1] = muy[:,2]
    mu[1,2] = muy[:,3]
    mu[2,0] = muz[:,1]
    mu[2,1] = muz[:,2]
    mu[2,2] = muz[:,3]

    if (mux.shape[1] > 4):
        spinpol = True
        print('This propagation was done using collinear spin polarisation')
        spintype = str(input('Please select singlet or triplet excitations (s/t)'))
    else:
        spinpol = False

    if spinpol:
        if spintype == 's':
            factor = 1
        elif spintype == 't':
            factor = -1
        else:
            print("Wrong spin type")
            sys.exit()

        mu[0,0] += factor * mux[:,4]
        mu[0,1] += factor * mux[:,5]
        mu[0,2] += factor * mux[:,6]
        mu[1,0] += factor * muy[:,4]
        mu[1,1] += factor * muy[:,5]
        mu[1,2] += factor * muy[:,6]
        mu[2,0] += factor * muz[:,4]
        mu[2,1] += factor * muz[:,5]
        mu[2,2] += factor * muz[:,6]

    damp = np.exp(-time/ args.tau)
    energsev = np.fft.rfftfreq(10*time.shape[0],time[1]-time[0]) * hplanck
    idx = (np.abs(energsev-energy)).argmin()

    for i in range(3):
        for j in range(3):
            specs[i,j] = - np.fft.rfft(damp*(mu[i,j,:]-mu[i,j,0]), \
                                       10*time.shape[0]).imag * 2.0 * energsev / np.pi
            alfa[i,j] = specs[i,j,idx]

    w,v = linalg.eig(alfa)
    maxidx = np.argmax(w)
    print('PolarizationDirection = {:.8f} {:.8f} {:.8f}'.format(*v[:,maxidx]))

if __name__ == "__main__":
    main()
