!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2026 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!
MODULE fftw3_lib
   USE ISO_C_BINDING, ONLY: C_ASSOCIATED, &
                            C_CHAR, &
                            C_DOUBLE, &
                            C_DOUBLE_COMPLEX, &
                            C_INT, &
                            C_PTR
#if defined(__FFTW3)
   USE ISO_C_BINDING, ONLY: &
      C_FLOAT, &
      C_FLOAT_COMPLEX, &
      C_FUNPTR, &
      C_INT32_T, &
      C_INTPTR_T, &
      C_LOC, &
      C_NULL_CHAR, &
      C_SIZE_T, C_F_POINTER
   USE mathconstants, ONLY: z_zero
#endif
   USE cp_files, ONLY: get_unit_number
   USE fft_kinds, ONLY: dp
   USE fft_plan, ONLY: fft_plan_type

!$ USE OMP_LIB, ONLY: omp_get_max_threads, omp_get_thread_num, omp_get_num_threads

#include "../../base/base_uses.f90"

   IMPLICIT NONE
   PRIVATE

   PUBLIC :: fftw3_do_init, fftw3_do_cleanup, fftw3_get_lengths, fftw33d, fftw31dm
   PUBLIC :: fftw3_destroy_plan, fftw3_create_plan_1dm, fftw3_create_plan_3d
   PUBLIC :: fftw_alloc, fftw_dealloc

#if defined(__FFTW3)
#include "fftw3.f03"
#endif

   INTERFACE fftw_alloc
      MODULE PROCEDURE :: fftw_alloc_complex_1d
      MODULE PROCEDURE :: fftw_alloc_complex_2d
      MODULE PROCEDURE :: fftw_alloc_complex_3d
   END INTERFACE fftw_alloc

   INTERFACE fftw_dealloc
      MODULE PROCEDURE :: fftw_dealloc_complex_1d
      MODULE PROCEDURE :: fftw_dealloc_complex_2d
      MODULE PROCEDURE :: fftw_dealloc_complex_3d
   END INTERFACE fftw_dealloc

CONTAINS

   #:set maxdim = 3
   #:for dim in range(1, maxdim+1)
! Concatenate the components of the dimensions passed to this function to use it if FFTW3 is not used
      #:set dim_extended = ", ".join(["n("+str(i)+")" for i in range(1, dim+1)])
      SUBROUTINE fftw_alloc_complex_${dim}$d(array, n)
         COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(:${", :"*(dim-1)}$), CONTIGUOUS, POINTER, INTENT(OUT) :: array
         INTEGER, DIMENSION(${dim}$), INTENT(IN) :: n

#if defined(__FFTW3)
         TYPE(C_PTR) :: data_ptr
         data_ptr = fftw_alloc_complex(INT(PRODUCT(n), KIND=C_SIZE_T))
         CALL C_F_POINTER(data_ptr, array, n)
#else
! Just allocate the array
         ALLOCATE (array(${dim_extended}$))
#endif

      END SUBROUTINE fftw_alloc_complex_${dim}$d

      SUBROUTINE fftw_dealloc_complex_${dim}$d(array)
         COMPLEX(C_DOUBLE_COMPLEX), DIMENSION(:${", :"*(dim-1)}$), CONTIGUOUS, POINTER, INTENT(INOUT) :: array

#if defined(__FFTW3)
         CALL fftw_free(C_LOC(array))
         NULLIFY (array)
#else
! Just deallocate the array
         DEALLOCATE (array)
#endif

      END SUBROUTINE fftw_dealloc_complex_${dim}$d
   #:endfor

#if defined(__FFTW3)
! **************************************************************************************************
!> \brief A workaround that allows us to compile with -Werror=unused-parameter
! **************************************************************************************************
   SUBROUTINE dummy_routine_to_call_mark_used()
      MARK_USED(FFTW_R2HC)
      MARK_USED(FFTW_HC2R)
      MARK_USED(FFTW_DHT)
      MARK_USED(FFTW_REDFT00)
      MARK_USED(FFTW_REDFT01)
      MARK_USED(FFTW_REDFT10)
      MARK_USED(FFTW_REDFT11)
      MARK_USED(FFTW_RODFT00)
      MARK_USED(FFTW_RODFT01)
      MARK_USED(FFTW_RODFT10)
      MARK_USED(FFTW_RODFT11)
      MARK_USED(FFTW_FORWARD)
      MARK_USED(FFTW_BACKWARD)
      MARK_USED(FFTW_MEASURE)
      MARK_USED(FFTW_DESTROY_INPUT)
      MARK_USED(FFTW_UNALIGNED)
      MARK_USED(FFTW_CONSERVE_MEMORY)
      MARK_USED(FFTW_EXHAUSTIVE)
      MARK_USED(FFTW_PRESERVE_INPUT)
      MARK_USED(FFTW_PATIENT)
      MARK_USED(FFTW_ESTIMATE)
      MARK_USED(FFTW_WISDOM_ONLY)
      MARK_USED(FFTW_ESTIMATE_PATIENT)
      MARK_USED(FFTW_BELIEVE_PCOST)
      MARK_USED(FFTW_NO_DFT_R2HC)
      MARK_USED(FFTW_NO_NONTHREADED)
      MARK_USED(FFTW_NO_BUFFERING)
      MARK_USED(FFTW_NO_INDIRECT_OP)
      MARK_USED(FFTW_ALLOW_LARGE_GENERIC)
      MARK_USED(FFTW_NO_RANK_SPLITS)
      MARK_USED(FFTW_NO_VRANK_SPLITS)
      MARK_USED(FFTW_NO_VRECURSE)
      MARK_USED(FFTW_NO_SIMD)
      MARK_USED(FFTW_NO_SLOW)
      MARK_USED(FFTW_NO_FIXED_RADIX_LARGE_N)
      MARK_USED(FFTW_ALLOW_PRUNING)
   END SUBROUTINE dummy_routine_to_call_mark_used
