Back to CFM home             Brown University



matrix-vector multiply routines

c**************************************************************************
c     Checking the result
c     This assumes a particular initialization of the array A
c     in mxv.f

      subroutine check_mxv (c, mp, n, prank)

      implicit none

      integer mp, n, prank
      double precision c(*)
      integer i
      double precision difr, correct

      do i = 1, mp
         correct = (n * (n + 1)) / 2 + (prank * mp + i - 1) * n
         difr = c(i) - correct
         if (abs(difr) .gt. 1d-15) then
            write (6,*) 'c(', i, ') = ', c(i), ' instead of
     $           ', correct, ' - difference is ', difr
            write (6,*) 'Error in mxv! Exiting!'
            stop
         endif
      enddo
      return
      end

c**************************************************************************
c     Error checking routine
c     call as check_error (MPI_COMM_WORLD, ierror, 'MPI function name')
c     useless as MPICH handles errors internally

      subroutine check_error (comm, ierror, funct)

      implicit none

      include 'mpif.h'

      integer comm, ierror, lierror
      character*20 funct

      if (ierror .ne. 0) then
         write (6,*) 'Error number ', ierror, ' in ', funct
         write (6,*) 'Exiting'
         call MPI_ABORT (comm, ierror, lierror)
         stop
      endif
      return
      end


c**************************************************************************
c     Fortran loop based matrix-vector product


      subroutine mxv (m, n, a, b, lc, gc, comm) 

      implicit none

      include 'mpif.h'
      
      double precision a(m,*), b(*), lc(*), gc(*)
      integer m, n, comm, i, j, ierror

c     local matrix-vector product (doubly unrolled loops)
c     startup loop
      do i = 1, m, 2
         lc(i) = a(i, 1) * b(1)
         lc(i + 1) = a(i + 1, 1) * b(1)
      enddo
c     main loop
      do j = 2, n
         do i = 1, m, 2
            lc(i) = lc(i) + a(i, j) * b(j)
            lc(i + 1) = lc(i + 1) + a(i + 1,j) * b(j)
         enddo
      enddo
c     cleanup loop
      if (mod (m, 2) .eq. 1) then
         lc(m) = a(m, 1) * b(1)
         do j = 2, n
            lc(m) = lc(m) + a(m, j) * b(j)
         enddo
      endif
c     Sum local vectors across processor rows to get global vectors
      call MPI_ALLREDUCE (lc, gc, m, MPI_DOUBLE_PRECISION, MPI_SUM,
     $     comm, ierror)

      return 
      end


c**************************************************************************
c     ESSL based matrix-vector product


      subroutine blas_mxv (m, n, a, b, lc, gc, comm) 

      implicit none

      include 'mpif.h'

      double precision a(m,*), b(*), lc(*), gc(*)
      integer m, n, comm, i, ierror

c     local matrix-vector product
      call dgemv ('N', m, n, 1.0d0, a, m, b, 1, 0.0d0, lc, 1)
c     Sum local vectors across processor rows to get global vectorsdot product
      call MPI_ALLREDUCE (lc, gc, m, MPI_DOUBLE_PRECISION, MPI_SUM,
     $     comm, ierror)

      return 
      end


c**************************************************************************
c     PESSL based matrix-vector product


      subroutine pblas_mxv (mglobal, nglobal, a, desca, b, descb, c,
     $     descc)

      implicit none

      integer desca(*), descb(*), descc(*), mglobal, nglobal
      double precision a(desca(8),*), b(*), c(*) 

      call PDGEMV ('No Transpose', mglobal, nglobal, 1.0d0, a, 1, 1,
     $     desca, b, 1, 1, descb, 1, 0.0d0, c, 1, 1, descc, 1)
      return
      end

c**************************************************************************
c     SCALAPACK support routine 

      subroutine descinit (desc, m, n, mb, nb, rsrc, csrc, ctxt, lda)
      
      implicit none

      integer desc(8), m, n, mb, nb, rsrc, csrc, ctxt, lda

c     minimal error checking

      if (m .gt. 0) desc(1) = m
      if (n .gt. 0) desc(2) = n
      if (mb .gt. 0) desc(3) = mb
      if (nb .gt. 0) desc(4) = nb
      if (rsrc .ge. 0) desc(5) = rsrc
      if (csrc .ge. 0) desc(6) = csrc

      desc(7) = ctxt
      desc(8) = lda
      return
      end