module vector_type
       type vector
          integer :: m
          real, dimension(:), pointer :: vData
       end type vector

       interface operator(*)
          module procedure multiplyv
          module procedure multiplycv
          module procedure multiplyvc
       end interface

       interface operator(/)
          module procedure dividev
       end interface

       interface operator(+)
          module procedure addv
       end interface

       interface operator(-)
          module procedure subtractv
       end interface

       interface init
          module procedure initv
       end interface

       interface dump
          module procedure dumpv
       end interface

contains

       subroutine initv(v1, m)
              integer, intent(in)        :: m
              type(vector), intent(out)  :: v1

              v1%m = m

              allocate(v1%vData(m))
       end subroutine initv

       function multiplyv(v1, v2)
              type(vector), intent(in) :: v1, v2
              real :: multiplyv
              integer :: i

              multiplyv = 0

              if(v1%m/=v2%m) then
                 write(*,*) "Error in vector_type:multiply"
                 return
              end if

              do i=1,v1%m
                 multiplyv = multiplyv + v1%vData(i)*v2%vData(i)
              end do
       end function multiplyv

       function multiplyvc(v1, c)
              type(vector), intent(in)   :: v1
              real, intent(in)           :: c
              type(vector)               :: multiplyvc
              integer                    :: i

              call init(multiplyvc, v1%m)

              do i=1,v1%m
                 multiplyvc%vData(i) = v1%vData(i)*c
              end do
       end function multiplyvc

       function multiplycv(c, v1)
              type(vector), intent(in)   :: v1
              real, intent(in)           :: c
              type(vector)               :: multiplycv
              integer                    :: i

              call init(multiplycv, v1%m)

              do i=1,v1%m
                 multiplycv%vData(i) = v1%vData(i)*c
              end do
       end function multiplycv

       function dividev(v1, c)
              type(vector), intent(in)   :: v1
              real, intent(in)           :: c
              type(vector)               :: dividev
              integer                    :: i

              call init(dividev, v1%m)

              do i=1,v1%m
                 dividev%vData(i) = v1%vData(i)/c
              end do
       end function dividev

       function addv(v1, v2)
              type(vector), intent(in)   :: v1,v2
              type(vector)               :: addv
              integer                    :: i

              call init(addv, v1%m)

              if(v1%m/=v2%m) then
                 write(*,*) "Error in vector_type:add"
                 return
              end if

              do i=1,v1%m
                 addv%vData(i) = v1%vData(i) + v2%vData(i)
              end do
       end function addv

       function subtractv(v1, v2)
              type(vector), intent(in)   :: v1,v2
              type(vector)               :: subtractv
              integer                    :: i

              call init(subtractv, v1%m)

              if(v1%m/=v2%m) then
                 write(*,*) "Error in vector_type:subtract"
                 return
              end if

              do i=1,v1%m
                 subtractv%vData(i) = v1%vData(i) - v2%vData(i)
              end do
       end function subtractv

       subroutine dumpv(v1)
              type(vector), intent(in)   :: v1
              integer                    :: i
10            format(e6.2e1,' ')
              do i=1,v1%m
                 write(*,10,advance='no') v1%vData(i)
              end do
              write(*,'(/)',advance='no')
       end subroutine dumpv
end module vector_type

module matrix_type
       use vector_type

       type matrix
          integer :: m,n
          real, pointer, dimension(:,:) :: mData
       end type matrix

       interface operator(*)
          module procedure multiplym
       end interface

       interface operator(+)
          module procedure addm
       end interface

       interface operator(-)
          module procedure subtractm
       end interface

       interface operator(**)
          module procedure inverse
       end interface

       interface init
          module procedure initm
       end interface

       interface row
          module procedure rowm
       end interface

       interface column
          module procedure columnm
       end interface

       interface dump
          module procedure dumpm
       end interface