#endif

! **************************************************************************************************
!> \brief ...
!> \param wisdom_file ...
!> \param ionode ...
! **************************************************************************************************
   SUBROUTINE fftw3_do_cleanup(wisdom_file, ionode)

      CHARACTER(LEN=*), INTENT(IN)             :: wisdom_file
      LOGICAL                                  :: ionode

#if defined(__FFTW3)
      CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), ALLOCATABLE :: wisdom_file_name_c
      INTEGER                                  :: file_name_length, i, iunit, istat
      INTEGER(KIND=C_INT)                      :: isuccess
      ! Write out FFTW3 wisdom to file (if we can)
      ! only the ionode updates the wisdom
      IF (ionode) THEN
         iunit = get_unit_number()
         ! Check whether the file can be opened in the necessary manner
         OPEN (UNIT=iunit, FILE=wisdom_file, STATUS="UNKNOWN", FORM="FORMATTED", ACTION="WRITE", IOSTAT=istat)
         IF (istat == 0) THEN
            CLOSE (iunit)
            file_name_length = LEN_TRIM(wisdom_file)
            ALLOCATE (wisdom_file_name_c(file_name_length + 1))
            DO i = 1, file_name_length
               wisdom_file_name_c(i) = wisdom_file(i:i)
            END DO
            wisdom_file_name_c(file_name_length + 1) = C_NULL_CHAR
            isuccess = fftw_export_wisdom_to_filename(wisdom_file_name_c)
            IF (isuccess == 0) &
               CALL cp_warn(__LOCATION__, "Error exporting wisdom to file "//TRIM(wisdom_file)//". "// &
                            "Wisdom was not exported.")
         END IF
      END IF

      CALL fftw_cleanup()
#else
      MARK_USED(wisdom_file)
      MARK_USED(ionode)
#endif

   END SUBROUTINE fftw3_do_cleanup

! **************************************************************************************************
!> \brief ...
!> \param wisdom_file ...
! **************************************************************************************************
   SUBROUTINE fftw3_do_init(wisdom_file)

      CHARACTER(LEN=*), INTENT(IN)             :: wisdom_file

#if defined(__FFTW3)
      CHARACTER(LEN=1, KIND=C_CHAR), DIMENSION(:), ALLOCATABLE :: wisdom_file_name_c
      INTEGER                                  :: file_name_length, i, istat, iunit
      INTEGER(KIND=C_INT)                      :: isuccess
      LOGICAL :: file_exists

      isuccess = fftw_init_threads()
      IF (isuccess == 0) &
         CPABORT("Error initializing FFTW with threads")

      ! Read FFTW wisdom (if available)
      ! all nodes are opening the file here...
      INQUIRE (FILE=wisdom_file, exist=file_exists)
      IF (file_exists) THEN
         iunit = get_unit_number()
         file_name_length = LEN_TRIM(wisdom_file)
         ! Check whether the file can be opened in the necessary manner
         OPEN (UNIT=iunit, FILE=wisdom_file, STATUS="OLD", FORM="FORMATTED", POSITION="REWIND", &
               ACTION="READ", IOSTAT=istat)
         IF (istat == 0) THEN
            CLOSE (iunit)
            file_name_length = LEN_TRIM(wisdom_file)
            ALLOCATE (wisdom_file_name_c(file_name_length + 1))
            DO i = 1, file_name_length
               wisdom_file_name_c(i) = wisdom_file(i:i)
            END DO
            wisdom_file_name_c(file_name_length + 1) = C_NULL_CHAR
            isuccess = fftw_import_wisdom_from_filename(wisdom_file_name_c)
            IF (isuccess == 0) &
               CALL cp_warn(__LOCATION__, "Error importing wisdom from file "//TRIM(wisdom_file)//". "// &
                            "Maybe the file was created with a different configuration than CP2K is run with. "// &
                            "CP2K continues without importing wisdom.")
         END IF
      END IF
#else
      MARK_USED(wisdom_file)
#endif

   END SUBROUTINE fftw3_do_init

! **************************************************************************************************
!> \brief ...
!> \param DATA ...
!> \param max_length ...
!> \par History
!>      JGH 23-Jan-2006 : initial version
!>      Adapted for new interface
!>      IAB 09-Jan-2009 : Modified to cache plans in fft_plan_type
!>                        (c) The Numerical Algorithms Group (NAG) Ltd, 2009 on behalf of the HECToR project
!>      IAB 09-Oct-2009 : Added OpenMP directives to 1D FFT, and planning routines
!>                        (c) The Numerical Algorithms Group (NAG) Ltd, 2009 on behalf of the HECToR project
!>      IAB 11-Sep-2012 : OpenMP parallel 3D FFT (Ruyman Reyes, PRACE)
!> \author JGH
! **************************************************************************************************
   SUBROUTINE fftw3_get_lengths(DATA, max_length)

      INTEGER, DIMENSION(*)                              :: DATA
      INTEGER, INTENT(INOUT)                             :: max_length

      INTEGER                                            :: h, i, j, k, m, maxn, maxn_elevens, &
                                                            maxn_fives, maxn_sevens, &
                                                            maxn_thirteens, maxn_threes, &
                                                            maxn_twos, ndata, nmax, number
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: dlocal, idx

