Back to CFM home             Brown University



2D FFT routines

c**************************************************************************
c     Rearranging the data received in the proper order to be
c     transformed or the data transformed in the proper order to be sent

      subroutine rearrange (b, c, dir_procs, dir, blocksize, loc_trans)

      implicit none

      double complex b(*), c(*)
      integer blocksize, dir_procs, loc_trans
      character dir
      integer i, j, bindex, cindex, trans_length, block_length
      
      trans_length = dir_procs * blocksize
      block_length = loc_trans * blocksize
      if (dir .eq. 'F') then
         do i = 1, dir_procs
            do j = 1, loc_trans
               bindex = (i - 1) * block_length + (j - 1) * blocksize + 1
               cindex = (j - 1) * trans_length + (i - 1) * blocksize + 1
               call zcopy (blocksize, b(bindex), 1, c(cindex), 1)
            enddo
         enddo
      else
         do i = 1, dir_procs
            do j = 1, loc_trans
               bindex = (i - 1) * block_length + (j - 1) * blocksize + 1
               cindex = (j - 1) * trans_length + (i - 1) * blocksize + 1
               call zcopy (blocksize, c(cindex), 1, b(bindex), 1)
             enddo
         enddo
      endif

      return
      end

c**************************************************************************
c     block packing for sends (before the FFT) in row-based algorithms

      subroutine bpack (a, b, mp, nq, q, mpq)

      implicit none

      integer mp, nq, q, mpq
      double complex a(mp,nq), b(*)
      integer i, j

      do i = 1, q
         do j = 1, nq
            call zcopy (mpq, a((i - 1) * mpq + 1, j), 1, 
     $           b((i - 1) * mpq * nq + (j - 1) * mpq + 1), 1)
         enddo
      enddo

      return
      end

c**************************************************************************
c     block unpacking for receives (after the FFT) in row-based
c     algorithms

      subroutine bunpack (b, a, mp, nq, q, mpq)

      implicit none

      integer mp, nq, q, mpq
      double complex a(mp,nq), b(*)
      integer i, j

      do i = 1, q
         do j = 1, nq
            call zcopy (mpq, b((i - 1) * mpq * nq + (j - 1) * mpq + 1),
     $           1, a((i - 1) * mpq + 1, j), 1)
         enddo
      enddo

      return
      end


c**************************************************************************
c     packing for block-based algorithms


      subroutine rpack (a, b, mp, nq, unit_stride)

      implicit none

      integer mp, nq
      double complex a(mp,nq), b(*)
      character unit_stride
      integer i, j
      
      if (unit_stride .eq. 'L') then
         do j = 1, nq
            do i = 1, mp
               b((i - 1) * nq + j) = a(i, j)
            enddo
         enddo
      else
         do i = 1, mp
            do j = 1, nq
               b((i - 1) * nq + j) = a(i, j)
            enddo
         enddo
      endif

      return
      end

c**************************************************************************
c     unpacking for block-based algorithms

      subroutine runpack (b, a, mp, nq, unit_stride)

      implicit none

      integer mp, nq
      double complex a(mp,nq), b(*)
      character unit_stride
      integer i, j
      
      if (unit_stride .eq. 'L') then
         do i = 1, mp
            do j = 1, nq
               a(i, j) = b((i - 1) * nq + j)
            enddo
         enddo
      else
         do j = 1, nq
            do i = 1, mp
               a(i, j) = b((i - 1) * nq + j)
            enddo
         enddo
      endif

      return
      end

c**************************************************************************
c     Error checking routine
c     call as check_error (MPI_COMM_WORLD, ierror, 'MPI function name')
c     useless as MPICH handles errors internally

      subroutine check_error (comm, ierror, funct)

      implicit none

      include 'mpif.h'

      integer comm, ierror, lierror
      character*20 funct

      if (ierror .ne. 0) then
         write (6,*) 'Error number ', ierror, ' in ', funct
         write (6,*) 'Exiting'
         call MPI_ABORT (comm, ierror, lierror)
         stop
      endif
      return
      end

c**************************************************************************
c     Array initialization routine

      subroutine initl (a, lda, mp, nq, value, myproc)

      implicit none

c     arguments
      double complex a(*), value
      integer lda, mp, nq, myproc
c     local
      integer i, j
 
      do j = 1, nq
         do i= 1, mp
            a((j - 1) * lda + i) = value
         enddo
      enddo
      if (myproc .eq. 0) a(1) = (1.0d0, 0.0d0)

      return
      end

c**************************************************************************
c     Checking of results

      subroutine fft_res (a, lda, mp, nq, value, myproc)

      implicit none 

c     arguments
      double complex a(*), value
      integer lda, mp, nq, myproc
c     local
      integer i, j

      do j = 1, nq
         do i = 1, mp
            if (a((j - 1) * lda + i) .ne. value) write (6,*)
     $           'In processor ',myproc, ' element(', i, ',', j, ') = ',
     $           a((j - 1) * lda + i)
         enddo
      enddo

      return
      end

c**************************************************************************
c     Initialization for the block-based algorithms

      subroutine init_essl_2dfft (b, m, n, mpq, nqp, auxr, auxc, aux2,
     $     naux) 

      implicit none

      double complex b(*)
      double precision auxr(*), auxc(*), aux2(*)
      integer m, n, mpq, nqp, naux

      call DCFT(1, b, 1, m, b, 1, m, m, nqp, 1, 1.0d0, auxc, naux, aux2,
     $     naux);
      call DCFT(1, b, 1, n, b, 1, n, n, mpq, 1, 1.0d0, auxr, naux, aux2,
     $     naux);

      return
      end

