Back to CFM home             Brown University



2D FFT program

      program dfft2d

      implicit none

      include 'mpif.h'

      double precision mflops, time
      double precision auxr, auxc, aux2
      double complex a, b, c
      double complex in, out
      integer comm_row, comm_col, comm_cart, ierror, nprocs, myproc
      integer iter, m, n, ndims, p, q, mp, nq, mpq, nqp, maxr, maxc
      integer i, j
      integer dims, naux
      logical periods, hor_dims, ver_dims
      integer ctxt, ip(40)
      integer ldb
      real atime(2), dtime_


      include '2dfftsizes'

      parameter (mp = m / p, nq = n / q, mpq = mp / q, nqp = nq / p)
      parameter (maxr = max (m, n) + 10, maxc = max (mpq, nqp) + 10)
      parameter (naux = 140000)
      parameter (in = (0.0d0, 0.0d0), out = (1.0d0, 0.0d0))

      dimension a(mp,nq), b(maxr*maxc), c(maxr*maxc)
      dimension auxr(naux), auxc(naux), aux2(naux)
      dimension dims(ndims), periods(ndims)
      dimension hor_dims(ndims), ver_dims(ndims)

      data (dims(i), i=1,2) / p, q /
      data (periods(i), i=1,2) /.false., .false. /
      data (hor_dims(i), i=1,2) /.false., .true. /
      data (ver_dims(i), i=1,2) /.true., .false. /

      call MPI_INIT (ierror)
      call MPI_COMM_SIZE (MPI_COMM_WORLD, nprocs, ierror)
      call MPI_COMM_RANK (MPI_COMM_WORLD, myproc, ierror)

      call MPI_DIMS_CREATE (nprocs, ndims, dims, ierror)
      call MPI_CART_CREATE (MPI_COMM_WORLD, ndims, dims, periods, .true.
     $     ,comm_cart, ierror)  
      call MPI_CART_SUB(comm_cart, hor_dims, comm_row, ierror)
      call MPI_CART_SUB(comm_cart, ver_dims, comm_col, ierror)


c     initialize a
      call initl (a, mp, mp, nq, in, myproc)

c     initialize ffts
      call init_essl_2dfft (c, m, n, mpq, nqp, auxr, auxc, aux2, naux)

      call MPI_BARRIER (MPI_COMM_WORLD, ierror)

      if (myproc .eq. 0) time = MPI_WTIME ()

      do i = 1, iter
         call essl_2dfft (a, b, c, m, n, mp, nq, mpq, nqp, auxc, auxr,
     $        aux2, naux, comm_row, comm_col, p, q, 'L')
      enddo

      call MPI_BARRIER (MPI_COMM_WORLD, ierror)

      if (myproc .eq. 0) then
         time = MPI_WTIME () - time
         mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $        log (dble(m))) * iter / (log (2.0) * time))
         write (6,*) '2dfft achieves ', mflops, ' Mflop/s'
         write (6,*) '(', iter, ' iterations in ', time, ' secs)'
         write (6,*) 'Using unit-stride loads for copying'
      endif

c     check the result
c$$$      call fft_res (a, mp, mp, nq, out, myproc)
      if (myproc .eq. 0) write (6,*)
     $     '************************************'
 
c     initialize a
      call initl (a, mp, mp, nq, in, myproc)

      call MPI_BARRIER (MPI_COMM_WORLD, ierror)

      if (myproc .eq. 0) time = MPI_WTIME ()

      do i = 1, iter
         call essl_2dfft (a, b, c, m, n, mp, nq, mpq, nqp, auxc, auxr,
     $        aux2, naux, comm_row, comm_col, p, q, 'S')
      enddo

      call MPI_BARRIER (MPI_COMM_WORLD, ierror)

      if (myproc .eq. 0) then
         time = MPI_WTIME () - time
         mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $        log (dble(m))) * iter / (log (2.0) * time))
         write (6,*) '2dfft achieves ', mflops, ' Mflop/s'
         write (6,*) '(', iter, ' iterations in ', time, ' secs)'
         write (6,*) 'Using unit-stride stores for copying'
      endif
