Skip to content

Commit

Permalink
TSMP-PDAF: import changes from eCLM
Browse files Browse the repository at this point in the history
- include preprocessor variable `USE_PDAF`
- include OASIS code using preprocessor variable `USE_OASIS`
  • Loading branch information
jjokella committed Jun 19, 2024
1 parent e51fd8b commit a05d2e7
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 6 deletions.
61 changes: 57 additions & 4 deletions src/drivers/mct/main/cime_comp_mod.F90
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ module cime_comp_mod
use mct_mod ! mct_ wrappers for mct lib
use perf_mod
use ESMF
#if defined(USE_OASIS)
use mod_oasis
#endif

!----------------------------------------------------------------------------
! component model interfaces (init, run, final methods)
Expand Down Expand Up @@ -177,7 +180,9 @@ module cime_comp_mod

implicit none

!private
#ifndef USE_PDAF
private
#endif

public cime_pre_init1, cime_pre_init2, cime_init, cime_run, cime_final
public timing_dir, mpicom_GLOID
Expand Down Expand Up @@ -580,39 +585,72 @@ module cime_comp_mod
!*******************************************************************************
!===============================================================================

#ifdef USE_PDAF
subroutine cime_pre_init1(esmf_log_option, pdaf_comm, pdaf_id, pdaf_max)
#else
subroutine cime_pre_init1(esmf_log_option)
#endif
use shr_pio_mod, only : shr_pio_init1, shr_pio_init2
use seq_comm_mct, only: num_inst_driver
!----------------------------------------------------------
!| Initialize MCT and MPI communicators and IO
!----------------------------------------------------------

character(CS), intent(out) :: esmf_log_option ! For esmf_logfile_kind
#ifdef USE_PDAF
integer, optional, intent(in) :: pdaf_comm
integer, optional, intent(in) :: pdaf_id
integer, optional, intent(in) :: pdaf_max
#endif

integer, dimension(num_inst_total) :: comp_id, comp_comm, comp_comm_iam
logical :: comp_iamin(num_inst_total)
character(len=seq_comm_namelen) :: comp_name(num_inst_total)
integer :: it
integer :: driver_id
integer :: driver_comm
#if defined(USE_OASIS)
integer :: oas_comp_id
#endif