c**************************************************************************
c     Initialization for the row-based algorithms

      subroutine init_row_essl_2dfft (b, m, n, mpq, nqp, auxr, auxc,
     $     aux2, naux, transpose) 

      implicit none

      double complex b(*)
      double precision auxr(*), auxc(*), aux2(*)
      integer m, n, mpq, nqp, naux
      character transpose

      call DCFT(1, b, 1, m, b, 1, m, m, nqp, 1, 1.0d0, auxc, naux, aux2,
     $     naux);
      if (transpose .eq. 'N') then
         call DCFT(1, b, mpq, 1, b, mpq, 1, n, mpq, 1, 1.0d0, auxr,
     $        naux, aux2, naux);
      else
         call DCFT(1, b, mpq, 1, b, 1, n, n, mpq, 1, 1.0d0, auxr,
     $        naux, aux2, naux);
      endif

      return
      end

c**************************************************************************
c     block-based FFT routine

      subroutine essl_2dfft (a, b, c, m, n, mp, nq, mpq, nqp, auxc,
     $     auxr, aux2, naux, comm_row, comm_col, p, q, unit_stride)

      implicit none

      include 'mpif.h'

      integer m, n, mp, nq, mpq, nqp, naux
      integer comm_row, comm_col, p, q
      double complex a(mp,nq), b(*), c(*)
      double precision auxr(*), auxc(*), aux2(*)
      character unit_stride
      integer ierror, exsize
      integer i, j, k, myproc, inc

      if (p .ne. 1) then
         exsize = mp * nqp
         call MPI_ALLTOALL(a, exsize, MPI_DOUBLE_COMPLEX, b, exsize,
     $        MPI_DOUBLE_COMPLEX, comm_col, ierror)

         call rearrange (b, c, p, 'Forward', mp, nqp)

         call DCFT(0, c, 1, m, c, 1, m, m, nqp, 1, 1.0d0, auxc, naux,
     $        aux2,naux);

         call rearrange (b, c, p, 'Backward', mp, nqp)
         
         call MPI_ALLTOALL(b, exsize, MPI_DOUBLE_COMPLEX, a, exsize,
     $        MPI_DOUBLE_COMPLEX, comm_col, ierror)
      else
         call DCFT(0, a, 1, m, a, 1, m, m, nqp, 1, 1.0d0, auxc, naux,
     $        aux2,naux);
      endif

c      call MPI_BARRIER (MPI_COMM_WORLD, ierror)

      exsize = nq * mpq

      call rpack (a, b, mp, nq, unit_stride)

      call MPI_ALLTOALL(b, exsize, MPI_DOUBLE_COMPLEX, c, exsize,
     $     MPI_DOUBLE_COMPLEX, comm_row, ierror)

      call rearrange (c, b, q, 'Forward', nq, mpq)

      call DCFT(0, b, 1, n, b, 1, n, n, mpq, 1, 1.0d0, auxr, naux, aux2,
     $     naux);

      call rearrange (c, b, q, 'Backward', nq, mpq)

      call MPI_ALLTOALL(c, exsize, MPI_DOUBLE_COMPLEX, b, exsize,
     $     MPI_DOUBLE_COMPLEX, comm_row, ierror)

      call runpack (b, a, mp, nq, unit_stride)

c      call MPI_BARRIER (MPI_COMM_WORLD, ierror)

      return
      end

c**************************************************************************
c     row-based FFT routine

      subroutine row_essl_2dfft (a, b, c, m, n, mp, nq, mpq, nqp, auxc,
     $     auxr, aux2, naux, comm_row, comm_col, p, q, y_orient)

      implicit none

      include 'mpif.h'

      integer m, n, mp, nq, mpq, nqp, naux
      integer comm_row, comm_col, p, q
      double complex a(mp,nq), b(*), c(*)
      double precision auxr(*), auxc(*), aux2(*)
      character y_orient
      integer ierror, exsize
      integer i, j, k, myproc, inc
      
      call DCFT(0, a, 1, m, a, 1, m, m, nqp, 1, 1.0d0, auxc, naux, aux2,
     $     naux);

      exsize = nq * mpq

      call bpack (a, b, mp, nq, q, mpq)

      call MPI_ALLTOALL(b, exsize, MPI_DOUBLE_COMPLEX, c, exsize,
     $     MPI_DOUBLE_COMPLEX, comm_row, ierror)

      if (y_orient .eq. 'N') then
         call DCFT(0, c, mpq, 1, c, mpq, 1, n, mpq, 1, 1.0d0, auxr,
     $        naux, aux2, naux);

         call MPI_ALLTOALL(c, exsize, MPI_DOUBLE_COMPLEX, b, exsize,
     $        MPI_DOUBLE_COMPLEX, comm_row, ierror)
         
         call bunpack (b, a, mp, nq, q, mpq)
      else
         call DCFT(0, c, mpq, 1, b, 1, n, n, mpq, 1, 1.0d0, auxr,
     $        naux, aux2, naux);
      endif

c      call MPI_BARRIER (MPI_COMM_WORLD, ierror)

      return
      end