Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype fields factory #533

Draft
wants to merge 50 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
51bd673
WIP
halungge Jun 27, 2024
2270522
add backend to metric_fields stencils
halungge Jun 27, 2024
0bf8d18
ugly version that works for gtfn programs
halungge Jun 27, 2024
b78b24f
use operator.add instead of lambda
halungge Jul 1, 2024
faa931a
merge main
halungge Aug 8, 2024
5836a32
reduce dependencies, move ProgramFieldProvider out of Factory
halungge Aug 8, 2024
6f3e6c6
rename fields
halungge Aug 9, 2024
6d99a17
Merge branch 'main' into prototype_fields_factory
halungge Aug 14, 2024
21c744b
move factory.py to states package
halungge Aug 15, 2024
bf7dc7e
remove duplicated computation of wgtfacq_c_dsl
halungge Aug 15, 2024
d07fef2
fix type annotations for arrays
halungge Aug 15, 2024
8bb63f6
add type annotations to compute_vwind_impl_wgt.py
halungge Aug 15, 2024
a9b0b54
FieldProvider for numpy functions (WIP I)
halungge Aug 16, 2024
dc809e8
merge main
halungge Aug 16, 2024
ffb4661
first version for numpy functions
halungge Aug 16, 2024
9f042b1
fix: move _unallocated to ProgramFieldProvider
halungge Aug 20, 2024
809f060
move joint functionality into FieldProvider
halungge Aug 20, 2024
bcd65b5
- switch to device dependent import in compute_wgtfacq.py
halungge Aug 20, 2024
52a837d
add type annotation to connectivity
halungge Aug 21, 2024
72e742b
handle numpy field with connectivity
halungge Aug 21, 2024
d93c570
merge main
halungge Aug 27, 2024
fba0891
add type to get_processor_properties argument
halungge Aug 28, 2024
c2c250a
add c_lin_e metadata
halungge Aug 28, 2024
04645e0
start_index, end_index abstraction for vertical (WIP)
halungge Aug 28, 2024
306b761
basic sample of factory.
halungge Aug 28, 2024
cec01f9
fix with_allocator function
halungge Aug 29, 2024
57bf95d
update with upstrean
nfarabullini Sep 3, 2024
966abbd
update with upstrean
nfarabullini Sep 3, 2024
aa2c402
ran pre-commit and made fixes
nfarabullini Sep 3, 2024
afe3f47
small edit
nfarabullini Sep 3, 2024
df8ba00
Merge branch 'main' into prototype_fields_factory
halungge Sep 5, 2024
8f8d8de
using domains for the compute domain in factory
halungge Sep 5, 2024
e1ec531
add docstring to Providers
halungge Sep 10, 2024
62c21ae
separate vertical and horizontal connectivities
halungge Sep 13, 2024
f98f8dc
pre-commit
halungge Sep 13, 2024
110bec6
Merge branch 'main' into prototype_fields_factory
halungge Sep 17, 2024
8cdcea6
merge main
halungge Sep 19, 2024
1e4a20f
Merge branch 'main' into prototype_fields_factory
halungge Sep 24, 2024
e697cc8
Merge branch 'main' into prototype_fields_factory
halungge Sep 27, 2024
8c7b782
Merge branch 'main' into prototype_fields_factory
halungge Oct 1, 2024
e417fd1
add types for metadata attributes
halungge Oct 2, 2024
75bda6d
fix int32 issues (ad hoc fix)
halungge Oct 2, 2024
f978d72
rename providers, fixes in FieldProvider Protocol
halungge Oct 2, 2024
e635e3d
add FieldSource Protocol
halungge Oct 3, 2024
fef2ced
fix doc strings, pre-commit
halungge Oct 3, 2024
1d7abf5
Merge branch 'main' into prototype_fields_factory
halungge Oct 4, 2024
02cce48
add documentation
halungge Oct 4, 2024
70063a7
move FieldSource protocol
halungge Oct 10, 2024
19ceae8
add return type to FieldSource.get(...)
halungge Oct 11, 2024
ae937d4
Split factory argument in FieldProvider to several protocols
halungge Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def get_runtype(with_mpi: bool = False) -> RunType:


@functools.singledispatch
def get_processor_properties(runtime) -> ProcessProperties:
def get_processor_properties(runtime: RunType) -> ProcessProperties:
raise TypeError(f"Cannot define ProcessProperties for ({type(runtime)})")


Expand Down
7 changes: 6 additions & 1 deletion model/common/src/icon4py/model/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ class InvalidConfigError(Exception):
pass


class IncompleteSetupError(Exception):
def __init__(self, msg):
super().__init__(f"{msg}")


class IncompleteStateError(Exception):
def __init__(self, field_name):
super().__init__(f"Field '{field_name}' is missing in state.")
super().__init__(f"Field '{field_name}' is missing.")


class IconGridError(RuntimeError):
Expand Down
14 changes: 8 additions & 6 deletions model/common/src/icon4py/model/common/grid/icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def n_shift(self):
def lvert_nest(self):
return True if self.config.lvertnest else False

def start_index(self, domain: h_grid.Domain):
def start_index(self, domain: h_grid.Domain) -> gtx.int32:
"""
Use to specify lower end of domains of a field for field_operators.

Expand All @@ -177,10 +177,11 @@ def start_index(self, domain: h_grid.Domain):
"""
if domain.local:
# special treatment because this value is not set properly in the underlying data.
return 0
return self._start_indices[domain.dim][domain()].item()
return gtx.int32(0)
# ndarray.item() does not respect the dtype of the array, returns a copy of the value _as the default python type_
return gtx.int32(self._start_indices[domain.dim][domain()])

def end_index(self, domain: h_grid.Domain):
def end_index(self, domain: h_grid.Domain) -> gtx.int32:
"""
Use to specify upper end of domains of a field for field_operators.

Expand All @@ -189,5 +190,6 @@ def end_index(self, domain: h_grid.Domain):
"""
if domain.zone == h_grid.Zone.INTERIOR and not self.limited_area:
# special treatment because this value is not set properly in the underlying data, for a global grid
return self.size[domain.dim]
return self._end_indices[domain.dim][domain()].item()
return gtx.int32(self.size[domain.dim])
# ndarray.item() does not respect the dtype of the array, returns a copy of the value _as the default python builtin type_
return gtx.int32(self._end_indices[domain.dim][domain()].item())
31 changes: 19 additions & 12 deletions model/common/src/icon4py/model/common/grid/vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import gt4py.next as gtx

import icon4py.model.common.states.metadata as data
from icon4py.model.common import dimension as dims, exceptions, field_type_aliases as fa
from icon4py.model.common.settings import xp

Expand Down Expand Up @@ -58,6 +59,14 @@ def _validate(self):
), f"{self.marker} needs to be combined with positive offest, but offset = {self.offset}"


def domain(dim: gtx.Dimension):
def _domain(marker: Zone):
assert dim.kind == gtx.DimensionKind.VERTICAL, "Only vertical dimensions are supported"
return Domain(dim, marker)

return _domain