!------------------------------------------------------------------------------
! compute ndata
!! FFTW can do arbitrary(?) lengths, maybe you want to limit them to some
!!    powers of small prime numbers though...

      maxn_twos = 15
      maxn_threes = 3
      maxn_fives = 2
      maxn_sevens = 1
      maxn_elevens = 1
      maxn_thirteens = 0
      maxn = 37748736

      ndata = 0
      DO h = 0, maxn_twos
         nmax = HUGE(0)/2**h
         DO i = 0, maxn_threes
            DO j = 0, maxn_fives
               DO k = 0, maxn_sevens
                  DO m = 0, maxn_elevens
                     number = (3**i)*(5**j)*(7**k)*(11**m)

                     IF (number > nmax) CYCLE

                     number = number*2**h
                     IF (number >= maxn) CYCLE

                     ndata = ndata + 1
                  END DO
               END DO
            END DO
         END DO
      END DO

      ALLOCATE (dlocal(ndata), idx(ndata))

      ndata = 0
      dlocal(:) = 0
      DO h = 0, maxn_twos
         nmax = HUGE(0)/2**h
         DO i = 0, maxn_threes
            DO j = 0, maxn_fives
               DO k = 0, maxn_sevens
                  DO m = 0, maxn_elevens
                     number = (3**i)*(5**j)*(7**k)*(11**m)

                     IF (number > nmax) CYCLE

                     number = number*2**h
                     IF (number >= maxn) CYCLE

                     ndata = ndata + 1
                     dlocal(ndata) = number
                  END DO
               END DO
            END DO
         END DO
      END DO

      CALL sortint(dlocal, ndata, idx)
      ndata = MIN(ndata, max_length)
      DATA(1:ndata) = dlocal(1:ndata)
      max_length = ndata

      DEALLOCATE (dlocal, idx)

   END SUBROUTINE fftw3_get_lengths

! **************************************************************************************************
!> \brief ...
!> \param iarr ...
!> \param n ...
!> \param index ...
! **************************************************************************************************
   SUBROUTINE sortint(iarr, n, index)

      INTEGER, INTENT(IN)                                :: n
      INTEGER, INTENT(INOUT)                             :: iarr(1:n)
      INTEGER, INTENT(OUT)                               :: INDEX(1:n)

      INTEGER, PARAMETER                                 :: m = 7, nstack = 50

      INTEGER                                            :: a, i, ib, ir, istack(1:nstack), itemp, &
                                                            j, jstack, k, l, temp

!------------------------------------------------------------------------------

      DO i = 1, n
         INDEX(i) = i
      END DO
      jstack = 0
      l = 1
      ir = n
      DO WHILE (.TRUE.)
      IF (ir - l < m) THEN
         DO j = l + 1, ir
            a = iarr(j)
            ib = INDEX(j)
            DO i = j - 1, 0, -1
               IF (i == 0) EXIT
               IF (iarr(i) <= a) EXIT
               iarr(i + 1) = iarr(i)
               INDEX(i + 1) = INDEX(i)
            END DO
            iarr(i + 1) = a
            INDEX(i + 1) = ib
         END DO
         IF (jstack == 0) RETURN
         ir = istack(jstack)
         l = istack(jstack - 1)
         jstack = jstack - 2
      ELSE
         k = (l + ir)/2
         temp = iarr(k)
         iarr(k) = iarr(l + 1)
         iarr(l + 1) = temp
         itemp = INDEX(k)
         INDEX(k) = INDEX(l + 1)
         INDEX(l + 1) = itemp
         IF (iarr(l + 1) > iarr(ir)) THEN
            temp = iarr(l + 1)
            iarr(l + 1) = iarr(ir)
            iarr(ir) = temp
            itemp = INDEX(l + 1)
            INDEX(l + 1) = INDEX(ir)
            INDEX(ir) = itemp
         END IF
         IF (iarr(l) > iarr(ir)) THEN
            temp = iarr(l)
            iarr(l) = iarr(ir)
            iarr(ir) = temp
            itemp = INDEX(l)
            INDEX(l) = INDEX(ir)
            INDEX(ir) = itemp
         END IF
         IF (iarr(l + 1) > iarr(l)) THEN
            temp = iarr(l + 1)
            iarr(l + 1) = iarr(l)
            iarr(l) = temp
            itemp = INDEX(l + 1)
            INDEX(l + 1) = INDEX(l)
            INDEX(l) = itemp
         END IF
         i = l + 1
         j = ir
         a = iarr(l)
         ib = INDEX(l)
         DO WHILE (.TRUE.)
            i = i + 1
            DO WHILE (iarr(i) < a)
               i = i + 1
            END DO
            j = j - 1
            DO WHILE (iarr(j) > a)
               j = j - 1
            END DO
            IF (j < i) EXIT
            temp = iarr(i)
            iarr(i) = iarr(j)
            iarr(j) = temp
            itemp = INDEX(i)
            INDEX(i) = INDEX(j)
            INDEX(j) = itemp
         END DO
         iarr(l) = iarr(j)
         iarr(j) = a
         INDEX(l) = INDEX(j)
         INDEX(j) = ib
         jstack = jstack + 2
         IF (jstack > nstack) CPABORT("Nstack too small in sortint")
         IF (ir - i + 1 >= j - l) THEN
            istack(jstack) = ir
            istack(jstack - 1) = i
            ir = j - 1
         ELSE
            istack(jstack) = j - 1
            istack(jstack - 1) = l
            l = i
         END IF
      END IF

      END DO

   END SUBROUTINE sortint

! **************************************************************************************************

! **************************************************************************************************
!> \brief ...
!> \param plan ...
!> \param fft_rank ...
!> \param dim_n ...
!> \param dim_istride ...
!> \param dim_ostride ...
!> \param hm_rank ...
!> \param hm_n ...
!> \param hm_istride ...
!> \param hm_ostride ...
!> \param zin ...
!> \param zout ...
!> \param fft_direction ...
!> \param fftw_plan_type ...
!> \param valid ...
! **************************************************************************************************
   SUBROUTINE fftw3_create_guru_plan(plan, fft_rank, dim_n, &
                                     dim_istride, dim_ostride, hm_rank, &
                                     hm_n, hm_istride, hm_ostride, &
                                     zin, zout, fft_direction, fftw_plan_type, &
                                     valid)

      IMPLICIT NONE

      TYPE(C_PTR), INTENT(INOUT)                         :: plan
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(INOUT)      :: zin, zout
      INTEGER, INTENT(IN) :: dim_n(2), dim_istride(2), dim_ostride(2), &
                             hm_n(2), hm_istride(2), hm_ostride(2), fft_rank, &
                             fft_direction, fftw_plan_type, hm_rank
      LOGICAL, INTENT(OUT)                               :: valid

