#!/bin/bash

sizes="1 2 3 4 5 6 7 9 10 11 15"
limit=19

#
# this generates hfx_contract_block.F for given sizes for mx_max, providing explicit code
# as long as ma_max*mb_max*.... < limit but at list for ma_max and mb_max
#

cat << EOF > template2.f90
  SUBROUTINE block_ma_max_mb_max(mc_max,md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)
      INTEGER :: mc_max,md_max
      REAL(KIND=dp) :: kbd(mb_max*md_max),kbc(mb_max*mc_max),kad(ma_max*md_max),kac(ma_max*mc_max), &
                       pbd(mb_max*md_max),pbc(mb_max*mc_max),pad(ma_max*md_max),pac(ma_max*mc_max), &
                       prim(ma_max*mb_max*mc_max*md_max)
      REAL(KIND=dp) :: scale,tmp,ks_bd,ks_bc,p_bd,p_bc
      INTEGER :: p_index,ma,mb,mc,md
  
      kbd(1:mb_max*md_max) = 0.0_dp
      kbc(1:mb_max*mc_max) = 0.0_dp
      kad(1:ma_max*md_max) = 0.0_dp
      kac(1:ma_max*mc_max) = 0.0_dp
      p_index = 0
      DO md = 1,md_max
        DO mc = 1,mc_max
          DO mb = 1,mb_max
            ks_bd = 0.0_dp
            ks_bc = 0.0_dp
            p_bd = pbd((md-1)*mb_max+mb)
            p_bc = pbc((mc-1)*mb_max+mb)
            DO ma = 1,ma_max
              p_index=p_index+1
              tmp = scale*prim(p_index)
              ks_bc = ks_bc + tmp* pad((md-1)*ma_max+ma)
              ks_bd = ks_bd + tmp* pac((mc-1)*ma_max+ma)
              kad((md-1)*ma_max+ma) =  kad((md-1)*ma_max+ma)-tmp*p_bc
              kac((mc-1)*ma_max+ma) = kac((mc-1)*ma_max+ma)-tmp*p_bd
            END DO
            kbd((md-1)*mb_max+mb) = kbd((md-1)*mb_max+mb) - ks_bd
            kbc((mc-1)*mb_max+mb) = kbc((mc-1)*mb_max+mb) - ks_bc
          END DO
        END DO
      END DO
  END SUBROUTINE block_ma_max_mb_max
EOF
cat <<EOF > template3.f90
  SUBROUTINE block_ma_max_mb_max_mc_max(md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)
      INTEGER :: md_max
      REAL(KIND=dp) :: kbd(mb_max*md_max),kbc(mb_max*mc_max),kad(ma_max*md_max),kac(ma_max*mc_max), &
                       pbd(mb_max*md_max),pbc(mb_max*mc_max),pad(ma_max*md_max),pac(ma_max*mc_max), &
                       prim(ma_max*mb_max*mc_max*md_max)
      REAL(KIND=dp) :: scale,tmp,ks_bd,ks_bc,p_bd,p_bc
      INTEGER :: p_index,ma,mb,mc,md
  
      kbd(1:mb_max*md_max) = 0.0_dp
      kbc(1:mb_max*mc_max) = 0.0_dp
      kad(1:ma_max*md_max) = 0.0_dp
      kac(1:ma_max*mc_max) = 0.0_dp
      p_index = 0
      DO md = 1,md_max
        DO mc = 1,mc_max
          DO mb = 1,mb_max
            ks_bd = 0.0_dp
            ks_bc = 0.0_dp
            p_bd = pbd((md-1)*mb_max+mb)
            p_bc = pbc((mc-1)*mb_max+mb)
            DO ma = 1,ma_max
              p_index=p_index+1
              tmp = scale*prim(p_index)
              ks_bc = ks_bc + tmp* pad((md-1)*ma_max+ma)
              ks_bd = ks_bd + tmp* pac((mc-1)*ma_max+ma)
              kad((md-1)*ma_max+ma) =  kad((md-1)*ma_max+ma)-tmp*p_bc
              kac((mc-1)*ma_max+ma) = kac((mc-1)*ma_max+ma)-tmp*p_bd
            END DO
            kbd((md-1)*mb_max+mb) = kbd((md-1)*mb_max+mb) - ks_bd
            kbc((mc-1)*mb_max+mb) = kbc((mc-1)*mb_max+mb) - ks_bc
          END DO
        END DO
      END DO
  END SUBROUTINE block_ma_max_mb_max_mc_max