c     check the result
c$$$      call fft_res (a, mp, mp, nq, out, myproc)
      if (myproc .eq. 0) write (6,*)
     $     '************************************'
 
c**************************************************************************
c**************************************************************************
c     Row Algorithms
c**************************************************************************
c**************************************************************************

      if (p .eq. 1) then

c     initialize ffts
         call init_row_essl_2dfft (b, m, n, mpq, nqp, auxr, auxc,
     $     aux2, naux, 'N') 

c     initialize a
         call initl (a, mp, mp, nq, in, myproc)
         
         call MPI_BARRIER (MPI_COMM_WORLD, ierror)

         if (myproc .eq. 0) time = MPI_WTIME ()

         do i = 1, iter
            call row_essl_2dfft (a, b, c, m, n, mp, nq, mpq, nqp, auxc,
     $           auxr, aux2, naux, comm_row, comm_col, p, q, 'N')
         enddo
         
         call MPI_BARRIER (MPI_COMM_WORLD, ierror)

         if (myproc .eq. 0) then
            time = MPI_WTIME () - time
            mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $           log (dble(m))) * iter / (log (2.0) * time))
            write (6,*) 'row_2dfft achieves ', mflops, ' Mflop/s'
            write (6,*) '(', iter, ' iterations in ', time, ' secs)'
         endif
c     check the result
c$$$         call fft_res (a, mp, mp, nq, out, myproc)
         if (myproc .eq. 0) write (6,*)
     $        '************************************'

c     initialize ffts
         call init_row_essl_2dfft (b, m, n, mpq, nqp, auxr, auxc,
     $     aux2, naux, 'T') 

c     initialize a
         call initl (a, mp, mp, nq, in, myproc)
         
         call MPI_BARRIER (MPI_COMM_WORLD, ierror)

         if (myproc .eq. 0) time = MPI_WTIME ()

         do i = 1, iter
            call row_essl_2dfft (a, b, c, m, n, mp, nq, mpq, nqp, auxc,
     $           auxr, aux2, naux, comm_row, comm_col, p, q, 'T')
         enddo
         
         call MPI_BARRIER (MPI_COMM_WORLD, ierror)

         if (myproc .eq. 0) then
            time = MPI_WTIME () - time
            mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $           log (dble(m))) * iter / (log (2.0) * time))
            write (6,*) '"transpose" row_2dfft achieves ', mflops,
     $           ' Mflop/s'
            write (6,*) '(', iter, ' iterations in ', time, ' secs)'
         endif
c     check the result
c$$$         call fft_res (b, n, n, mpq, out, myproc)
         if (myproc .eq. 0) write (6,*)
     $        '************************************'


c**************************************************************************
c**************************************************************************
c     PESSL PDCFT2
c**************************************************************************
c**************************************************************************


         call BLACS_GET (0, 0 , ctxt)
         call BLACS_GRIDINIT (ctxt, 'Row-major', p, q)

c**************************************************************************
c     in-place Normal PDCFT2
c**************************************************************************
c     initialize a
         call initl (a, mp, mp, nq, in, myproc)

         call MPI_BARRIER (MPI_COMM_WORLD, ierror)

c     set the values for the ip array
         ip(1) = 1
         ip(2) = 1
         ip(20) = 0
         ip(21) = 0

         if (myproc .eq. 0) time = MPI_WTIME ()

         do i = 1, iter
            call PDCFT2 (a, a, m, n, 1, 1.0d0, ctxt, ip)
         enddo

         call MPI_BARRIER (MPI_COMM_WORLD, ierror)

         if (myproc .eq. 0) then
            time = MPI_WTIME () - time
            mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $           log (dble(m))) * iter / (log (2.0) * time))
            write (6,*) 'in-place pdcft2 achieves ', mflops, ' Mflop/s'
            write (6,*) '(', iter, ' iterations in ', time, ' secs)'
         endif
c     check the result
c$$$         call fft_res (a, mp, mp, nq, out, myproc)
         if (myproc .eq. 0) write (6,*)
     $        '************************************'