@dataclasses.dataclass(frozen=True)
class VerticalGridConfig:
"""
Expand Down Expand Up @@ -159,19 +168,17 @@ def __str__(self) -> str:
return "\n".join(vertical_params_properties)

@property
def metadata_interface_physical_height(self) -> dict:
return dict(
standard_name="model_interface_height",
long_name="height value of half levels without topography",
units="m",
positive="up",
icon_var_name="vct_a",
)
def metadata_interface_physical_height(self):
return data.attrs["model_interface_height"]

@property
def num_levels(self):
return self.config.num_levels

def index(self, domain: Domain) -> gtx.int32:
match domain.marker:
case Zone.TOP:
index = gtx.int32(0)
index = 0
case Zone.BOTTOM:
index = self._bottom_level(domain)
case Zone.MOIST:
Expand All @@ -187,10 +194,10 @@ def index(self, domain: Domain) -> gtx.int32:
assert (
0 <= index <= self._bottom_level(domain)
), f"vertical index {index} outside of grid levels for {domain.dim}"
return index
return gtx.int32(index)

def _bottom_level(self, domain: Domain) -> gtx.int32:
return gtx.int32(self.size(domain.dim))
def _bottom_level(self, domain: Domain) -> int:
return self.size(domain.dim)

@property
def interface_physical_height(self) -> fa.KField[float]:
Expand Down
30 changes: 15 additions & 15 deletions model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
from icon4py.model.common.settings import xp


def _compute_z1_z2_z3(
z_ifc: np.ndarray, i1: int, i2: int, i3: int, i4: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
z_ifc: xp.ndarray, i1: int, i2: int, i3: int, i4: int
) -> tuple[xp.ndarray, xp.ndarray, xp.ndarray]:
z1 = 0.5 * (z_ifc[:, i2] - z_ifc[:, i1])
z2 = 0.5 * (z_ifc[:, i2] + z_ifc[:, i3]) - z_ifc[:, i1]
z3 = 0.5 * (z_ifc[:, i3] + z_ifc[:, i4]) - z_ifc[:, i1]
return z1, z2, z3


def compute_wgtfacq_c_dsl(
z_ifc: np.ndarray,
z_ifc: xp.ndarray,
nlev: int,
) -> np.ndarray:
) -> xp.ndarray:
"""
Compute weighting factor for quadratic interpolation to surface.