#if defined(__FFTW3)
      TYPE(fftw_iodim) :: dim(2), hm(2)
      INTEGER :: i

      DO i = 1, 2
         DIM(i) = fftw_iodim(dim_n(i), dim_istride(i), dim_ostride(i))
         hm(i) = fftw_iodim(hm_n(i), hm_istride(i), hm_ostride(i))
      END DO

      plan = fftw_plan_guru_dft(fft_rank, &
                                dim, hm_rank, hm, &
                                zin, zout, &
                                fft_direction, fftw_plan_type)

      valid = C_ASSOCIATED(plan)

#else
      MARK_USED(plan)
      MARK_USED(fft_rank)
      MARK_USED(dim_n)
      MARK_USED(dim_istride)
      MARK_USED(dim_ostride)
      MARK_USED(hm_rank)
      MARK_USED(hm_n)
      MARK_USED(hm_istride)
      MARK_USED(hm_ostride)
      MARK_USED(fft_direction)
      MARK_USED(fftw_plan_type)
      !MARK_USED does not work with assumed size arguments
      IF (.FALSE.) THEN; DO; IF (ABS(zin(1)) > ABS(zout(1))) EXIT; END DO; END IF
      valid = .FALSE.

#endif

   END SUBROUTINE fftw3_create_guru_plan

! **************************************************************************************************

! **************************************************************************************************
!> \brief Attempt to create a plan with the guru interface for a 2d sub-space.
!>        If this fails, fall back to the FFTW3 threaded 3D transform instead
!>        of the hand-optimised version.
!> \return ...
! **************************************************************************************************
   FUNCTION fftw3_is_guru_supported() RESULT(guru_supported)

      IMPLICIT NONE

      LOGICAL :: guru_supported
#if defined(__FFTW3)
      INTEGER :: dim_n(2), dim_istride(2), dim_ostride(2), &
                 howmany_n(2), howmany_istride(2), howmany_ostride(2)
      TYPE(C_PTR)                          :: test_plan
      COMPLEX(KIND=dp), DIMENSION(1, 1, 1) :: zin

      dim_n(1) = 1
      dim_n(2) = 1
      dim_istride(1) = 1
      dim_istride(2) = 1
      dim_ostride(1) = 1
      dim_ostride(2) = 1
      howmany_n(1) = 1
      howmany_n(2) = 1
      howmany_istride(1) = 1
      howmany_istride(2) = 1
      howmany_ostride(1) = 1
      howmany_ostride(2) = 1
      zin = z_zero
      CALL fftw3_create_guru_plan(test_plan, 1, &
                                  dim_n, dim_istride, dim_ostride, &
                                  2, howmany_n, howmany_istride, howmany_ostride, &
                                  zin, zin, &
                                  FFTW_FORWARD, FFTW_ESTIMATE, guru_supported)
      IF (guru_supported) THEN
         CALL fftw_destroy_plan(test_plan)
      END IF

#else
      guru_supported = .FALSE.
#endif

   END FUNCTION fftw3_is_guru_supported

! **************************************************************************************************

! **************************************************************************************************
!> \brief ...
!> \param nrows ...
!> \param nt ...
!> \param rows_per_thread ...
!> \param rows_per_thread_r ...
!> \param th_planA ...
!> \param th_planB ...
! **************************************************************************************************
   SUBROUTINE fftw3_compute_rows_per_th(nrows, nt, rows_per_thread, rows_per_thread_r, &
                                        th_planA, th_planB)

      INTEGER, INTENT(IN)                                :: nrows, nt
      INTEGER, INTENT(OUT)                               :: rows_per_thread, rows_per_thread_r, &
                                                            th_planA, th_planB

      IF (MOD(nrows, nt) == 0) THEN
         rows_per_thread = nrows/nt
         rows_per_thread_r = 0
         th_planA = nt
         th_planB = 0
      ELSE
         rows_per_thread = nrows/nt + 1
         rows_per_thread_r = nrows/nt
         th_planA = MOD(nrows, nt)
         th_planB = nt - th_planA
      END IF

   END SUBROUTINE fftw3_compute_rows_per_th

! **************************************************************************************************

! **************************************************************************************************
!> \brief ...
!> \param plan ...
!> \param plan_r ...
!> \param dim_n ...
!> \param dim_istride ...
!> \param dim_ostride ...
!> \param hm_n ...
!> \param hm_istride ...
!> \param hm_ostride ...
!> \param input ...
!> \param output ...
!> \param fft_direction ...
!> \param fftw_plan_type ...
!> \param rows_per_th ...
!> \param rows_per_th_r ...
! **************************************************************************************************
   SUBROUTINE fftw3_create_3d_plans(plan, plan_r, dim_n, dim_istride, dim_ostride, &
                                    hm_n, hm_istride, hm_ostride, &
                                    input, output, &
                                    fft_direction, fftw_plan_type, rows_per_th, &
                                    rows_per_th_r)

      TYPE(C_PTR), INTENT(INOUT)                         :: plan, plan_r
      INTEGER, INTENT(INOUT)                             :: dim_n(2), dim_istride(2), &
                                                            dim_ostride(2), hm_n(2), &
                                                            hm_istride(2), hm_ostride(2)
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(INOUT)      :: input, output
      INTEGER, INTENT(INOUT)                             :: fft_direction, fftw_plan_type
      INTEGER, INTENT(IN)                                :: rows_per_th, rows_per_th_r

      LOGICAL                                            :: valid