#ifndef USE_PDAF
call mpi_init(ierr)
call shr_mpi_chkerr(ierr,subname//' mpi_init')
#endif

#if defined(USE_OASIS)
call oasis_init_comp (oas_comp_id, 'eCLM', ierr)
if (ierr /= 0) then
call oasis_abort(oas_comp_id, 'cime_pre_init1', 'oasis_init_comp error')
end if
call oasis_get_localcomm(global_comm, ierr)
if (ierr /= 0) then
call oasis_abort(oas_comp_id, 'cime_pre_init1', 'oasis_get_localcomm error')
end if
#else
#ifdef USE_PDAF
if (present(pdaf_comm)) then
global_comm = pdaf_comm
else
! call mpi_init(ierr)
! call shr_mpi_chkerr(ierr,subname//' mpi_init')
call mpi_comm_dup(MPI_COMM_WORLD, global_comm, ierr)
call shr_mpi_chkerr(ierr,subname//' mpi_comm_dup')
end if
#else
call mpi_comm_dup(MPI_COMM_WORLD, global_comm, ierr)
call shr_mpi_chkerr(ierr,subname//' mpi_comm_dup')
#endif
#endif

comp_comm = MPI_COMM_NULL
time_brun = mpi_wtime()

!--- Initialize multiple driver instances, if requested ---
#ifdef USE_PDAF
call cime_cpl_init(global_comm, driver_comm, num_inst_driver, driver_id, &
pdaf_id, pdaf_max)
#else
call cime_cpl_init(global_comm, driver_comm, num_inst_driver, driver_id)
#endif

call shr_pio_init1(num_inst_total,NLFileName, driver_comm)
!
Expand All @@ -624,9 +662,11 @@ subroutine cime_pre_init1(esmf_log_option, pdaf_comm, pdaf_id, pdaf_max)
if (num_inst_driver > 1) then
call seq_comm_init(global_comm, driver_comm, NLFileName, drv_comm_ID=driver_id)
write(cpl_inst_tag,'("_",i4.4)') driver_id
#ifdef USE_PDAF
else if (present(pdaf_id) .and. present(pdaf_max)) then
call seq_comm_init(global_comm, driver_comm, NLFileName, &
pdaf_id=pdaf_id, pdaf_max=pdaf_max)
#endif
else
call seq_comm_init(global_comm, driver_comm, NLFileName)
cpl_inst_tag = ''
Expand Down Expand Up @@ -4145,6 +4185,10 @@ subroutine cime_final()
mpicom=mpicom_GLOID)
endif

#if defined(USE_OASIS)
call oasis_terminate (ierr)
call shr_mpi_chkerr(ierr,subname//' oasis_terminate')
#endif
call t_finalizef()

end subroutine cime_final
Expand Down Expand Up @@ -4215,8 +4259,12 @@ subroutine cime_comp_barriers(mpicom, timer)
endif
end subroutine cime_comp_barriers

#ifdef USE_PDAF
subroutine cime_cpl_init(comm_in, comm_out, num_inst_driver, id, &
pdaf_id, pdaf_max)
#else
subroutine cime_cpl_init(comm_in, comm_out, num_inst_driver, id)
#endif
!-----------------------------------------------------------------------
!
! Initialize multiple coupler instances, if requested
Expand All @@ -4229,8 +4277,9 @@ subroutine cime_cpl_init(comm_in, comm_out, num_inst_driver, id, &
integer , intent(out) :: comm_out
integer , intent(out) :: num_inst_driver
integer , intent(out) :: id ! instance ID, starts from 1
#ifdef USE_PDAF
integer , intent(in), optional :: pdaf_id, pdaf_max

#endif
!
! Local variables
!
Expand Down Expand Up @@ -4272,10 +4321,14 @@ subroutine cime_cpl_init(comm_in, comm_out, num_inst_driver, id, &
' : Total PE number must be a multiple of coupler instance number')
end if

#ifdef USE_PDAF
if (pdaf_max > 1) then
call mpi_comm_split(comm_in, pdaf_id, 0, comm_out, ierr)
call shr_mpi_chkerr(ierr,subname//' mpi_comm_split')
else if (num_inst_driver == 1) then
#else
if (num_inst_driver == 1) then
#endif
call mpi_comm_dup(comm_in, comm_out, ierr)
call shr_mpi_chkerr(ierr,subname//' mpi_comm_dup')
else
Expand Down
41 changes: 39 additions & 2 deletions src/drivers/mct/shr/seq_comm_mct.F90
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,12 @@ integer function seq_comm_get_ncomps()
seq_comm_get_ncomps = ncomps
end function seq_comm_get_ncomps

subroutine seq_comm_init(global_comm_in,driver_comm_in,nmlfile,drv_comm_id,&
#ifdef USE_PDAF
subroutine seq_comm_init(global_comm_in, driver_comm_in, nmlfile, drv_comm_id,&
pdaf_id, pdaf_max)
#else
subroutine seq_comm_init(global_comm_in, driver_comm_in, nmlfile, drv_comm_id)
#endif
!----------------------------------------------------------
!
! Arguments
Expand All @@ -208,8 +212,10 @@ subroutine seq_comm_init(global_comm_in,driver_comm_in,nmlfile,drv_comm_id,&
integer, intent(in) :: driver_comm_in
character(len=*), intent(IN) :: nmlfile
integer, intent(in), optional :: drv_comm_id
#ifdef USE_PDAF
integer, intent(in), optional :: pdaf_id
integer, intent(in), optional :: pdaf_max
#endif
!
! Local variables
!
Expand Down Expand Up @@ -407,6 +413,7 @@ subroutine seq_comm_init(global_comm_in,driver_comm_in,nmlfile,drv_comm_id,&
call mpi_bcast(pelist, size(pelist), MPI_INTEGER, 0, DRIVER_COMM, ierr)
call seq_comm_setcomm(CPLID,pelist,nthreads=cpl_nthreads,iname='CPL')

#ifdef USE_PDAF
call comp_comm_init(driver_comm, atm_rootpe, atm_nthreads, atm_layout, atm_ntasks, atm_pestride, num_inst_atm, &
CPLID, ATMID, CPLATMID, ALLATMID, CPLALLATMID, 'ATM', count, drv_comm_id, pdaf_id, pdaf_max)
call comp_comm_init(driver_comm, lnd_rootpe, lnd_nthreads, lnd_layout, lnd_ntasks, lnd_pestride, num_inst_lnd, &
Expand All @@ -423,6 +430,24 @@ subroutine seq_comm_init(global_comm_in,driver_comm_in,nmlfile,drv_comm_id,&
CPLID, WAVID, CPLWAVID, ALLWAVID, CPLALLWAVID, 'WAV', count, drv_comm_id, pdaf_id, pdaf_max)
call comp_comm_init(driver_comm, esp_rootpe, esp_nthreads, esp_layout, esp_ntasks, esp_pestride, num_inst_esp, &
CPLID, ESPID, CPLESPID, ALLESPID, CPLALLESPID, 'ESP', count, drv_comm_id, pdaf_id, pdaf_max)
#else
call comp_comm_init(driver_comm, atm_rootpe, atm_nthreads, atm_layout, atm_ntasks, atm_pestride, num_inst_atm, &
CPLID, ATMID, CPLATMID, ALLATMID, CPLALLATMID, 'ATM', count, drv_comm_id)
call comp_comm_init(driver_comm, lnd_rootpe, lnd_nthreads, lnd_layout, lnd_ntasks, lnd_pestride, num_inst_lnd, &
CPLID, LNDID, CPLLNDID, ALLLNDID, CPLALLLNDID, 'LND', count, drv_comm_id)
call comp_comm_init(driver_comm, ice_rootpe, ice_nthreads, ice_layout, ice_ntasks, ice_pestride, num_inst_ice, &
CPLID, ICEID, CPLICEID, ALLICEID, CPLALLICEID, 'ICE', count, drv_comm_id)
call comp_comm_init(driver_comm, ocn_rootpe, ocn_nthreads, ocn_layout, ocn_ntasks, ocn_pestride, num_inst_ocn, &
CPLID, OCNID, CPLOCNID, ALLOCNID, CPLALLOCNID, 'OCN', count, drv_comm_id)
call comp_comm_init(driver_comm, rof_rootpe, rof_nthreads, rof_layout, rof_ntasks, rof_pestride, num_inst_rof, &
CPLID, ROFID, CPLROFID, ALLROFID, CPLALLROFID, 'ROF', count, drv_comm_id)
call comp_comm_init(driver_comm, glc_rootpe, glc_nthreads, glc_layout, glc_ntasks, glc_pestride, num_inst_glc, &
CPLID, GLCID, CPLGLCID, ALLGLCID, CPLALLGLCID, 'GLC', count, drv_comm_id)
call comp_comm_init(driver_comm, wav_rootpe, wav_nthreads, wav_layout, wav_ntasks, wav_pestride, num_inst_wav, &
CPLID, WAVID, CPLWAVID, ALLWAVID, CPLALLWAVID, 'WAV', count, drv_comm_id)
call comp_comm_init(driver_comm, esp_rootpe, esp_nthreads, esp_layout, esp_ntasks, esp_pestride, num_inst_esp, &
CPLID, ESPID, CPLESPID, ALLESPID, CPLALLESPID, 'ESP', count, drv_comm_id)
#endif

if (count /= ncomps) then
write(logunit,*) trim(subname),' ERROR in ID count ',count,ncomps
Expand Down Expand Up @@ -496,10 +521,16 @@ subroutine seq_comm_init(global_comm_in,driver_comm_in,nmlfile,drv_comm_id,&

end subroutine seq_comm_init

#ifdef USE_PDAF
subroutine comp_comm_init(driver_comm, comp_rootpe, comp_nthreads, comp_layout, &
comp_ntasks, comp_pestride, num_inst_comp, &
CPLID, COMPID, CPLCOMPID, ALLCOMPID, CPLALLCOMPID, name, count, drv_comm_id, &
pdaf_id, pdaf_max)
#else
subroutine comp_comm_init(driver_comm, comp_rootpe, comp_nthreads, comp_layout, &
comp_ntasks, comp_pestride, num_inst_comp, &
CPLID, COMPID, CPLCOMPID, ALLCOMPID, CPLALLCOMPID, name, count, drv_comm_id)
#endif
integer, intent(in) :: driver_comm
integer, intent(in) :: comp_rootpe
integer, intent(in) :: comp_nthreads
Expand All @@ -514,8 +545,10 @@ subroutine comp_comm_init(driver_comm, comp_rootpe, comp_nthreads, comp_layout,
integer, intent(out) :: CPLALLCOMPID
integer, intent(inout) :: count
integer, intent(in), optional :: drv_comm_id
#ifdef USE_PDAF
integer, intent(in), optional :: pdaf_id
integer, intent(in), optional :: pdaf_max
#endif
character(len=*), intent(in) :: name

character(len=*), parameter :: subname = "comp_comm_init"
Expand Down Expand Up @@ -578,11 +611,15 @@ subroutine comp_comm_init(driver_comm, comp_rootpe, comp_nthreads, comp_layout,
pelist(2,1) = cmax(n)
pelist(3,1) = cstr(n)
endif

call mpi_bcast(pelist, size(pelist), MPI_INTEGER, 0, DRIVER_COMM, ierr)

#ifdef USE_PDAF
if (present(pdaf_id) .and. present(pdaf_max)) then
call seq_comm_setcomm(COMPID(n),pelist,nthreads=comp_nthreads,iname=name,inst=pdaf_id,tinst=pdaf_max)
else if (present(drv_comm_id)) then
#else
if (present(drv_comm_id)) then
#endif
call seq_comm_setcomm(COMPID(n),pelist,nthreads=comp_nthreads,iname=name,inst=drv_comm_id)
else
call seq_comm_setcomm(COMPID(n),pelist,nthreads=comp_nthreads,iname=name,inst=n,tinst=num_inst_comp)
Expand Down

0 comments on commit a05d2e7

Please sign in to comment.