c**************************************************************************
c     Normal -> Normal PDCFT2
c**************************************************************************
c     initialize a
         call initl (a, mp, mp, nq, in, myproc)

         call MPI_BARRIER (MPI_COMM_WORLD, ierror)

c     set the values for the ip array
         ip(1) = 1
         ip(2) = 1
         ip(20) = 0
         ip(21) = 0

         if (myproc .eq. 0) time = MPI_WTIME ()

         do i = 1, iter
            call PDCFT2 (a, b, m, n, 1, 1.0d0, ctxt, ip)
         enddo

         call MPI_BARRIER (MPI_COMM_WORLD, ierror)

         if (myproc .eq. 0) then
            time = MPI_WTIME () - time
            mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $           log (dble(m))) * iter / (log (2.0) * time))
            write (6,*) 'pdcft2 achieves ', mflops, ' Mflop/s'
            write (6,*) '(', iter, ' iterations in ', time, ' secs)'
         endif
c     check the result
c$$$         call fft_res (b, mp, mp, nq, out, myproc)
         if (myproc .eq. 0) write (6,*)
     $        '************************************'

c**************************************************************************
c     Normal -> Transposed PDCFT2
c**************************************************************************
c     initialize a
         call initl (a, mp, mp, nq, in, myproc)

         call MPI_BARRIER (MPI_COMM_WORLD, ierror)

c     default values for the ip array
         ip(1) = 0

         if (myproc .eq. 0) time = MPI_WTIME ()

         do i = 1, iter
            call PDCFT2 (a, b, m, n, 1, 1.0d0, ctxt, ip)
         enddo

         call MPI_BARRIER (MPI_COMM_WORLD, ierror)

         if (myproc .eq. 0) then
            time = MPI_WTIME () - time
            mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $           log (dble(m))) * iter / (log (2.0) * time))
            write (6,*) '"transpose" pdcft2 achieves ', mflops,
     $           ' Mflop/s'
            write (6,*) '(', iter, ' iterations in ', time, ' secs)'
         endif
c     check the result
c$$$         call fft_res (b, nq, nq, mp, out, myproc)
         if (myproc .eq. 0) write (6,*)
     $        '************************************'

         call BLACS_EXIT (0)
      endif

c**************************************************************************
c**************************************************************************
c     Single CPU 2-D FFTs
c**************************************************************************
c**************************************************************************

      if (nprocs .eq. 1) then

c**************************************************************************
c     in-place Dual DCFTs
c**************************************************************************
c     initialize a
         call initl (a, m, m, n, in, myproc)
c     initialize DCFT2
         call DCFT(1, a, 1, m, a, 1, m, m, n, 1, 1.0d0, auxc, naux,
     $        aux2,naux);
         call DCFT(1, a, m, 1, a, m, 1, n, m, 1, 1.0d0, auxr, naux,
     $        aux2,naux);
c     timing loop
         time = dtime_(atime)
         do i = 1, iter
            call DCFT(0, a, 1, m, a, 1, m, m, n, 1, 1.0d0, auxc, naux,
     $           aux2, naux);
            call DCFT(0, a, m, 1, a, m, 1, n, m, 1, 1.0d0, auxr, naux,
     $           aux2, naux);
         enddo
         time = dtime_(atime)
c     performance
         mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $        log (dble(m))) * iter / (log (2.0) * time))
         write (6,*) 'in-place dual dcft achieve ', mflops, ' Mflop/s'
         write (6,*) '(', iter, ' iterations in ', time, ' secs)'
c     check the result
c$$$         call fft_res (a, m, m, n, out, myproc)
         write (6,*) '************************************'
 
c**************************************************************************
c     Normal Dual DCFTs
c**************************************************************************
c     initialize a
         call initl (a, m, m, n, in, myproc)
c     Stride setting
         call stride (n, m, ldb, 'Z', 0)
         write (6,*) 'The leading dimension is ', ldb, ' for  m = ', m
c     initialize DCFT2
         call DCFT(1, a, 1, m, a, 1, m, m, n, 1, 1.0d0, auxc, naux,
     $        aux2,naux);
         call DCFT(1, a, m, 1, b, ldb, 1, n, m, 1, 1.0d0, auxr, naux,
     $        aux2,naux);