! First plans will have an additional row

      hm_n(2) = rows_per_th
      CALL fftw3_create_guru_plan(plan, 1, &
                                  dim_n, dim_istride, dim_ostride, &
                                  2, hm_n, hm_istride, hm_ostride, &
                                  input, output, &
                                  fft_direction, fftw_plan_type, valid)

      IF (.NOT. valid) THEN
         CPABORT("fftw3_create_plan")
      END IF

      !!!! Remainder
      hm_n(2) = rows_per_th_r
      CALL fftw3_create_guru_plan(plan_r, 1, &
                                  dim_n, dim_istride, dim_ostride, &
                                  2, hm_n, hm_istride, hm_ostride, &
                                  input, output, &
                                  fft_direction, fftw_plan_type, valid)
      IF (.NOT. valid) THEN
         CPABORT("fftw3_create_plan (remaining)")
      END IF

   END SUBROUTINE fftw3_create_3d_plans

! **************************************************************************************************

! **************************************************************************************************
!> \brief ...
!> \param plan ...
!> \param zin ...
!> \param zout ...
!> \param plan_style ...
! **************************************************************************************************
   SUBROUTINE fftw3_create_plan_3d(plan, zin, zout, plan_style)

      TYPE(fft_plan_type), INTENT(INOUT)              :: plan
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(INOUT)      :: zin
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(INOUT)      :: zout
      INTEGER                                            :: plan_style
#if defined(__FFTW3)
      INTEGER                                            :: n1, n2, n3
      INTEGER                                            :: nt
      INTEGER                                            :: rows_per_th
      INTEGER                                            :: rows_per_th_r
      INTEGER                                            :: fft_direction
      INTEGER                                            :: th_planA, th_planB
      COMPLEX(KIND=dp), ALLOCATABLE                      :: tmp(:)

      ! GURU Interface
      INTEGER :: dim_n(2), dim_istride(2), dim_ostride(2), &
                 howmany_n(2), howmany_istride(2), howmany_ostride(2)

      INTEGER :: fftw_plan_type
      SELECT CASE (plan_style)
      CASE (1)
         fftw_plan_type = FFTW_ESTIMATE
      CASE (2)
         fftw_plan_type = FFTW_MEASURE
      CASE (3)
         fftw_plan_type = FFTW_PATIENT
      CASE (4)
         fftw_plan_type = FFTW_EXHAUSTIVE
      CASE DEFAULT
         CPABORT("fftw3_create_plan_3d")
      END SELECT

      IF (plan%fsign == +1) THEN
         fft_direction = FFTW_FORWARD
      ELSE
         fft_direction = FFTW_BACKWARD
      END IF

      n1 = plan%n_3d(1)
      n2 = plan%n_3d(2)
      n3 = plan%n_3d(3)

      nt = 1
!$OMP PARALLEL DEFAULT(NONE) SHARED(nt)
!$OMP MASTER
!$    nt = omp_get_num_threads()
!$OMP END MASTER
!$OMP END PARALLEL

      IF ((.NOT. fftw3_is_guru_supported()) .OR. &
          (.NOT. plan_style == 1) .OR. &
          (n1 < 256 .AND. n2 < 256 .AND. n3 < 256 .AND. nt == 1)) THEN
         ! If the plan type is MEASURE, PATIENT and EXHAUSTIVE or
         ! the grid size is small (and we are single-threaded) then
         ! FFTW3 does a better job than handmade optimization
         ! so plan a single 3D FFT which will execute using all the threads

         plan%separated_plans = .FALSE.