Expand All @@ -31,8 +31,8 @@ def compute_wgtfacq_c_dsl(
Returns:
Field[CellDim, KDim] (full levels)
"""
wgtfacq_c = np.zeros((z_ifc.shape[0], nlev + 1))
wgtfacq_c_dsl = np.zeros((z_ifc.shape[0], nlev))
wgtfacq_c = xp.zeros((z_ifc.shape[0], nlev + 1))
wgtfacq_c_dsl = xp.zeros((z_ifc.shape[0], nlev))
z1, z2, z3 = _compute_z1_z2_z3(z_ifc, nlev, nlev - 1, nlev - 2, nlev - 3)

wgtfacq_c[:, 2] = z1 * z2 / (z2 - z3) / (z1 - z3)
Expand All @@ -47,10 +47,10 @@ def compute_wgtfacq_c_dsl(


def compute_wgtfacq_e_dsl(
e2c,
z_ifc: np.ndarray,
c_lin_e: np.ndarray,
wgtfacq_c_dsl: np.ndarray,
e2c: xp.ndarray,
z_ifc: xp.ndarray,
c_lin_e: xp.ndarray,
wgtfacq_c_dsl: xp.ndarray,
n_edges: int,
nlev: int,
):
Expand All @@ -67,8 +67,8 @@ def compute_wgtfacq_e_dsl(
Returns:
Field[EdgeDim, KDim] (full levels)
"""
wgtfacq_e_dsl = np.zeros(shape=(n_edges, nlev + 1))
z_aux_c = np.zeros((z_ifc.shape[0], 6))
wgtfacq_e_dsl = xp.zeros(shape=(n_edges, nlev + 1))
z_aux_c = xp.zeros((z_ifc.shape[0], 6))
z1, z2, z3 = _compute_z1_z2_z3(z_ifc, nlev, nlev - 1, nlev - 2, nlev - 3)
z_aux_c[:, 2] = z1 * z2 / (z2 - z3) / (z1 - z3)
z_aux_c[:, 1] = (z1 - wgtfacq_c_dsl[:, nlev - 3] * (z1 - z3)) / (z1 - z2)
Expand All @@ -79,8 +79,8 @@ def compute_wgtfacq_e_dsl(
z_aux_c[:, 4] = (z1 - z_aux_c[:, 5] * (z1 - z3)) / (z1 - z2)
z_aux_c[:, 3] = 1.0 - (z_aux_c[:, 4] + z_aux_c[:, 5])

c_lin_e = c_lin_e[:, :, np.newaxis]
z_aux_e = np.sum(c_lin_e * z_aux_c[e2c], axis=1)
c_lin_e = c_lin_e[:, :, xp.newaxis]
z_aux_e = xp.sum(c_lin_e * z_aux_c[e2c], axis=1)

wgtfacq_e_dsl[:, nlev] = z_aux_e[:, 0]
wgtfacq_e_dsl[:, nlev - 1] = z_aux_e[:, 1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
where,
)

from icon4py.model.common import dimension as dims, field_type_aliases as fa
from icon4py.model.common import dimension as dims, field_type_aliases as fa, settings
from icon4py.model.common.dimension import (
C2E,
C2E2C,
Expand Down Expand Up @@ -61,7 +61,7 @@ class MetricsConfig:
exner_expol: Final[wpfloat] = 0.3333333333333


@program(grid_type=GridType.UNSTRUCTURED)
@program(grid_type=GridType.UNSTRUCTURED, backend=settings.backend)
def compute_z_mc(
z_ifc: fa.CellKField[wpfloat],
z_mc: fa.CellKField[wpfloat],
Expand Down Expand Up @@ -109,7 +109,7 @@ def _compute_ddqz_z_half(
return ddqz_z_half


@program(grid_type=GridType.UNSTRUCTURED, backend=None)
@program(grid_type=GridType.UNSTRUCTURED, backend=settings.backend)
def compute_ddqz_z_half(
z_ifc: fa.CellKField[wpfloat],
z_mc: fa.CellKField[wpfloat],
Expand Down
68 changes: 68 additions & 0 deletions model/common/src/icon4py/model/common/metrics/metrics_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# ICON4Py - ICON inspired code in Python and GT4Py
#
# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import pathlib

import icon4py.model.common.states.factory as factory
from icon4py.model.common import dimension as dims
from icon4py.model.common.decomposition import definitions as decomposition
from icon4py.model.common.grid import horizontal
from icon4py.model.common.metrics import metric_fields as mf
from icon4py.model.common.test_utils import datatest_utils as dt_utils, serialbox_utils as sb


# we need to register a couple of fields from the serializer. Those should get replaced one by one.

dt_utils.TEST_DATA_ROOT = pathlib.Path(__file__).parent / "testdata"
properties = decomposition.get_processor_properties(decomposition.get_runtype(with_mpi=False))
path = dt_utils.get_ranked_data_path(dt_utils.SERIALIZED_DATA_PATH, properties)

data_provider = sb.IconSerialDataProvider(
"icon_pydycore", str(path.absolute()), False, mpi_rank=properties.rank
)

# z_ifc (computable from vertical grid for model without topography)
metrics_savepoint = data_provider.from_metrics_savepoint()

# interpolation fields also for now passing as precomputed fields
interpolation_savepoint = data_provider.from_interpolation_savepoint()
# can get geometry fields as pre computed fields from the grid_savepoint
grid_savepoint = data_provider.from_savepoint_grid()
#######

# start build up factory:


interface_model_height = metrics_savepoint.z_ifc()
c_lin_e = interpolation_savepoint.c_lin_e()

fields_factory = factory.FieldsFactory()

# used for vertical domain below: should go away once vertical grid provids start_index and end_index like interface
grid = grid_savepoint.global_grid_params

fields_factory.register_provider(
factory.PrecomputedFieldProvider(
{
"height_on_interface_levels": interface_model_height,
"cell_to_edge_interpolation_coefficient": c_lin_e,
}
)
)
height_provider = factory.ProgramFieldProvider(
func=mf.compute_z_mc,
domain={
dims.CellDim: (
horizontal.HorizontalMarkerIndex.local(dims.CellDim),
horizontal.HorizontalMarkerIndex.end(dims.CellDim),
),
dims.KDim: (0, grid.num_levels),
},
fields={"z_mc": "height"},
deps={"z_ifc": "height_on_interface_levels"},
)
Loading
Loading