contains

       subroutine initm(m1, m, n)
              integer, intent(in)        :: m,n
              type(matrix), intent(out)  :: m1

              m1%m = m
              m1%n = n

              allocate(m1%mData(m,n))
       end subroutine initm

       function rowm(m1, m)
              type(vector) :: rowm
              type(matrix), intent(in) :: m1
              integer, intent(in) :: m
              integer :: j

              call init(rowm, m1%n)

              do j=1,m1%n
                 rowm%vData(j) = m1%mData(m, j)
              end do
       end function rowm

       function columnm(m1, n)
              type(vector) :: columnm
              type(matrix), intent(in) :: m1
              integer, intent(in) :: n
              integer :: i

              call init(columnm, m1%m)

              do i=1,m1%m
                 columnm%vData(i) = m1%mData(i, n)
              end do
       end function columnm

       subroutine setrow(m1, m, v1)
              type(matrix), intent(inout)   :: m1
              integer, intent(in)           :: m
              type(vector), intent(in)      :: v1
              integer                       :: j

              if(v1%m/=m1%n.or.m>m1%m) then
                 write(*,*) "Error in matrix_type:setrow"
                 return
              end if

              do j=1,v1%m
                 m1%mData(m,j) = v1%vData(j)
              end do
       end subroutine setrow

       subroutine setcolumn(m1, n, v1)
              type(matrix), intent(inout)   :: m1
              integer, intent(in)           :: n
              type(vector), intent(in)      :: v1
              integer                       :: i

              if(v1%m/=m1%m.or.n>m1%n) then
                 write(*,*) "Error in matrix_type:setcolumn"
                 return
              end if

              do i=1,v1%m
                 m1%mData(i,n) = v1%vData(i)
              end do
       end subroutine setcolumn

       function multiplym(m1, m2)
              implicit none
              type(matrix), intent(in)   :: m1,m2
              type(matrix)               :: multiplym

              integer :: i,j

              call init(multiplym, m1%m, m2%n)

              if (m1%n/=m2%m) then
                 write(*,*) "Error, matrix_type:multiply"
                 return
              end if

              do i=1,multiplym%m
                 do j=1,multiplym%n
                    multiplym%mData(i,j) = row(m1,i)*column(m2,j)
                 end do
              end do
       end function multiplym

       function addm(m1,m2)
              type(matrix), intent(in)   :: m1,m2
              type(matrix)               :: addm

              integer :: i

              call init(addm, m1%m, m1%n)

              if (m1%m/=m2%m.or.m1%n/=m2%n) then
                 write(*,*) "Error, matrix_type:add"
                 return
              end if

              do i=1,addm%m
                 call setrow(addm, i, row(m1,i)+row(m2,i))
              end do
       end function addm

       function subtractm(m1,m2)
              type(matrix), intent(in)   :: m1,m2
              type(matrix)               :: subtractm

              integer :: i

              call init(subtractm, m1%m, m1%n)

              if (m1%m/=m2%m.or.m1%n/=m2%n) then
                 write(*,*) "Error, matrix_type:subtract"
                 return
              end if

              do i=1,subtractm%m
                 call setrow(subtractm, i, row(m1,i)-row(m2,i))
              end do
       end function subtractm

       function inverse(m1,n)
              type(matrix), intent(in)   :: m1
              integer, intent(in)        :: n
              type(matrix)               :: inverse, temp
              integer :: i,j

              call init(inverse, m1%m, m1%n)
              call init(temp, m1%m, 2*m1%n)

              if (m1%m/=m1%n.or.n/=-1) then
                 write(*,*) "Error, matrix_type:inverse"
                 return
              end if

              do i=1,inverse%m
                 do j=1,inverse%n
                    temp%mData(i,j) = m1%mData(i,j)
                 end do

                 do j=1,inverse%n
                    if(j==i) then
                       temp%mData(i,j+inverse%n) = 1.0
                    else
                       temp%mData(i,j+inverse%n) = 0.0
                    end if
                 end do
              end do

              call gauss_jordan(temp)

              do j=1,inverse%n
                 call setcolumn(inverse, j, column(temp, j+inverse%n))
              end do

       end function inverse

       subroutine dumpm(m1)
              type(matrix), intent(in)   :: m1
              integer                    :: i

              do i=1, m1%m
                 call dump(row(m1,i))
              end do
              write(*,'(/)',advance = 'no')
       end subroutine dumpm

       function trans(m1)
              type(matrix), intent(in)   :: m1
              type(matrix)               :: trans
              integer                    :: i,j

              call init(trans, m1%n, m1%m)

              do i=1,m1%m
                 do j=1,m1%n
                    trans%mData(j,i) = m1%mData(i,j)
                 end do
              end do
       end function trans

       subroutine gauss_jordan(m1)
              type(matrix), intent(inout)   :: m1
              integer                       :: i,j,k,rotatenum
              type(vector)                  :: tempvector

              do j=1,m1%m ! gauss-jordan loop
                 rotatenum = 1 !make sure working position is nonzero
                 do while (m1%mData(j,j)==0)
                    write(*,*)"rotate"
                    if(rotatenum>(m1%m-j)) then
                       write(*,*) "Matrix not invertable"
                       return
                    end if
                    tempvector = row(m1,j)
                    do k=j,m1%m-1
                       call setrow(m1, k, row(m1,k+1))
                    end do
                    call setrow(m1, m1%m, tempvector)
                    rotatenum = rotatenum + 1
                 end do

                 call setrow(m1, j, row(m1,j)/m1%mData(j,j))

                 if(j<m1%m) then
                    do i=j+1,m1%m
                       call setrow(m1,i,row(m1,i)-row(m1,j)*m1%mData(i,j))
                    end do
                 end if
              end do

              do j=1,m1%m ! backsubstitution loop
                 if(j<m1%m) then
                    do i=j+1,m1%m
                       call setrow(m1,(m1%m-i+1),row(m1,(m1%m-i+1))-row(m1,(m1%m-j+1))*m1%mData((m1%m-i+1),(m1%m-j+1)))
                    end do
                 end if
              end do
       end subroutine gauss_jordan
end module matrix_type