!$       CALL fftw_plan_with_nthreads(nt)

         IF (plan%fft_in_place) THEN
            plan%fftw_plan = fftw_plan_dft_3d(n3, n2, n1, zin, zin, fft_direction, fftw_plan_type)
         ELSE
            plan%fftw_plan = fftw_plan_dft_3d(n3, n2, n1, zin, zout, fft_direction, fftw_plan_type)
         END IF
      ELSE
         ALLOCATE (tmp(n1*n2*n3))
         ! ************************* PLANS WITH TRANSPOSITIONS ****************************
         !  In the cases described above, we manually thread each stage of the 3D FFT.
         !
         !  The following plans replace the 3D FFT call by running 1D FFTW across all
         !  3 directions of the array.
         !
         !  Output of FFTW is transposed to ensure that the next round of FFTW access
         !  contiguous information.
         !
         !  Assuming the input matrix is M(n3,n2,n1), FFTW/Transp are :
         !  M(n3,n2,n1) -> fftw(x) -> M(n3,n1,n2) -> fftw(y) -> M(n1,n2,n3) -> fftw(z) -> M(n1,n2,n3)
         !  Notice that last matrix is transposed in the Z axis. A DO-loop in the execute routine
         !  will perform the final transposition. Performance evaluation showed that using an external
         !  DO loop to do the final transposition performed better than directly transposing the output.
         !  However, this might vary depending on the compiler/platform, so a potential tuning spot
         !  is to perform the final transposition within the fftw library rather than using the external loop
         !  See comments below in Z-FFT for how to transpose the output to avoid the final DO loop.
         !
         !  Doc. for the Guru interface is in http://www.fftw.org/doc/Guru-Interface.html
         !
         !  OpenMP : Work is distributed on the Z plane.
         !           All transpositions are out-of-place to facilitate multi-threading
         !
         !!!! Plan for X : M(n3,n2,n1) -> fftw(x) -> M(n3,n1,n2)
         CALL fftw3_compute_rows_per_th(n3, nt, rows_per_th, rows_per_th_r, &
                                        th_planA, th_planB)

         dim_n(1) = n1
         dim_istride(1) = 1
         dim_ostride(1) = n2
         howmany_n(1) = n2
         howmany_n(2) = rows_per_th
         howmany_istride(1) = n1
         howmany_istride(2) = n1*n2
         howmany_ostride(1) = 1
         howmany_ostride(2) = n1*n2
         CALL fftw3_create_3d_plans(plan%fftw_plan_nx, plan%fftw_plan_nx_r, &
                                    dim_n, dim_istride, dim_ostride, howmany_n, &
                                    howmany_istride, howmany_ostride, &
                                    zin, tmp, &
                                    fft_direction, fftw_plan_type, rows_per_th, &
                                    rows_per_th_r)

         !!!! Plan for Y : M(n3,n1,n2) -> fftw(y) -> M(n1,n2,n3)
         CALL fftw3_compute_rows_per_th(n3, nt, rows_per_th, rows_per_th_r, &
                                        th_planA, th_planB)
         dim_n(1) = n2
         dim_istride(1) = 1
         dim_ostride(1) = n3
         howmany_n(1) = n1
         howmany_n(2) = rows_per_th
         howmany_istride(1) = n2
         howmany_istride(2) = n1*n2
         !!! transposed Z axis on output
         howmany_ostride(1) = n2*n3
         howmany_ostride(2) = 1

         CALL fftw3_create_3d_plans(plan%fftw_plan_ny, plan%fftw_plan_ny_r, &
                                    dim_n, dim_istride, dim_ostride, &
                                    howmany_n, howmany_istride, howmany_ostride, &
                                    tmp, zin, &
                                    fft_direction, fftw_plan_type, rows_per_th, &
                                    rows_per_th_r)

         !!!! Plan for Z : M(n1,n2,n3) -> fftw(z) -> M(n1,n2,n3)
         CALL fftw3_compute_rows_per_th(n1, nt, rows_per_th, rows_per_th_r, &
                                        th_planA, th_planB)
         dim_n(1) = n3
         dim_istride(1) = 1
         dim_ostride(1) = 1          ! To transpose: n2*n1
         howmany_n(1) = n2
         howmany_n(2) = rows_per_th
         howmany_istride(1) = n3
         howmany_istride(2) = n2*n3
         howmany_ostride(1) = n3     ! To transpose: n1
         howmany_ostride(2) = n2*n3  ! To transpose: 1

         CALL fftw3_create_3d_plans(plan%fftw_plan_nz, plan%fftw_plan_nz_r, &
                                    dim_n, dim_istride, dim_ostride, &
                                    howmany_n, howmany_istride, howmany_ostride, &
                                    zin, tmp, &
                                    fft_direction, fftw_plan_type, rows_per_th, &
                                    rows_per_th_r)

         plan%separated_plans = .TRUE.

         DEALLOCATE (tmp)
      END IF

#else
      MARK_USED(plan)
      MARK_USED(plan_style)
      !MARK_USED does not work with assumed size arguments
      IF (.FALSE.) THEN; DO; IF (ABS(zin(1)) > ABS(zout(1))) EXIT; END DO; END IF
#endif

   END SUBROUTINE fftw3_create_plan_3d

! **************************************************************************************************

! **************************************************************************************************
!> \brief ...
!> \param plan ...
!> \param plan_r ...
!> \param split_dim ...
!> \param nt ...
!> \param tid ...
!> \param input ...
!> \param istride ...
!> \param output ...
!> \param ostride ...
! **************************************************************************************************
   SUBROUTINE fftw3_workshare_execute_dft(plan, plan_r, split_dim, nt, tid, &
                                          input, istride, output, ostride)

      INTEGER, INTENT(IN)                           :: split_dim, nt, tid
      INTEGER, INTENT(IN)                           :: istride, ostride
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(INOUT) :: input, output
      TYPE(C_PTR)                                   :: plan, plan_r
#if defined(__FFTW3)
      INTEGER                                     :: i_off, o_off
      INTEGER                                     :: th_planA, th_planB
      INTEGER :: rows_per_thread, rows_per_thread_r

      CALL fftw3_compute_rows_per_th(split_dim, nt, rows_per_thread, &
                                     rows_per_thread_r, &
                                     th_planA, th_planB)

      IF (th_planB > 0) THEN
         IF (tid < th_planA) THEN
            i_off = (tid)*(istride*(rows_per_thread)) + 1
            o_off = (tid)*(ostride*(rows_per_thread)) + 1
            IF (rows_per_thread > 0) THEN
               CALL fftw_execute_dft(plan, input(i_off), &
                                     output(o_off))
            END IF
         ELSE IF ((tid - th_planA) < th_planB) THEN

            i_off = (th_planA)*istride*(rows_per_thread) + &
                    (tid - th_planA)*istride*(rows_per_thread_r) + 1
            o_off = (th_planA)*ostride*(rows_per_thread) + &
                    (tid - th_planA)*ostride*(rows_per_thread_r) + 1

            CALL fftw_execute_dft(plan_r, input(i_off), &
                                  output(o_off))
         END IF

      ELSE
         i_off = (tid)*(istride*(rows_per_thread)) + 1
         o_off = (tid)*(ostride*(rows_per_thread)) + 1

         CALL fftw_execute_dft(plan, input(i_off), &
                               output(o_off))

      END IF
#else
      MARK_USED(plan)
      MARK_USED(plan_r)
      MARK_USED(split_dim)
      MARK_USED(nt)
      MARK_USED(tid)
      MARK_USED(istride)
      MARK_USED(ostride)
      !MARK_USED does not work with assumed size arguments
      IF (.FALSE.) THEN; DO; IF (ABS(input(1)) > ABS(output(1))) EXIT; END DO; END IF
#endif

   END SUBROUTINE fftw3_workshare_execute_dft

! **************************************************************************************************

