Skip to content

Commit

Permalink
backend edits
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Oct 9, 2024
1 parent 8d14f80 commit 936c7d8
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import gt4py.next as gtx

import icon4py.model.common.states.prognostic_state as prognostics
from gt4py.next.backend import Backend
from icon4py.model.atmosphere.diffusion import diffusion_utils, diffusion_states
from icon4py.model.atmosphere.diffusion.diffusion_utils import (
copy_field,
Expand Down Expand Up @@ -348,7 +349,8 @@ class Diffusion:
"""Class that configures diffusion and does one diffusion step."""

def __init__(
self, exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange()
self, exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
backend: Backend = None
):
self._exchange = exchange
self._initialized = False
Expand All @@ -372,6 +374,16 @@ def __init__(
self.edge_params: Optional[geometry.EdgeParams] = None
self.cell_params: Optional[geometry.CellParams] = None
self._horizontal_start_index_w_diffusion: gtx.int32 = gtx.int32(0)
self.mo_intp_rbf_rbf_vec_interpol_vertex = mo_intp_rbf_rbf_vec_interpol_vertex.with_backend(backend)
self.calculate_nabla2_and_smag_coefficients_for_vn = calculate_nabla2_and_smag_coefficients_for_vn.with_backend(backend)
self.calculate_diagnostic_quantities_for_turbulence = calculate_diagnostic_quantities_for_turbulence.with_backend(backend)
self.apply_diffusion_to_vn = apply_diffusion_to_vn.with_backend(backend)
self.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence = apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence.with_backend(backend)
self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools = calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools.with_backend(backend)
self.calculate_nabla2_for_theta = calculate_nabla2_for_theta.with_backend(backend)
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points = truly_horizontal_diffusion_nabla_of_theta_over_steep_points.with_backend(backend)
self.update_theta_and_exner = update_theta_and_exner.with_backend(backend)


def init(
self,
Expand Down Expand Up @@ -611,7 +623,7 @@ def _do_diffusion_step(
scale_k(self.enh_smag_fac, dtime, self.diff_multfac_smag, offset_provider={})

log.debug("rbf interpolation 1: start")
mo_intp_rbf_rbf_vec_interpol_vertex(
self.mo_intp_rbf_rbf_vec_interpol_vertex(
p_e_in=prognostic_state.vn,
ptr_coeff_1=self.interpolation_state.rbf_coeff_1,
ptr_coeff_2=self.interpolation_state.rbf_coeff_2,
Expand All @@ -631,7 +643,7 @@ def _do_diffusion_step(
log.debug("communication rbf extrapolation of vn - end")

log.debug("running stencil 01(calculate_nabla2_and_smag_coefficients_for_vn): start")
calculate_nabla2_and_smag_coefficients_for_vn(
self.calculate_nabla2_and_smag_coefficients_for_vn(
diff_multfac_smag=self.diff_multfac_smag,
tangent_orientation=self.edge_params.tangent_orientation,
inv_primal_edge_length=self.edge_params.inverse_primal_edge_lengths,
Expand Down Expand Up @@ -662,7 +674,7 @@ def _do_diffusion_step(
log.debug(
"running stencils 02 03 (calculate_diagnostic_quantities_for_turbulence): start"
)
calculate_diagnostic_quantities_for_turbulence(
self.calculate_diagnostic_quantities_for_turbulence(
kh_smag_ec=self.kh_smag_ec,
vn=prognostic_state.vn,
e_bln_c_s=self.interpolation_state.e_bln_c_s,
Expand All @@ -689,7 +701,7 @@ def _do_diffusion_step(
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("2nd rbf interpolation: start")
mo_intp_rbf_rbf_vec_interpol_vertex(
self.mo_intp_rbf_rbf_vec_interpol_vertex(
p_e_in=self.z_nabla2_e,
ptr_coeff_1=self.interpolation_state.rbf_coeff_1,
ptr_coeff_2=self.interpolation_state.rbf_coeff_2,
Expand All @@ -709,7 +721,7 @@ def _do_diffusion_step(
log.debug("communication rbf extrapolation of z_nable2_e - end")

log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): start")
apply_diffusion_to_vn(
self.apply_diffusion_to_vn(
u_vert=self.u_vert,
v_vert=self.v_vert,
primal_normal_vert_v1=self.edge_params.primal_normal_vert[0],
Expand Down Expand Up @@ -743,7 +755,7 @@ def _do_diffusion_step(
# TODO (magdalena) get rid of this copying. So far passing an empty buffer instead did not verify?
copy_field(prognostic_state.w, self.w_tmp, offset_provider={})

apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
self.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
area=self.cell_params.area,
geofac_n2s=self.interpolation_state.geofac_n2s,
geofac_grg_x=self.interpolation_state.geofac_grg_x,
Expand Down Expand Up @@ -776,7 +788,7 @@ def _do_diffusion_step(
"running fused stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): start"
)

calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
theta_v=prognostic_state.theta_v,
theta_ref_mc=self.metric_state.theta_ref_mc,
thresh_tdiff=self.thresh_tdiff,
Expand All @@ -793,7 +805,7 @@ def _do_diffusion_step(
)

log.debug("running stencils 13 14 (calculate_nabla2_for_theta): start")
calculate_nabla2_for_theta(
self.calculate_nabla2_for_theta(
kh_smag_e=self.kh_smag_e,
inv_dual_edge_length=self.edge_params.inverse_dual_edge_lengths,
theta_v=prognostic_state.theta_v,
Expand All @@ -810,7 +822,7 @@ def _do_diffusion_step(
"running stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): start"
)
if self.config.apply_zdiffusion_t:
truly_horizontal_diffusion_nabla_of_theta_over_steep_points(
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points(
mask=self.metric_state.mask_hdiff,
zd_vertoffset=self.metric_state.zd_vertoffset,
zd_diffcoef=self.metric_state.zd_diffcoef,
Expand All @@ -830,7 +842,7 @@ def _do_diffusion_step(
"running fused stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): end"
)
log.debug("running stencil 16 (update_theta_and_exner): start")
update_theta_and_exner(
self.update_theta_and_exner(
z_temp=self.z_temp,
area=self.cell_params.area,
theta_v=prognostic_state.theta_v,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _apply_diffusion_to_vn(
return vn


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def apply_diffusion_to_vn(
u_vert: fa.VertexKField[vpfloat],
v_vert: fa.VertexKField[vpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
return w, dwdx, dwdy


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence(
area: fa.CellField[wpfloat],
geofac_n2s: Field[[CellDim, C2E2CODim], wpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _calculate_diagnostic_quantities_for_turbulence(
return div_ic_vp, hdef_ic_vp


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def calculate_diagnostic_quantities_for_turbulence(
kh_smag_ec: fa.EdgeKField[vpfloat],
vn: fa.EdgeKField[wpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
return kh_smag_e_vp


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools(
theta_v: fa.CellKField[wpfloat],
theta_ref_mc: fa.CellKField[vpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _calculate_nabla2_and_smag_coefficients_for_vn(
return kh_smag_e_vp, astype(kh_smag_ec_wp, vpfloat), z_nabla2_e_wp


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def calculate_nabla2_and_smag_coefficients_for_vn(
diff_multfac_smag: Field[[dims.KDim], vpfloat],
tangent_orientation: fa.EdgeField[wpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _calculate_nabla2_for_theta(
return z_temp


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def calculate_nabla2_for_theta(
kh_smag_e: fa.EdgeKField[float],
inv_dual_edge_length: fa.EdgeField[float],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _truly_horizontal_diffusion_nabla_of_theta_over_steep_points(
return astype(z_temp_wp, vpfloat)


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def truly_horizontal_diffusion_nabla_of_theta_over_steep_points(
mask: fa.CellKField[bool],
zd_vertoffset: Field[[dims.CECDim, dims.KDim], int32],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _update_theta_and_exner(
return theta_v, exner


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def update_theta_and_exner(
z_temp: fa.CellKField[vpfloat],
area: fa.CellField[wpfloat],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def test_diffusion_init(
stretch_factor,
damping_height,
ndyn_substeps,
backend
):
config = construct_diffusion_config(experiment, ndyn_substeps=ndyn_substeps)
additional_parameters = diffusion.DiffusionParams(config)
Expand Down Expand Up @@ -149,7 +150,7 @@ def test_diffusion_init(
edge_params = grid_savepoint.construct_edge_geometry()
cell_params = grid_savepoint.construct_cell_geometry()

diffusion_granule = diffusion.Diffusion()
diffusion_granule = diffusion.Diffusion(backend=backend)
diffusion_granule.init(
grid=icon_grid,
config=config,
Expand Down Expand Up @@ -252,6 +253,7 @@ def test_verify_diffusion_init_against_savepoint(
stretch_factor,
damping_height,
ndyn_substeps,
backend
):
config = construct_diffusion_config(experiment, ndyn_substeps=ndyn_substeps)
additional_parameters = diffusion.DiffusionParams(config)
Expand Down Expand Up @@ -289,7 +291,7 @@ def test_verify_diffusion_init_against_savepoint(
edge_params = grid_savepoint.construct_edge_geometry()
cell_params = grid_savepoint.construct_cell_geometry()

diffusion_granule = diffusion.Diffusion()
diffusion_granule = diffusion.Diffusion(backend=backend)
diffusion_granule.init(
icon_grid,
config,
Expand Down Expand Up @@ -326,6 +328,7 @@ def test_run_diffusion_single_step(
stretch_factor,
damping_height,
ndyn_substeps,
backend
):
dtime = savepoint_diffusion_init.get_metadata("dtime").get("dtime")
edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry()
Expand Down Expand Up @@ -375,7 +378,7 @@ def test_run_diffusion_single_step(
config = construct_diffusion_config(experiment, ndyn_substeps)
additional_parameters = diffusion.DiffusionParams(config)

diffusion_granule = diffusion.Diffusion()
diffusion_granule = diffusion.Diffusion(backend=backend)
diffusion_granule.init(
grid=icon_grid,
config=config,
Expand Down Expand Up @@ -413,6 +416,7 @@ def test_run_diffusion_initial_step(
metrics_savepoint,
grid_savepoint,
icon_grid,
backend
):
dtime = savepoint_diffusion_init.get_metadata("dtime").get("dtime")
edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry()
Expand Down Expand Up @@ -458,7 +462,7 @@ def test_run_diffusion_initial_step(
config = construct_diffusion_config(experiment, ndyn_substeps=2)
additional_parameters = diffusion.DiffusionParams(config)

diffusion_granule = diffusion.Diffusion()
diffusion_granule = diffusion.Diffusion(backend=backend)
diffusion_granule.init(
grid=icon_grid,
config=config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _mo_intp_rbf_rbf_vec_interpol_vertex(
return p_u_out, p_v_out


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
@program(grid_type=GridType.UNSTRUCTURED)
def mo_intp_rbf_rbf_vec_interpol_vertex(
p_e_in: fa.EdgeKField[wpfloat],
ptr_coeff_1: Field[[dims.VertexDim, V2EDim], wpfloat],
Expand Down

0 comments on commit 936c7d8

Please sign in to comment.