EOF
cat << EOF > template4.f90
  SUBROUTINE block_ma_max_mb_max_mc_max_md_max(kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)
      REAL(KIND=dp) :: kbd(mb_max*md_max),kbc(mb_max*mc_max),kad(ma_max*md_max),kac(ma_max*mc_max), &
                       pbd(mb_max*md_max),pbc(mb_max*mc_max),pad(ma_max*md_max),pac(ma_max*mc_max), &
                       prim(ma_max*mb_max*mc_max*md_max)
      REAL(KIND=dp) :: scale,tmp,ks_bd,ks_bc,p_bd,p_bc
      INTEGER :: p_index,ma,mb,mc,md
  
      kbd(1:mb_max*md_max) = 0.0_dp
      kbc(1:mb_max*mc_max) = 0.0_dp
      kad(1:ma_max*md_max) = 0.0_dp
      kac(1:ma_max*mc_max) = 0.0_dp
      p_index = 0
      DO md = 1,md_max
        DO mc = 1,mc_max
          DO mb = 1,mb_max
            ks_bd = 0.0_dp
            ks_bc = 0.0_dp
            p_bd = pbd((md-1)*mb_max+mb)
            p_bc = pbc((mc-1)*mb_max+mb)
            DO ma = 1,ma_max
              p_index=p_index+1
              tmp = scale*prim(p_index)
              ks_bc = ks_bc + tmp* pad((md-1)*ma_max+ma)
              ks_bd = ks_bd + tmp* pac((mc-1)*ma_max+ma)
              kad((md-1)*ma_max+ma) =  kad((md-1)*ma_max+ma)-tmp*p_bc
              kac((mc-1)*ma_max+ma) = kac((mc-1)*ma_max+ma)-tmp*p_bd
            END DO
            kbd((md-1)*mb_max+mb) = kbd((md-1)*mb_max+mb) - ks_bd
            kbc((mc-1)*mb_max+mb) = kbc((mc-1)*mb_max+mb) - ks_bc
          END DO
        END DO
      END DO
  END SUBROUTINE block_ma_max_mb_max_mc_max_md_max
EOF
cat << EOF > default.f90
  SUBROUTINE block_default(ma_max,mb_max,mc_max,md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)
      INTEGER :: ma_max,mb_max,mc_max,md_max
      REAL(KIND=dp) :: kbd(mb_max*md_max),kbc(mb_max*mc_max),kad(ma_max*md_max),kac(ma_max*mc_max), &
                       pbd(mb_max*md_max),pbc(mb_max*mc_max),pad(ma_max*md_max),pac(ma_max*mc_max), &
                       prim(ma_max*mb_max*mc_max*md_max)
      REAL(KIND=dp) :: scale,tmp,ks_bd,ks_bc,p_bd,p_bc
      INTEGER :: p_index,ma,mb,mc,md
  
      kbd(1:mb_max*md_max) = 0.0_dp
      kbc(1:mb_max*mc_max) = 0.0_dp
      kad(1:ma_max*md_max) = 0.0_dp
      kac(1:ma_max*mc_max) = 0.0_dp
      p_index = 0
      DO md = 1,md_max
        DO mc = 1,mc_max
          DO mb = 1,mb_max
            ks_bd = 0.0_dp
            ks_bc = 0.0_dp
            p_bd = pbd((md-1)*mb_max+mb)
            p_bc = pbc((mc-1)*mb_max+mb)
            DO ma = 1,ma_max
              p_index=p_index+1
              tmp = scale*prim(p_index)
              ks_bc = ks_bc + tmp* pad((md-1)*ma_max+ma)
              ks_bd = ks_bd + tmp* pac((mc-1)*ma_max+ma)
              kad((md-1)*ma_max+ma) =  kad((md-1)*ma_max+ma)-tmp*p_bc
              kac((mc-1)*ma_max+ma) = kac((mc-1)*ma_max+ma)-tmp*p_bd
            END DO
            kbd((md-1)*mb_max+mb) = kbd((md-1)*mb_max+mb) - ks_bd
            kbc((mc-1)*mb_max+mb) = kbc((mc-1)*mb_max+mb) - ks_bc
          END DO
        END DO
      END DO
  END SUBROUTINE block_default
EOF


#
# actual code generation
#

rm -f caller
rm -f functions
s="  "

cat default.f90 >> functions

