Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove zero stencils from dycore/state_utils/utils.py #425

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

import icon4py.model.atmosphere.dycore.nh_solve.solve_nonhydro_program as nhsolve_prog
import icon4py.model.common.constants as constants
from icon4py.model.atmosphere.dycore.set_cell_kdim_field_to_zero_wp import (
set_cell_kdim_field_to_zero_wp,
)

from icon4py.model.atmosphere.dycore.accumulate_prep_adv_fields import (
accumulate_prep_adv_fields,
)
Expand Down Expand Up @@ -136,8 +140,6 @@
_allocate_indices,
_calculate_divdamp_fields,
compute_z_raylfac,
set_zero_c_k,
set_zero_e_k,
)
from icon4py.model.atmosphere.dycore.update_dynamical_exner_time_increment import (
update_dynamical_exner_time_increment,
Expand Down Expand Up @@ -244,7 +246,6 @@ def __init__(
rayleigh_type: int = 2,
rayleigh_coeff: float = 0.05,
divdamp_order: int = 24, # the ICON default is 4,
idiv_method: int = 1,
is_iau_active: bool = False,
iau_wgt_dyn: float = 0.0,
divdamp_type: int = 3,
Expand Down Expand Up @@ -328,9 +329,6 @@ def __init__(
#: IAU weight for dynamics fields
self.iau_wgt_dyn: float = iau_wgt_dyn

#: from mo_dynamics_nml.f90
self.idiv_method: int = idiv_method

self._validate()

def _validate(self):
Expand All @@ -345,9 +343,6 @@ def _validate(self):
if self.itime_scheme != 4:
raise NotImplementedError("itime_scheme can only be 4")

if self.idiv_method != 1:
raise NotImplementedError("idiv_method can only be 1")

if self.divdamp_order != 24:
raise NotImplementedError("divdamp_order can only be 24")

Expand Down Expand Up @@ -709,6 +704,9 @@ def run_predictor_step(
start_edge_lb_plus4 = self.grid.get_start_index(
EdgeDim, HorizontalMarkerIndex.lateral_boundary(EdgeDim) + 4
)
start_edge_local_minus2 = self.grid.get_start_index(
EdgeDim, HorizontalMarkerIndex.local(EdgeDim) - 2
)
end_edge_local_minus2 = self.grid.get_end_index(
EdgeDim, HorizontalMarkerIndex.local(EdgeDim) - 2
)
Expand Down Expand Up @@ -898,50 +896,26 @@ def run_predictor_step(
},
)
if self.config.iadv_rhotheta <= 2:
tmp_0_0 = self.grid.get_start_index(EdgeDim, HorizontalMarkerIndex.local(EdgeDim) - 2)
offset = 2 if self.config.idiv_method == 1 else 3
tmp_0_1 = self.grid.get_end_index(
EdgeDim, HorizontalMarkerIndex.local(EdgeDim) - offset
)

set_zero_e_k.with_backend(backend)(
field=z_fields.z_rho_e,
horizontal_start=tmp_0_0,
horizontal_end=tmp_0_1,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

set_zero_e_k.with_backend(backend)(
field=z_fields.z_theta_v_e,
horizontal_start=tmp_0_0,
horizontal_end=tmp_0_1,
set_two_edge_kdim_fields_to_zero_wp.with_backend(backend)(
edge_kdim_field_to_zero_wp_1=z_fields.z_rho_e,
edge_kdim_field_to_zero_wp_2=z_fields.z_theta_v_e,
horizontal_start=start_edge_local_minus2,
horizontal_end=end_edge_local_minus2,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

# initialize also nest boundary points with zero
if self.grid.limited_area:
set_zero_e_k.with_backend(backend)(
field=z_fields.z_rho_e,
horizontal_start=start_edge_lb,
horizontal_end=end_edge_local_minus1,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

set_zero_e_k.with_backend(backend)(
field=z_fields.z_theta_v_e,
set_two_edge_kdim_fields_to_zero_wp.with_backend(backend)(
edge_kdim_field_to_zero_wp_1=z_fields.z_rho_e,
edge_kdim_field_to_zero_wp_2=z_fields.z_theta_v_e,
horizontal_start=start_edge_lb,
horizontal_end=end_edge_local_minus1,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

if self.config.iadv_rhotheta == 2:
# Compute upwind-biased values for rho and theta starting from centered differences
# Note: the length of the backward trajectory should be 0.5*dtime*(vn,vt) in order to arrive
Expand Down Expand Up @@ -1134,20 +1108,19 @@ def run_predictor_step(
},
)

if self.config.idiv_method == 1:
compute_mass_flux.with_backend(backend)(
z_rho_e=z_fields.z_rho_e,
z_vn_avg=self.z_vn_avg,
ddqz_z_full_e=self.metric_state_nonhydro.ddqz_z_full_e,
z_theta_v_e=z_fields.z_theta_v_e,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
horizontal_start=start_edge_lb_plus4,
horizontal_end=end_edge_local_minus2,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)
compute_mass_flux.with_backend(backend)(
z_rho_e=z_fields.z_rho_e,
z_vn_avg=self.z_vn_avg,
ddqz_z_full_e=self.metric_state_nonhydro.ddqz_z_full_e,
z_theta_v_e=z_fields.z_theta_v_e,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
horizontal_start=start_edge_lb_plus4,
horizontal_end=end_edge_local_minus2,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

nhsolve_prog.predictor_stencils_35_36.with_backend(backend)(
vn=prognostic_state[nnew].vn,
Expand Down Expand Up @@ -1203,22 +1176,21 @@ def run_predictor_step(
},
)

if self.config.idiv_method == 1:
compute_divergence_of_fluxes_of_rho_and_theta.with_backend(backend)(
geofac_div=self.interpolation_state.geofac_div,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
z_flxdiv_mass=self.z_flxdiv_mass,
z_flxdiv_theta=self.z_flxdiv_theta,
horizontal_start=start_cell_nudging,
horizontal_end=end_cell_local,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={
"C2E": self.grid.get_offset_provider("C2E"),
"C2CE": self.grid.get_offset_provider("C2CE"),
},
)
compute_divergence_of_fluxes_of_rho_and_theta.with_backend(backend)(
geofac_div=self.interpolation_state.geofac_div,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
z_flxdiv_mass=self.z_flxdiv_mass,
z_flxdiv_theta=self.z_flxdiv_theta,
horizontal_start=start_cell_nudging,
horizontal_end=end_cell_local,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={
"C2E": self.grid.get_offset_provider("C2E"),
"C2CE": self.grid.get_offset_provider("C2CE"),
},
)

nhsolve_prog.stencils_43_44_45_45b.with_backend(backend)(
z_w_expl=z_fields.z_w_expl,
Expand Down Expand Up @@ -1686,50 +1658,48 @@ def run_corrector_step(
},
)

if self.config.idiv_method == 1:
log.debug("corrector: start stencil 32")
compute_mass_flux.with_backend(backend)(
z_rho_e=z_fields.z_rho_e,
log.debug("corrector: start stencil 32")
compute_mass_flux.with_backend(backend)(
z_rho_e=z_fields.z_rho_e,
z_vn_avg=self.z_vn_avg,
ddqz_z_full_e=self.metric_state_nonhydro.ddqz_z_full_e,
z_theta_v_e=z_fields.z_theta_v_e,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
horizontal_start=start_edge_lb_plus4,
horizontal_end=end_edge_local_minus2, # TODO: (halungge) this is actually the second halo line
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

if lprep_adv: # Preparations for tracer advection
log.debug("corrector: doing prep advection")
if lclean_mflx:
log.debug("corrector: start stencil 33")
set_two_edge_kdim_fields_to_zero_wp.with_backend(backend)(
edge_kdim_field_to_zero_wp_1=prep_adv.vn_traj,
edge_kdim_field_to_zero_wp_2=prep_adv.mass_flx_me,
horizontal_start=start_edge_lb,
horizontal_end=end_edge_end,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)
log.debug(f"corrector: start stencil 34")
accumulate_prep_adv_fields.with_backend(backend)(
z_vn_avg=self.z_vn_avg,
ddqz_z_full_e=self.metric_state_nonhydro.ddqz_z_full_e,
z_theta_v_e=z_fields.z_theta_v_e,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
vn_traj=prep_adv.vn_traj,
mass_flx_me=prep_adv.mass_flx_me,
r_nsubsteps=r_nsubsteps,
horizontal_start=start_edge_lb_plus4,
horizontal_end=end_edge_local_minus2,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

if lprep_adv: # Preparations for tracer advection
log.debug("corrector: doing prep advection")
if lclean_mflx:
log.debug("corrector: start stencil 33")
set_two_edge_kdim_fields_to_zero_wp.with_backend(backend)(
edge_kdim_field_to_zero_wp_1=prep_adv.vn_traj,
edge_kdim_field_to_zero_wp_2=prep_adv.mass_flx_me,
horizontal_start=start_edge_lb,
horizontal_end=end_edge_end,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)
log.debug(f"corrector: start stencil 34")
accumulate_prep_adv_fields.with_backend(backend)(
z_vn_avg=self.z_vn_avg,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
vn_traj=prep_adv.vn_traj,
mass_flx_me=prep_adv.mass_flx_me,
r_nsubsteps=r_nsubsteps,
horizontal_start=start_edge_lb_plus4,
horizontal_end=end_edge_local_minus2,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

if self.config.idiv_method == 1:
# verified for e-9
log.debug(f"corrector: start stencile 41")
compute_divergence_of_fluxes_of_rho_and_theta.with_backend(backend)(
Expand Down Expand Up @@ -1961,7 +1931,7 @@ def run_corrector_step(
r_nsubsteps=r_nsubsteps,
horizontal_start=start_cell_nudging,
horizontal_end=end_cell_local,
vertical_start=0,
vertical_start=1,
vertical_end=self.grid.num_levels,
offset_provider={},
)
Expand All @@ -1982,8 +1952,8 @@ def run_corrector_step(
if lprep_adv:
if lclean_mflx:
log.debug(f"corrector set prep_adv.mass_flx_ic to zero")
set_zero_c_k.with_backend(backend)(
field=prep_adv.mass_flx_ic,
set_cell_kdim_field_to_zero_wp.with_backend(backend)(
field_to_zero_wp=prep_adv.mass_flx_ic,
horizontal_start=start_cell_lb,
horizontal_end=end_cell_nudging,
vertical_start=0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,18 @@
from icon4py.model.atmosphere.dycore.set_cell_kdim_field_to_zero_vp import (
_set_cell_kdim_field_to_zero_vp,
)
from icon4py.model.atmosphere.dycore.set_cell_kdim_field_to_zero_wp import (
_set_cell_kdim_field_to_zero_wp,
)
from icon4py.model.atmosphere.dycore.set_lower_boundary_condition_for_w_and_contravariant_correction import (
_set_lower_boundary_condition_for_w_and_contravariant_correction,
)
from icon4py.model.atmosphere.dycore.set_theta_v_prime_ic_at_lower_boundary import (
_set_theta_v_prime_ic_at_lower_boundary,
)
from icon4py.model.atmosphere.dycore.state_utils.utils import _set_zero_c_k, _set_zero_e_k
from icon4py.model.atmosphere.dycore.state_utils.utils import (
_broadcast_zero_to_three_edge_kdim_fields_wp,
)
from icon4py.model.atmosphere.dycore.update_densety_exener_wind import _update_densety_exener_wind
from icon4py.model.atmosphere.dycore.update_wind import _update_wind
from icon4py.model.common.dimension import CEDim, CellDim, ECDim, EdgeDim, KDim
Expand All @@ -92,19 +97,12 @@ def init_test_fields(
indices_cells_2: int32,
nlev: int32,
):
_set_zero_e_k(
out=z_rho_e,
domain={EdgeDim: (indices_edges_1, indices_edges_2), KDim: (0, nlev)},
)
_set_zero_e_k(
out=z_theta_v_e,
_broadcast_zero_to_three_edge_kdim_fields_wp(
out=(z_rho_e, z_theta_v_e, z_graddiv_vn),
domain={EdgeDim: (indices_edges_1, indices_edges_2), KDim: (0, nlev)},
)
_set_zero_e_k(
out=z_graddiv_vn,
domain={EdgeDim: (indices_edges_1, indices_edges_2), KDim: (0, nlev)},
)
_set_zero_c_k(

_set_cell_kdim_field_to_zero_wp(
out=z_dwdz_dd,
domain={CellDim: (indices_cells_1, indices_cells_2), KDim: (0, nlev)},
)
Expand All @@ -125,7 +123,7 @@ def _predictor_stencils_2_3(
_extrapolate_temporally_exner_pressure(exner_exfac, exner, exner_ref_mc, exner_pr),
(z_exner_ex_pr, exner_pr),
)
z_exner_ex_pr = where(k_field == nlev, _set_zero_c_k(), z_exner_ex_pr)
z_exner_ex_pr = where(k_field == nlev, _set_cell_kdim_field_to_zero_wp(), z_exner_ex_pr)

return z_exner_ex_pr, exner_pr

Expand Down
Loading
Loading