Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Corrector stencil 60 #350

Merged
merged 11 commits into from
Jan 19, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dataclasses import dataclass
from typing import Final, Optional

from gt4py.next import Field, as_field
from gt4py.next import as_field
from gt4py.next.common import Field
from gt4py.next.ffront.fbuiltins import int32
from gt4py.next.program_processors.runners.gtfn import run_gtfn
Expand Down Expand Up @@ -126,6 +126,9 @@
from icon4py.model.atmosphere.dycore.mo_solve_nonhydro_stencil_59 import (
mo_solve_nonhydro_stencil_59,
)
from icon4py.model.atmosphere.dycore.mo_solve_nonhydro_stencil_60 import (
mo_solve_nonhydro_stencil_60,
)
from icon4py.model.atmosphere.dycore.mo_solve_nonhydro_stencil_65 import (
mo_solve_nonhydro_stencil_65,
)
Expand Down Expand Up @@ -440,7 +443,7 @@ def init(
else:
self.jk_start = 0

en_smag_fac_for_zero_nshift.with_backend(run_gtfn)(
en_smag_fac_for_zero_nshift.with_backend(backend)(
self.vertical_params.vct_a,
self.config.divdamp_fac,
self.config.divdamp_fac2,
Expand Down Expand Up @@ -474,8 +477,6 @@ def _allocate_local_fields(self):
self.z_grad_rth_3 = _allocate(CellDim, KDim, grid=self.grid)
self.z_grad_rth_4 = _allocate(CellDim, KDim, grid=self.grid)
self.z_dexner_dz_c_2 = _allocate(CellDim, KDim, grid=self.grid)
# TODO (magdalena) missing stencil_60 in corrector remove! this is a field from the diagnostics!
self.exner_dyn_incr = _allocate(CellDim, KDim, grid=self.grid)
self.z_hydro_corr = _allocate(EdgeDim, KDim, grid=self.grid)
self.z_vn_avg = _allocate(EdgeDim, KDim, grid=self.grid)
self.z_theta_v_fl_e = _allocate(EdgeDim, KDim, grid=self.grid)
Expand Down Expand Up @@ -509,16 +510,17 @@ def time_step(
prep_adv: PrepAdvection,
divdamp_fac_o2: float,
dtime: float,
idyn_timestep: float,
l_recompute: bool,
l_init: bool,
nnow: int,
nnew: int,
lclean_mflx: bool,
lprep_adv: bool,
at_first_substep: bool,
at_last_substep: bool,
):
log.info(
f"running timestep: dtime = {dtime}, dyn_timestep = {idyn_timestep}, init = {l_init}, recompute = {l_recompute}, prep_adv = {lprep_adv} "
f"running timestep: dtime = {dtime}, init = {l_init}, recompute = {l_recompute}, prep_adv = {lprep_adv} clean_mflx={lclean_mflx} "
)
start_cell_lb = self.grid.get_start_index(
CellDim, HorizontalMarkerIndex.lateral_boundary(CellDim)
Expand Down Expand Up @@ -550,15 +552,13 @@ def time_step(
prognostic_state=prognostic_state_ls,
z_fields=self.intermediate_fields,
dtime=dtime,
idyn_timestep=idyn_timestep,
l_recompute=l_recompute,
l_init=l_init,
at_first_substep=at_first_substep,
nnow=nnow,
nnew=nnew,
)
log.info(
f"running corrector step: dtime = {dtime}, dyn_timestep = {idyn_timestep}, prep_adv = {lprep_adv}, divdamp_fac_od = {divdamp_fac_o2} "
)

self.run_corrector_step(
diagnostic_state_nh=diagnostic_state_nh,
prognostic_state=prognostic_state_ls,
Expand All @@ -570,10 +570,9 @@ def time_step(
nnow=nnow,
lclean_mflx=lclean_mflx,
lprep_adv=lprep_adv,
at_last_substep=at_last_substep,
)
log.info(
f"running corrector step: dtime = {dtime}, dyn_timestep = {idyn_timestep}, prep_adv = {lprep_adv}, divdamp_fac_od = {divdamp_fac_o2} "
)

start_cell_lb = self.grid.get_start_index(
CellDim, HorizontalMarkerIndex.lateral_boundary(CellDim)
)
Expand Down Expand Up @@ -633,14 +632,14 @@ def run_predictor_step(
prognostic_state: list[PrognosticState],
z_fields: IntermediateFields,
dtime: float,
idyn_timestep: float,
l_recompute: bool,
l_init: bool,
at_first_substep: bool,
nnow: int,
nnew: int,
):
log.info(
f"running predictor step: dtime = {dtime}, dyn_timestep = {idyn_timestep}, init = {l_init}, recompute = {l_recompute} "
f"running predictor step: dtime = {dtime}, init = {l_init}, recompute = {l_recompute} "
)
if l_init or l_recompute:
if self.config.itime_scheme == 4 and not l_init:
Expand Down Expand Up @@ -871,7 +870,7 @@ def run_predictor_step(
horizontal_start=start_vertex_lb_plus1,
horizontal_end=end_vertex_local_minus1,
vertical_start=0,
vertical_end=self.grid.num_levels, # UBOUND(p_cell_in,2)
vertical_end=self.grid.num_levels,
offset_provider={
"V2C": self.grid.get_offset_provider("V2C"),
},
Expand Down Expand Up @@ -1375,10 +1374,10 @@ def run_predictor_step(
offset_provider={"Koff": KDim},
)

if idyn_timestep == 1:
if at_first_substep:
mo_solve_nonhydro_stencil_59.with_backend(backend)(
exner=prognostic_state[nnow].exner,
exner_dyn_incr=self.exner_dyn_incr,
exner_dyn_incr=diagnostic_state_nh.exner_dyn_incr,
horizontal_start=start_cell_nudging,
horizontal_end=end_cell_local,
vertical_start=self.vertical_params.kstart_moist,
Expand Down Expand Up @@ -1437,7 +1436,14 @@ def run_corrector_step(
nnow: int,
lclean_mflx: bool,
lprep_adv: bool,
at_last_substep: bool,
):
log.info(
f"running corrector step: dtime = {dtime}, prep_adv = {lprep_adv}, divdamp_fac_o2 = {divdamp_fac_o2} clean_mfxl= {lclean_mflx} "
)

# TODO (magdalena) is it correct to to use a config parameter here? the actual number of substeps can vary dynmically...
# should this config parameter exist at all in SolveNonHydro?
# Inverse value of ndyn_substeps for tracer advection precomputations
r_nsubsteps = 1.0 / self.config.ndyn_substeps_var

Expand All @@ -1446,7 +1452,7 @@ def run_corrector_step(
# Coefficient for reduced fourth-order divergence d
scal_divdamp_o2 = divdamp_fac_o2 * self.cell_params.mean_cell_area

_calculate_divdamp_fields.with_backend(run_gtfn)(
_calculate_divdamp_fields.with_backend(backend)(
self.enh_divdamp_fac,
int32(self.config.divdamp_order),
self.cell_params.mean_cell_area,
Expand Down Expand Up @@ -1799,7 +1805,7 @@ def run_corrector_step(
offset_provider={},
)
if not self.l_vert_nested:
mo_solve_nonhydro_stencil_46.with_backend(run_gtfn)(
mo_solve_nonhydro_stencil_46.with_backend(backend)(
w_nnew=prognostic_state[nnew].w,
z_contr_w_fl_l=z_fields.z_contr_w_fl_l,
horizontal_start=start_cell_nudging,
Expand Down Expand Up @@ -1946,7 +1952,19 @@ def run_corrector_step(
vertical_end=self.grid.num_levels,
offset_provider={},
)
# TODO (magdalena) stencil_60 is missing here?
if at_last_substep:
mo_solve_nonhydro_stencil_60.with_backend(backend)(
exner=prognostic_state[nnew].exner,
ddt_exner_phy=diagnostic_state_nh.ddt_exner_phy,
exner_dyn_incr=diagnostic_state_nh.exner_dyn_incr,
ndyn_substeps_var=float(self.config.ndyn_substeps_var),
dtime=dtime,
horizontal_start=start_cell_nudging,
horizontal_end=end_cell_local,
vertical_start=self.vertical_params.kstart_moist,
vertical_end=int32(self.grid.num_levels),
offset_provider={},
)

if lprep_adv:
if lclean_mflx:
Expand Down
41 changes: 30 additions & 11 deletions model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def test_nonhydro_predictor_step(
sp_v = savepoint_velocity_init
dtime = sp_v.get_metadata("dtime").get("dtime")
recompute = sp_v.get_metadata("recompute").get("recompute")
dyn_timestep = sp.get_metadata("dyn_timestep").get("dyn_timestep")
linit = sp_v.get_metadata("linit").get("linit")

nnow = 0
Expand Down Expand Up @@ -161,9 +160,9 @@ def test_nonhydro_predictor_step(
prognostic_state=prognostic_state_ls,
z_fields=solve_nonhydro.intermediate_fields,
dtime=dtime,
idyn_timestep=dyn_timestep,
l_recompute=recompute,
l_init=linit,
at_first_substep=(jstep_init == 0),
nnow=nnow,
nnew=nnew,
)
Expand Down Expand Up @@ -490,6 +489,7 @@ def create_vertical_params(damping_height, grid_savepoint):
def test_nonhydro_corrector_step(
istep_init,
istep_exit,
jstep_init,
step_date_init,
step_date_exit,
icon_grid,
Expand Down Expand Up @@ -579,6 +579,7 @@ def test_nonhydro_corrector_step(
nnow=nnow,
lclean_mflx=clean_mflx,
lprep_adv=lprep_adv,
at_last_substep=jstep_init == (ndyn_substeps - 1),
)
if icon_grid.limited_area:
assert dallclose(solve_nonhydro._bdy_divdamp.asnumpy(), sp.bdy_divdamp().asnumpy())
Expand Down Expand Up @@ -652,6 +653,12 @@ def test_nonhydro_corrector_step(
savepoint_nonhydro_exit.vn_traj().asnumpy(),
rtol=5e-7, # TODO (magdalena) was rtol=1e-10 for local experiment only
)
# stencil 60 only relevant for last substep
assert dallclose(
diagnostic_state_nh.exner_dyn_incr.asnumpy(),
savepoint_nonhydro_exit.exner_dyn_incr().asnumpy(),
atol=1e-14,
)


@pytest.mark.datatest
Expand Down Expand Up @@ -702,7 +709,6 @@ def test_run_solve_nonhydro_single_step(
nnew = 1
recompute = sp_v.get_metadata("recompute").get("recompute")
linit = sp_v.get_metadata("linit").get("linit")
dyn_timestep = sp_v.get_metadata("dyn_timestep").get("dyn_timestep")

diagnostic_state_nh = construct_diagnostics(sp, sp_v)

Expand Down Expand Up @@ -734,13 +740,14 @@ def test_run_solve_nonhydro_single_step(
prep_adv=prep_adv,
divdamp_fac_o2=initial_divdamp_fac,
dtime=dtime,
idyn_timestep=dyn_timestep,
l_recompute=recompute,
l_init=linit,
nnew=nnew,
nnow=nnow,
lclean_mflx=clean_mflx,
lprep_adv=lprep_adv,
at_first_substep=jstep_init == 0,
at_last_substep=jstep_init == (ndyn_substeps - 1),
)
prognostic_state_nnew = prognostic_state_ls[1]
assert dallclose(
Expand All @@ -767,6 +774,12 @@ def test_run_solve_nonhydro_single_step(
atol=8e-14,
)

assert dallclose(
diagnostic_state_nh.exner_dyn_incr.asnumpy(),
savepoint_nonhydro_exit.exner_dyn_incr().asnumpy(),
atol=1e-14,
)


@pytest.mark.datatest
@pytest.mark.parametrize("experiment", [REGIONAL_EXPERIMENT])
Expand All @@ -790,9 +803,9 @@ def test_run_solve_nonhydro_multi_step(
savepoint_nonhydro_exit,
savepoint_nonhydro_step_exit,
experiment,
ndyn_substeps,
):
nsubsteps = grid_savepoint.get_metadata("nsteps").get("nsteps")
config = construct_config(experiment, nsubsteps)
config = construct_config(experiment, ndyn_substeps=ndyn_substeps)
sp = savepoint_nonhydro_init
sp_step_exit = savepoint_nonhydro_step_exit
nonhydro_params = NonHydrostaticParams(config)
Expand All @@ -809,8 +822,6 @@ def test_run_solve_nonhydro_multi_step(
nnew = 1
recompute = sp_v.get_metadata("recompute").get("recompute")
linit = sp_v.get_metadata("linit").get("linit")
dyn_timestep = sp_v.get_metadata("dyn_timestep").get("dyn_timestep")

diagnostic_state_nh = construct_diagnostics(sp, sp_v)

prognostic_state_ls = create_prognostic_states(sp)
Expand All @@ -834,25 +845,28 @@ def test_run_solve_nonhydro_multi_step(
owner_mask=grid_savepoint.c_owner_mask(),
)

for i_substep in range(nsubsteps):
for i_substep in range(ndyn_substeps):
is_first_substep = i_substep == 0
is_last_substep = i_substep == (ndyn_substeps - 1)
solve_nonhydro.time_step(
diagnostic_state_nh=diagnostic_state_nh,
prognostic_state_ls=prognostic_state_ls,
prep_adv=prep_adv,
divdamp_fac_o2=sp.divdamp_fac_o2(),
dtime=dtime,
idyn_timestep=dyn_timestep,
l_recompute=recompute,
l_init=linit,
nnew=nnew,
nnow=nnow,
lclean_mflx=clean_mflx,
lprep_adv=lprep_adv,
at_first_substep=is_first_substep,
at_last_substep=is_last_substep,
)
linit = False
recompute = False
clean_mflx = False
if i_substep != nsubsteps - 1:
if not is_last_substep:
ntemp = nnow
nnow = nnew
nnew = ntemp
Expand Down Expand Up @@ -924,6 +938,11 @@ def test_run_solve_nonhydro_multi_step(
savepoint_nonhydro_exit.vn_new().asnumpy(),
atol=5e-13,
)
assert dallclose(
diagnostic_state_nh.exner_dyn_incr.asnumpy(),
savepoint_nonhydro_exit.exner_dyn_incr().asnumpy(),
atol=1e-14,
)


@pytest.mark.datatest
Expand Down
2 changes: 2 additions & 0 deletions model/common/src/icon4py/model/common/grid/icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
C2E2CDim,
C2E2CODim,
C2EDim,
C2VDim,
CECDim,
CEDim,
CellDim,
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(self):
"E2C2V": (self._get_offset_provider, E2C2VDim, EdgeDim, VertexDim),
"V2E": (self._get_offset_provider, V2EDim, VertexDim, EdgeDim),
"V2C": (self._get_offset_provider, V2CDim, VertexDim, CellDim),
"C2V": (self._get_offset_provider, C2VDim, CellDim, VertexDim),
"E2ECV": (self._get_offset_provider_for_sparse_fields, E2C2VDim, EdgeDim, ECVDim),
"C2CEC": (self._get_offset_provider_for_sparse_fields, C2E2CDim, CellDim, CECDim),
"C2CE": (self._get_offset_provider_for_sparse_fields, C2EDim, CellDim, CEDim),
Expand Down
16 changes: 13 additions & 3 deletions model/driver/src/icon4py/model/driver/dycore_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def _validate_config(self):
def _not_first_step(self):
self._do_initial_stabilization = False

def _is_last_substep(self, step_nr: int):
return step_nr == (self.n_substeps_var - 1)

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this staticmethod and not the other one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the other accesses the class content, so you need the `self

def _is_first_substep(step_nr: int):
return step_nr == 0

def _next_simulation_date(self):
self._simulation_date += timedelta(seconds=self.run_config.dtime)

Expand Down Expand Up @@ -241,27 +248,30 @@ def _do_dyn_substepping(
do_clean_mflx = True
for dyn_substep in range(self._n_substeps_var):
log.info(
f"simulation date : {self._simulation_date} sub timestep : {dyn_substep}, initial_stabilization : {self._do_initial_stabilization}, nnow: {self._now}, nnew : {self._next}"
f"simulation date : {self._simulation_date} substep / n_substeps : {dyn_substep} / "
f"{self.n_substeps_var} , initial_stabilization : {self._do_initial_stabilization}, "
f"nnow: {self._now}, nnew : {self._next}"
)
self.solve_nonhydro.time_step(
solve_nonhydro_diagnostic_state,
prognostic_state_list,
prep_adv=prep_adv,
divdamp_fac_o2=inital_divdamp_fac_o2,
dtime=self._substep_timestep,
idyn_timestep=dyn_substep,
l_recompute=do_recompute,
l_init=self._do_initial_stabilization,
nnew=self._next,
nnow=self._now,
lclean_mflx=do_clean_mflx,
lprep_adv=do_prep_adv,
at_first_substep=self._is_first_substep(dyn_substep),
at_last_substep=self._is_last_substep(dyn_substep),
)

do_recompute = False
do_clean_mflx = False

if dyn_substep != self._n_substeps_var - 1:
if not self._is_last_substep(dyn_substep):
self._swap()

self._not_first_step()
Expand Down