echo "${s}${s}SELECT CASE(ma_max) " >> caller
for ma in $sizes
do
   echo "${s}${s}CASE($ma) " >> caller

   echo "${s}${s}${s}SELECT CASE(mb_max) " >> caller
   for mb in $sizes
   do
      echo "${s}${s}${s}CASE($mb) " >> caller
      product=$((ma*mb))
      if [ $product -lt $limit ]; then
         have2=0
         echo "${s}${s}${s}${s}SELECT CASE(mc_max) " >> caller
         for mc in $sizes
         do
            echo "${s}${s}${s}${s}CASE($mc) " >> caller
            product=$((ma*mb*mc))
            if [ $product -lt $limit ]; then
               have3=0
               echo "${s}${s}${s}${s}${s}SELECT CASE(md_max) " >> caller
               for md in $sizes
               do
                  echo "${s}${s}${s}${s}${s}CASE($md) " >> caller
                  product=$((ma*mb*mc*md))
                  if [ $product -lt $limit ]; then
                     echo "${s}${s}${s}${s}${s}${s}CALL block_${ma}_${mb}_${mc}_${md}(kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)" >> caller
                     cat template4.f90 | sed "s/ma_max/$ma/g" | sed "s/mb_max/$mb/g" | sed "s/mc_max/$mc/g" | sed "s/md_max/$md/g" >> functions
                  else
                     echo "${s}${s}${s}${s}${s}${s}CALL block_${ma}_${mb}_${mc}(md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)" >> caller
                     if [ $have3 -eq 0 ]; then
                        have3=1
                        cat template3.f90 | sed "s/ma_max/$ma/g" | sed "s/mb_max/$mb/g" | sed "s/mc_max/$mc/g" >> functions
                     fi
                  fi
               done
               echo "${s}${s}${s}${s}${s}CASE DEFAULT " >> caller
               echo "${s}${s}${s}${s}${s}${s}CALL block_default(ma_max,mb_max,mc_max,md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)" >> caller
               echo "${s}${s}${s}${s}${s}END SELECT" >> caller
            else
               echo "${s}${s}${s}${s}${s}CALL block_${ma}_${mb}(mc_max,md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)" >> caller
               if [ $have2 -eq 0 ]; then
                  have2=1
                  cat template2.f90 | sed "s/ma_max/$ma/g" | sed "s/mb_max/$mb/g" >> functions 
               fi
            fi
         done
         echo "${s}${s}${s}${s}CASE DEFAULT " >> caller
         echo "${s}${s}${s}${s}${s}CALL block_default(ma_max,mb_max,mc_max,md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)" >> caller
         echo "${s}${s}${s}${s}END SELECT" >> caller
      else
         echo "${s}${s}${s}${s}CALL block_${ma}_${mb}(mc_max,md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)" >> caller
         cat template2.f90 | sed "s/ma_max/$ma/g" | sed "s/mb_max/$mb/g" >> functions
      fi
   done
   echo "${s}${s}${s}CASE DEFAULT " >> caller
   echo "${s}${s}${s}${s}CALL block_default(ma_max,mb_max,mc_max,md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)" >> caller
   echo "${s}${s}${s}END SELECT" >> caller

done
echo "${s}${s}CASE DEFAULT " >> caller
echo "${s}${s}${s}CALL block_default(ma_max,mb_max,mc_max,md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)" >> caller
echo "${s}${s}END SELECT" >> caller

#
# assemble final file
#

cat << EOF > hfx_contract_block.F
!-----------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations         !
!   Copyright (C) 2000 - 2013  CP2K developers group                          !
!-----------------------------------------------------------------------------!
! *****************************************************************************
!> \brief routines to contract density matrix blocks with the for center
!>        integrals to yield the Kohn-Sham matrix. The specialized routines
!>        are about 1.2-2.0 as fast as the default one.
!> \par History
!>      10.2009 created [Joost VandeVondele]
!> \author Joost VandeVondele
! *****************************************************************************
MODULE hfx_contract_block
  USE kinds,                           ONLY: dp

  IMPLICIT NONE
  PRIVATE
  CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'hfx_contract_block'
  PUBLIC :: contract_block
CONTAINS
  SUBROUTINE contract_block(ma_max,mb_max,mc_max,md_max,kbd,kbc,kad,kac,pbd,pbc,pad,pac,prim,scale)
    INTEGER                                  :: ma_max, mb_max, mc_max, md_max
    REAL(KIND=dp) :: kbd(mb_max*md_max), kbc(mb_max*mc_max), &
      kad(ma_max*md_max), kac(ma_max*mc_max), pbd(mb_max*md_max), &
      pbc(mb_max*mc_max), pad(ma_max*md_max), pac(ma_max*mc_max), &
      prim(ma_max*mb_max*mc_max*md_max), scale

#if defined (__LIBINT)
EOF
cat caller >> hfx_contract_block.F
cat << EOF >> hfx_contract_block.F
#else
    RETURN
#endif
  END SUBROUTINE contract_block

#if defined (__LIBINT)
EOF
cat functions >> hfx_contract_block.F
cat << EOF >> hfx_contract_block.F
#endif
END MODULE hfx_contract_block
EOF

rm -f caller functions template2.f90 template3.f90 template4.f90 default.f90
