Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add functionality to return a zero_field for optional fields on the serializer #450

Merged
merged 4 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def test_run_diffusion_single_step(
cell_geometry: CellParams = grid_savepoint.construct_cell_geometry()
interpolation_state = construct_interpolation_state(interpolation_savepoint)
metric_state = construct_metric_state(metrics_savepoint)
diagnostic_state = construct_diagnostics(diffusion_savepoint_init, grid_savepoint)
diagnostic_state = construct_diagnostics(diffusion_savepoint_init)
prognostic_state = diffusion_savepoint_init.construct_prognostics()
vertical_params = VerticalModelParams(
vct_a=grid_savepoint.vct_a(),
Expand Down Expand Up @@ -325,7 +325,7 @@ def test_run_diffusion_initial_step(
cell_geometry: CellParams = grid_savepoint.construct_cell_geometry()
interpolation_state = construct_interpolation_state(interpolation_savepoint)
metric_state = construct_metric_state(metrics_savepoint)
diagnostic_state = construct_diagnostics(diffusion_savepoint_init, grid_savepoint)
diagnostic_state = construct_diagnostics(diffusion_savepoint_init)
prognostic_state = diffusion_savepoint_init.construct_prognostics()
vct_a = grid_savepoint.vct_a()
vertical_params = VerticalModelParams(vct_a=vct_a, rayleigh_damping_height=damping_height)
Expand Down
11 changes: 4 additions & 7 deletions model/atmosphere/diffusion/tests/diffusion_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
DiffusionInterpolationState,
DiffusionMetricState,
)
from icon4py.model.common.dimension import CEDim, CellDim, KDim
from icon4py.model.common.dimension import CEDim
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.helpers import as_1D_sparse_field, dallclose, zero_field
from icon4py.model.common.test_utils.helpers import as_1D_sparse_field, dallclose
from icon4py.model.common.test_utils.serialbox_utils import (
IconDiffusionExitSavepoint,
IconDiffusionInitSavepoint,
IconGridSavepoint,
InterpolationSavepoint,
MetricSavepoint,
)
Expand Down Expand Up @@ -167,11 +166,9 @@ def construct_metric_state(savepoint: MetricSavepoint) -> DiffusionMetricState:

def construct_diagnostics(
savepoint: IconDiffusionInitSavepoint,
grid_savepoint: IconGridSavepoint,
) -> DiffusionDiagnosticState:
grid = grid_savepoint.construct_icon_grid(on_gpu=False)
dwdx = savepoint.dwdx() if savepoint.dwdx() else zero_field(grid, CellDim, KDim)
dwdy = savepoint.dwdy() if savepoint.dwdy() else zero_field(grid, CellDim, KDim)
dwdx = savepoint.dwdx()
dwdy = savepoint.dwdy()
return DiffusionDiagnosticState(
hdef_ic=savepoint.hdef_ic(),
div_ic=savepoint.div_ic(),
Expand Down
48 changes: 28 additions & 20 deletions model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,34 @@
log = logging.getLogger(__name__)


def optionally_registered(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
name = func.__name__
return func(self, *args, **kwargs)
except serialbox.SerialboxError:
log.warning(f"{name}: field not registered in savepoint {self.savepoint.metainfo}")
return None

return wrapper


class IconSavepoint:
def __init__(self, sp: ser.Savepoint, ser: ser.Serializer, size: dict):
self.savepoint = sp
self.serializer = ser
self.sizes = size
self.log = logging.getLogger((__name__))

def optionally_registered(*dims):
def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
name = func.__name__
return func(self, *args, **kwargs)
except serialbox.SerialboxError:
log.warning(
f"{name}: field not registered in savepoint {self.savepoint.metainfo}"
)
if dims:
shp = tuple(self.sizes[d] for d in dims)
return as_field(dims, np.zeros(shp))
else:
return None

return wrapper

return decorator

def log_meta_info(self):
self.log.info(self.savepoint.metainfo)

Expand Down Expand Up @@ -438,7 +446,7 @@ def geofac_grg(self):
(CellDim, C2E2CODim), grg[:num_cells, :, 1]
)

@optionally_registered
@IconSavepoint.optionally_registered()
def zd_intcoef(self):
return self._get_field("vcoef", CellDim, C2E2CDim, KDim)

Expand Down Expand Up @@ -566,7 +574,7 @@ def ddxn_z_full(self):
def ddxt_z_full(self):
return self._get_field("ddxt_z_full", EdgeDim, KDim)

@optionally_registered
@IconSavepoint.optionally_registered(CellDim, KDim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing that same input argument dims has to be in both the optionally_registered and 'get_field' does not look super nice to me. But I can't think of a better way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I first wanted to get those args from the decorated function, but that is not so easy because it is pass to the get_field and not the decorated function.

def mask_hdiff(self):
return self._get_field("mask_hdiff", CellDim, KDim, dtype=bool)

Expand All @@ -587,11 +595,11 @@ def wgtfacq_e_dsl(
ar = np.pad(ar[:, ::-1], ((0, 0), (k, 0)), "constant", constant_values=(0.0,))
return self._get_field_from_ndarray(ar, EdgeDim, KDim)

@optionally_registered
@IconSavepoint.optionally_registered(CellDim, KDim)
def zd_diffcoef(self):
return self._get_field("zd_diffcoef", CellDim, KDim)

@optionally_registered
@IconSavepoint.optionally_registered()
def zd_intcoef(self):
return self._read_and_reorder_sparse_field("vcoef")

Expand All @@ -611,7 +619,7 @@ def _linearize_first_2dims(
assert old_shape[1] == sparse_size
return as_field(target_dims, data.reshape(old_shape[0] * old_shape[1], old_shape[2]))

@optionally_registered
@IconSavepoint.optionally_registered()
def zd_vertoffset(self):
return self._read_and_reorder_sparse_field("zd_vertoffset")

Expand All @@ -629,11 +637,11 @@ def hdef_ic(self):
def div_ic(self):
return self._get_field("div_ic", CellDim, KDim)

@optionally_registered
@IconSavepoint.optionally_registered(CellDim, KDim)
def dwdx(self):
return self._get_field("dwdx", CellDim, KDim)

@optionally_registered
@IconSavepoint.optionally_registered(CellDim, KDim)
def dwdy(self):
return self._get_field("dwdy", CellDim, KDim)

Expand Down
Loading