! **************************************************************************************************
!> \brief ...
!> \param plan ...
!> \param scale ...
!> \param zin ...
!> \param zout ...
!> \param stat ...
! **************************************************************************************************
   SUBROUTINE fftw33d(plan, scale, zin, zout, stat)

      TYPE(fft_plan_type), INTENT(IN)                      :: plan
      REAL(KIND=dp), INTENT(IN)                            :: scale
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(INOUT), TARGET:: zin
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(INOUT), TARGET:: zout
      INTEGER, INTENT(OUT)                                 :: stat
#if defined(__FFTW3)
      COMPLEX(KIND=dp), POINTER                            :: xout(:)
      COMPLEX(KIND=dp), ALLOCATABLE                        :: tmp1(:)
      INTEGER                                              :: n1, n2, n3
      INTEGER                                              :: tid, nt
      INTEGER                                              :: i, j, k

      n1 = plan%n_3d(1)
      n2 = plan%n_3d(2)
      n3 = plan%n_3d(3)

      stat = 1

      ! We use a POINTER to the output array to avoid duplicating code
      IF (plan%fft_in_place) THEN
         xout => zin(:n1*n2*n3)
      ELSE
         xout => zout(:n1*n2*n3)
      END IF

      ! Either compute the full 3D FFT using a multithreaded plan
      IF (.NOT. plan%separated_plans) THEN
         CALL fftw_execute_dft(plan%fftw_plan, zin, xout)
      ELSE
         ! Or use the 3 stage FFT scheme described in fftw3_create_plan_3d
         ALLOCATE (tmp1(n1*n2*n3))   ! Temporary vector used for transpositions
         !$OMP PARALLEL DEFAULT(NONE) PRIVATE(tid,nt,i,j,k) SHARED(zin,tmp1,n1,n2,n3,plan,xout)
         tid = 0
         nt = 1

!$       tid = omp_get_thread_num()
!$       nt = omp_get_num_threads()
         CALL fftw3_workshare_execute_dft(plan%fftw_plan_nx, plan%fftw_plan_nx_r, &
                                          n3, nt, tid, &
                                          zin, n1*n2, tmp1, n1*n2)

         !$OMP BARRIER
         CALL fftw3_workshare_execute_dft(plan%fftw_plan_ny, plan%fftw_plan_ny_r, &
                                          n3, nt, tid, &
                                          tmp1, n1*n2, xout, 1)
         !$OMP BARRIER
         CALL fftw3_workshare_execute_dft(plan%fftw_plan_nz, plan%fftw_plan_nz_r, &
                                          n1, nt, tid, &
                                          xout, n2*n3, tmp1, n2*n3)
         !$OMP BARRIER

         !$OMP DO COLLAPSE(3)
         DO i = 1, n1
            DO j = 1, n2
               DO k = 1, n3
                  xout((i - 1) + (j - 1)*n1 + (k - 1)*n1*n2 + 1) = &
                     tmp1((k - 1) + (j - 1)*n3 + (i - 1)*n3*n2 + 1)
               END DO
            END DO
         END DO
         !$OMP END DO

         !$OMP END PARALLEL
      END IF

      IF (scale /= 1.0_dp) THEN
         CALL zdscal(n1*n2*n3, scale, xout, 1)
      END IF

#else
      MARK_USED(plan)
      MARK_USED(scale)
      !MARK_USED does not work with assumed size arguments
      IF (.FALSE.) THEN; DO; IF (ABS(zin(1)) > ABS(zout(1))) EXIT; END DO; END IF
      stat = 0

#endif

   END SUBROUTINE fftw33d

! **************************************************************************************************

! **************************************************************************************************
!> \brief ...
!> \param plan ...
!> \param zin ...
!> \param zout ...
!> \param plan_style ...
! **************************************************************************************************
   SUBROUTINE fftw3_create_plan_1dm(plan, zin, zout, plan_style)

      IMPLICIT NONE

      TYPE(fft_plan_type), INTENT(INOUT)              :: plan
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(IN)         :: zin
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(IN)         :: zout
      INTEGER, INTENT(IN)                                :: plan_style
#if defined(__FFTW3)
      INTEGER                                            :: istride, idist, ostride, odist, num_threads, num_rows

      INTEGER :: fftw_plan_type
      SELECT CASE (plan_style)
      CASE (1)
         fftw_plan_type = FFTW_ESTIMATE
      CASE (2)
         fftw_plan_type = FFTW_MEASURE
      CASE (3)
         fftw_plan_type = FFTW_PATIENT
      CASE (4)
         fftw_plan_type = FFTW_EXHAUSTIVE
      CASE DEFAULT
         CPABORT("fftw3_create_plan_1dm")
      END SELECT

      num_threads = 1
      plan%separated_plans = .FALSE.
!$OMP PARALLEL DEFAULT(NONE), &
!$OMP          SHARED(NUM_THREADS)
!$OMP MASTER
!$    num_threads = omp_get_num_threads()
!$OMP END MASTER
!$OMP END PARALLEL

      num_rows = plan%m/num_threads
!$    plan%num_threads_needed = num_threads

! Check for number of rows less than num_threads
!$    IF (plan%m < num_threads) THEN
!$       num_rows = 1
!$       plan%num_threads_needed = plan%m
!$    END IF

! Check for total number of rows not divisible by num_threads
!$    IF (num_rows*plan%num_threads_needed /= plan%m) THEN
!$       plan%need_alt_plan = .TRUE.
!$    END IF

!$    plan%num_rows = num_rows
      istride = 1
      idist = plan%n
      ostride = 1
      odist = plan%n
      IF (plan%fsign == +1 .AND. plan%trans) THEN
         istride = plan%m
         idist = 1
      ELSEIF (plan%fsign == -1 .AND. plan%trans) THEN
         ostride = plan%m
         odist = 1
      END IF

      IF (plan%fsign == +1) THEN
         CALL dfftw_plan_many_dft(plan%fftw_plan, 1, plan%n, num_rows, zin, 0, istride, idist, &
                                  zout, 0, ostride, odist, FFTW_FORWARD, fftw_plan_type)
      ELSE
         CALL dfftw_plan_many_dft(plan%fftw_plan, 1, plan%n, num_rows, zin, 0, istride, idist, &
                                  zout, 0, ostride, odist, FFTW_BACKWARD, fftw_plan_type)
      END IF