c     timing loop
         time = dtime_(atime)
         do i = 1, iter
            call DCFT(0, a, 1, m, a, 1, m, m, n, 1, 1.0d0, auxc, naux,
     $           aux2,naux);
            call DCFT(0, a, m, 1, b, ldb, 1, n, m, 1, 1.0d0, auxr, naux,
     $           aux2,naux);
         enddo
         time = dtime_(atime)
c     performance
         mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $        log (dble(m))) * iter / (log (2.0) * time))
         write (6,*) 'dual dcft achieve ', mflops, ' Mflop/s'
         write (6,*) '(', iter, ' iterations in ', time, ' secs)'
c     check the result
c$$$         call fft_res (b, ldb, m, n, out, myproc)
         write (6,*) '************************************'
 
c**************************************************************************
c     in-place Normal DCFT2 
c**************************************************************************
c     initialize a
         call initl (a, m, m, n, in, myproc)
c     initialize DCFT2
         call DCFT2 (1, a, 1, m, a, 1, m, m, n, 1, 1.0d0, auxc, naux,
     $        aux2,naux)
c     timing loop
         time = dtime_(atime)
         do i = 1, iter
            call DCFT2 (0, a, 1, m, a, 1, m, m, n, 1, 1.0d0, auxc,
     $           naux, aux2, naux)
         enddo
         time = dtime_(atime)
c     performance
         mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $        log (dble(m))) * iter / (log (2.0) * time))
         write (6,*) 'in-place dcft2 achieves ', mflops, ' Mflop/s'
         write (6,*) '(', iter, ' iterations in ', time, ' secs)'
c     check the result
c$$$         call fft_res (a, m, m, n, out, myproc)
         write (6,*) '************************************'
 
c**************************************************************************
c     Normal -> Normal DCFT2 
c**************************************************************************
c     initialize a
         call initl (a, m, m, n, in, myproc)
c     stride setting
         call stride (n, m, ldb, 'Z', 0)
         write (6,*) 'The leading dimension is ', ldb, ' for  m = ', m
c     initialize DCFT2
         call DCFT2 (1, a, 1, m, b, 1, ldb, m, n, 1, 1.0d0, auxc, naux,
     $        aux2,naux)
c     timing loop
         time = dtime_(atime)
         do i = 1, iter
            call DCFT2 (0, a, 1, m, b, 1, ldb, m, n, 1, 1.0d0, auxc,
     $           naux,aux2, naux)
         enddo
         time = dtime_(atime)
c     performance
         mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $        log (dble(m))) * iter / (log (2.0) * time))
         write (6,*) 'dcft2 achieves ', mflops, ' Mflop/s'
         write (6,*) '(', iter, ' iterations in ', time, ' secs)'
c     check the result
c$$$         call fft_res (b, ldb, m, n, out, myproc)
         write (6,*) '************************************'
 
c**************************************************************************
c     Normal -> Transposed DCFT2
c**************************************************************************
c     stride setting
         call stride (m, n, ldb, 'Z', 0)
         write (6,*) 'The leading dimension is ', ldb, ' for  n = ', n
c     initialize DCFT2
         call DCFT2 (1, a, 1, m, b, ldb, 1, m, n, 1, 1.0d0, auxc, naux,
     $        aux2,naux)
c     timing loop
         time = dtime_(atime)
         do i = 1, iter
            call DCFT2 (0, a, 1, m, b, ldb, 1, m, n, 1, 1.0d0, auxc,
     $           naux,aux2, naux)
         enddo
         time = dtime_(atime)
c     performance
         mflops = 1.0d-6 * (5 * m * n * (log (dble (n)) + 
     $        log (dble(m))) * iter / (log (2.0) * time))
         write (6,*) '"transpose" dfct2 achieves ', mflops, ' Mflop/s'
         write (6,*) '(', iter, ' iterations in ', time, ' secs)'
c     check the result
c$$$         call fft_res (b, ldb, n, m, out, myproc)
         write (6,*) '************************************'

      endif

      call MPI_FINALIZE (ierror)
      end