!$    IF (plan%need_alt_plan) THEN
!$       plan%alt_num_rows = plan%m - (plan%num_threads_needed - 1)*num_rows
!$       IF (plan%fsign == +1) THEN
!$          CALL dfftw_plan_many_dft(plan%alt_fftw_plan, 1, plan%n, plan%alt_num_rows, zin, 0, istride, idist, &
!$                                   zout, 0, ostride, odist, FFTW_FORWARD, fftw_plan_type)
!$       ELSE
!$          CALL dfftw_plan_many_dft(plan%alt_fftw_plan, 1, plan%n, plan%alt_num_rows, zin, 0, istride, idist, &
!$                                   zout, 0, ostride, odist, FFTW_BACKWARD, fftw_plan_type)
!$       END IF
!$    END IF

#else
      MARK_USED(plan)
      MARK_USED(plan_style)
      !MARK_USED does not work with assumed size arguments
      IF (.FALSE.) THEN; DO; IF (ABS(zin(1)) > ABS(zout(1))) EXIT; END DO; END IF
#endif

   END SUBROUTINE fftw3_create_plan_1dm

! **************************************************************************************************
!> \brief ...
!> \param plan ...
! **************************************************************************************************
   SUBROUTINE fftw3_destroy_plan(plan)

      TYPE(fft_plan_type), INTENT(INOUT)   :: plan

#if defined(__FFTW3)
!$    IF (plan%need_alt_plan) THEN
!$       CALL fftw_destroy_plan(plan%alt_fftw_plan)
!$    END IF

      IF (.NOT. plan%separated_plans) THEN
         CALL fftw_destroy_plan(plan%fftw_plan)
      ELSE
         ! If it is a separated plan then we have to destroy
         ! each dim plan individually
         CALL fftw_destroy_plan(plan%fftw_plan_nx)
         CALL fftw_destroy_plan(plan%fftw_plan_ny)
         CALL fftw_destroy_plan(plan%fftw_plan_nz)
         CALL fftw_destroy_plan(plan%fftw_plan_nx_r)
         CALL fftw_destroy_plan(plan%fftw_plan_ny_r)
         CALL fftw_destroy_plan(plan%fftw_plan_nz_r)
      END IF

#else
      MARK_USED(plan)
#endif

   END SUBROUTINE fftw3_destroy_plan

! **************************************************************************************************
!> \brief ...
!> \param plan ...
!> \param zin ...
!> \param zout ...
!> \param scale ...
!> \param stat ...
! **************************************************************************************************
   SUBROUTINE fftw31dm(plan, zin, zout, scale, stat)
      TYPE(fft_plan_type), INTENT(IN)                    :: plan
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(IN), TARGET :: zin
      COMPLEX(KIND=dp), DIMENSION(*), INTENT(INOUT), &
         TARGET                                          :: zout
      REAL(KIND=dp), INTENT(IN)                          :: scale
      INTEGER, INTENT(OUT)                               :: stat

      INTEGER                                            :: in_offset, my_id, num_rows, out_offset, &
                                                            scal_offset
      TYPE(C_PTR)                                        :: fftw_plan

!------------------------------------------------------------------------------

      my_id = 0
      num_rows = plan%m

!$OMP PARALLEL DEFAULT(NONE), &
!$OMP          PRIVATE(my_id,num_rows,in_offset,out_offset,scal_offset,fftw_plan), &
!$OMP          SHARED(zin,zout), &
!$OMP          SHARED(plan,scale,stat)
!$    my_id = omp_get_thread_num()

!$    if (my_id < plan%num_threads_needed) then

         fftw_plan = plan%fftw_plan

         in_offset = 1
         out_offset = 1
         scal_offset = 1

!$       in_offset = 1 + plan%num_rows*my_id*plan%n
!$       out_offset = 1 + plan%num_rows*my_id*plan%n
!$       IF (plan%fsign == +1 .AND. plan%trans) THEN
!$          in_offset = 1 + plan%num_rows*my_id
!$       ELSEIF (plan%fsign == -1 .AND. plan%trans) THEN
!$          out_offset = 1 + plan%num_rows*my_id
!$       END IF
!$       scal_offset = 1 + plan%n*plan%num_rows*my_id
!$       IF (plan%need_alt_plan .AND. my_id == plan%num_threads_needed - 1) THEN
!$          num_rows = plan%alt_num_rows
!$          fftw_plan = plan%alt_fftw_plan
!$       ELSE
!$          num_rows = plan%num_rows
!$       END IF

#if defined(__FFTW3)
!$OMP MASTER
         stat = 1
!$OMP END MASTER
         CALL dfftw_execute_dft(fftw_plan, zin(in_offset:in_offset), zout(out_offset:out_offset))
!$    end if
! all theads need to meet at this barrier
!$OMP BARRIER
!$    if (my_id < plan%num_threads_needed) then
         IF (scale /= 1.0_dp) CALL zdscal(plan%n*num_rows, scale, zout(scal_offset:scal_offset), 1)
!$    end if

#else
      MARK_USED(plan)
      MARK_USED(scale)
      !MARK_USED does not work with assumed size arguments
      IF (.FALSE.) THEN; DO; IF (ABS(zin(1)) > ABS(zout(1))) EXIT; END DO; END IF
      stat = 0

!$    else
!$    end if

#endif

!$OMP END PARALLEL

      END SUBROUTINE fftw31dm

   END MODULE fftw3_lib
