From 51bd673da50dab9aaf32822c716a292819752e3b Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 27 Jun 2024 11:25:11 +0200 Subject: [PATCH 01/37] WIP --- .../icon4py/model/common/metrics/factory.py | 119 ++++++++++++++++++ .../common/tests/metric_tests/test_factory.py | 16 +++ 2 files changed, 135 insertions(+) create mode 100644 model/common/src/icon4py/model/common/metrics/factory.py create mode 100644 model/common/tests/metric_tests/test_factory.py diff --git a/model/common/src/icon4py/model/common/metrics/factory.py b/model/common/src/icon4py/model/common/metrics/factory.py new file mode 100644 index 000000000..e200ee9ca --- /dev/null +++ b/model/common/src/icon4py/model/common/metrics/factory.py @@ -0,0 +1,119 @@ +from enum import IntEnum +from typing import Sequence + +import gt4py.next as gtx +import xarray as xa + +import icon4py.model.common.metrics.metric_fields as metrics +import icon4py.model.common.type_alias as ta +from icon4py.model.common.dimension import CellDim, KDim, KHalfDim +from icon4py.model.common.grid import icon +from icon4py.model.common.grid.base import BaseGrid + + +class RetrievalType(IntEnum): + FIELD = 0, + DATA_ARRAY = 1, + METADATA = 2, + +_attrs = {"functional_determinant_of_the_metrics_on_half_levels":dict( + standard_name="functional_determinant_of_the_metrics_on_half_levels", + long_name="functional determinant of the metrics [sqrt(gamma)] on half levels", + units="", + dims=(CellDim, KHalfDim), + icon_var_name="ddqz_z_half", + ), + "height": dict(standard_name="height", long_name="height", units="m", dims=(CellDim, KDim), icon_var_name="z_mc"), + "height_on_interface_levels": dict(standard_name="height_on_interface_levels", long_name="height_on_interface_levels", units="m", dims=(CellDim, KHalfDim), icon_var_name="z_ifc") + } + + +class FieldProviderImpl: + """ + In charge of computing a field and providing metadata about it. + TODO: change for tuples of fields + + """ + + # TODO that should be a sequence or a dict of fields, since func -> tuple[...] + def __init__(self, grid: BaseGrid, deps: Sequence['FieldProvider'], attrs: dict): + self.grid = grid + self.dependencies = deps + self._attrs = attrs + self.func = metrics.compute_z_mc + self.fields:Sequence[gtx.Field|None] = [] + + # TODO (@halungge) handle DType + def _allocate(self, fields:Sequence[gtx.Field], dimensions: Sequence[gtx.Dimension]): + domain = {dim: (0, self.grid.size[dim]) for dim in dimensions} + return [gtx.constructors.zeros(domain, dtype=ta.wpfloat) for _ in fields] + + def __call__(self): + if not self.fields: + self.field = self._allocate(self.fields, self._attrs["dims"]) + domain = (0, self.grid.num_cells, 0, self.grid.num_levels) + args = [dep(RetrievalType.FIELD) for dep in self.dependencies] + self.field = self.func(*args, self.field, *domain, + offset_provider=self.grid.offset_providers) + return self.field + + +class SimpleFieldProvider: + def id(x: gtx.Field) -> gtx.Field: + return x + + def __init__(self, grid: BaseGrid, field, attrs): + super().__init__(grid, deps=(), attrs=attrs) + self.func = self.id + self.field = field + + +# class FieldProvider(Protocol): +# +# func = metrics.compute_ddqz_z_half +# field: gtx.Field[gtx.Dims[CellDim, KDim], ta.wpfloat] = None +# +# def __init__(self, grid:BaseGrid, func, deps: Sequence['FieldProvider''], attrs): +# super().__init__(grid, deps=deps, attrs=attrs) +# self.func = func + +class MetricsFieldsFactory: + """ + Factory for metric fields. + """ + def __init__(self, grid:icon.IconGrid, z_ifc:gtx.Field): + self.grid = grid + self.z_ifc_provider = SimpleFieldProvider(self.grid, z_ifc, _attrs["height_on_interface_levels"]) + self._providers = {"height_on_interface_levels": self.z_ifc_provider} + + z_mc_provider = None + z_ddqz_provider = None + # TODO (@halungge) use TypedDict + self._providers["functional_determinant_of_the_metrics_on_half_levels"]= z_ddqz_provider + self._providers["height"] = z_mc_provider + + + def get(self, field_name: str, type_: RetrievalType): + if field_name not in _attrs: + raise ValueError(f"Field {field_name} not found in metric fields") + if type_ == RetrievalType.METADATA: + return _attrs[field_name] + if type_ == RetrievalType.FIELD: + return self._providers[field_name]() + if type_ == RetrievalType.DATA_ARRAY: + return to_data_array(self._providers[field_name](), _attrs[field_name]) + raise ValueError(f"Invalid retrieval type {type_}") + + +def to_data_array(field, attrs): + return xa.DataArray(field, attrs=attrs) + + + + + + + + + + diff --git a/model/common/tests/metric_tests/test_factory.py b/model/common/tests/metric_tests/test_factory.py new file mode 100644 index 000000000..eaf0a44b3 --- /dev/null +++ b/model/common/tests/metric_tests/test_factory.py @@ -0,0 +1,16 @@ +from icon4py.model.common.metrics import factory +from icon4py.model.common.metrics.factory import RetrievalType + + +def test_field_provider(icon_grid, metrics_savepoint): + z_ifc = factory.SimpleFieldProvider(icon_grid, metrics_savepoint.z_ifc(), factory._attrs["height_on_interface_levels"]) + z_mc = factory.FieldProvider(grid=icon_grid, deps=(z_ifc,), attrs=factory._attrs["height"]) + data_array = z_mc(RetrievalType.FIELD) + + #assert dallclose(metrics_savepoint.z_mc(), data_array.ndarray) + + + #provider = factory.FieldProviderImpl(icon_grid, (z_ifc, z_mc), attrs=factory.attrs["functional_determinant_of_the_metrics_on_half_levels"]) + #provider() + + \ No newline at end of file From 2270522553af23487fabc2f5503a711aaf739301 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 27 Jun 2024 22:49:50 +0200 Subject: [PATCH 02/37] add backend to metric_fields stencils fix vertical dimension in z_mc --- .../src/icon4py/model/common/metrics/metric_fields.py | 7 ++++--- model/common/tests/metric_tests/test_metric_fields.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/model/common/src/icon4py/model/common/metrics/metric_fields.py b/model/common/src/icon4py/model/common/metrics/metric_fields.py index 1900843ae..8eb671b50 100644 --- a/model/common/src/icon4py/model/common/metrics/metric_fields.py +++ b/model/common/src/icon4py/model/common/metrics/metric_fields.py @@ -30,6 +30,7 @@ where, ) +from icon4py.model.common import settings from icon4py.model.common.dimension import ( C2E, E2C, @@ -64,7 +65,7 @@ class MetricsConfig: exner_expol: Final[wpfloat] = 0.333 -@program(grid_type=GridType.UNSTRUCTURED) +@program(grid_type=GridType.UNSTRUCTURED, backend=settings.backend) def compute_z_mc( z_ifc: Field[[CellDim, KDim], wpfloat], z_mc: Field[[CellDim, KDim], wpfloat], @@ -82,7 +83,7 @@ def compute_z_mc( Args: z_ifc: Field[[CellDim, KDim], wpfloat] geometric height on half levels z_mc: Field[[CellDim, KDim], wpfloat] output, geometric height defined on full levels - horizontal_start:int32 start index of horizontal domain + horizontal_start: horizontal_end:int32 end index of horizontal domain vertical_start:int32 start index of vertical domain vertical_end:int32 end index of vertical domain @@ -109,7 +110,7 @@ def _compute_ddqz_z_half( return ddqz_z_half -@program(grid_type=GridType.UNSTRUCTURED, backend=None) +@program(grid_type=GridType.UNSTRUCTURED, backend=settings.backend) def compute_ddqz_z_half( z_ifc: Field[[CellDim, KDim], wpfloat], z_mc: Field[[CellDim, KDim], wpfloat], diff --git a/model/common/tests/metric_tests/test_metric_fields.py b/model/common/tests/metric_tests/test_metric_fields.py index ec93c1c29..bc332363c 100644 --- a/model/common/tests/metric_tests/test_metric_fields.py +++ b/model/common/tests/metric_tests/test_metric_fields.py @@ -115,7 +115,7 @@ def test_compute_ddq_z_half(icon_grid, metrics_savepoint, backend): pytest.skip("skipping: unsupported backend") ddq_z_half_ref = metrics_savepoint.ddqz_z_half() z_ifc = metrics_savepoint.z_ifc() - z_mc = zero_field(icon_grid, CellDim, KDim, extend={KDim: 1}) + z_mc = zero_field(icon_grid, CellDim, KDim) nlevp1 = icon_grid.num_levels + 1 k_index = as_field((KDim,), np.arange(nlevp1, dtype=int32)) compute_z_mc.with_backend(backend)( From 0bf8d18e2425235ce1893bafe2ca808b817a07e0 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 27 Jun 2024 22:50:15 +0200 Subject: [PATCH 03/37] ugly version that works for gtfn programs --- .../icon4py/model/common/metrics/factory.py | 198 ++++++++++++------ .../common/tests/metric_tests/test_factory.py | 21 +- 2 files changed, 142 insertions(+), 77 deletions(-) diff --git a/model/common/src/icon4py/model/common/metrics/factory.py b/model/common/src/icon4py/model/common/metrics/factory.py index e200ee9ca..53f18d3ca 100644 --- a/model/common/src/icon4py/model/common/metrics/factory.py +++ b/model/common/src/icon4py/model/common/metrics/factory.py @@ -1,16 +1,24 @@ +import functools from enum import IntEnum -from typing import Sequence +from typing import Optional, Protocol, Sequence, TypeAlias, TypeVar, Union import gt4py.next as gtx import xarray as xa +from gt4py.next.ffront.decorator import Program -import icon4py.model.common.metrics.metric_fields as metrics +import icon4py.model.common.metrics.metric_fields as mf import icon4py.model.common.type_alias as ta -from icon4py.model.common.dimension import CellDim, KDim, KHalfDim +from icon4py.model.common import settings +from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, KHalfDim, VertexDim from icon4py.model.common.grid import icon -from icon4py.model.common.grid.base import BaseGrid +from icon4py.model.common.settings import xp +T = TypeVar("T", ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64) +DimT = TypeVar("DimT", KDim, KHalfDim, CellDim, EdgeDim, VertexDim) +Scalar: TypeAlias = Union[ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64] + +FieldType:TypeAlias = gtx.Field[gtx.Dims[DimT], T] class RetrievalType(IntEnum): FIELD = 0, DATA_ARRAY = 1, @@ -21,87 +29,145 @@ class RetrievalType(IntEnum): long_name="functional determinant of the metrics [sqrt(gamma)] on half levels", units="", dims=(CellDim, KHalfDim), + dtype=ta.wpfloat, icon_var_name="ddqz_z_half", ), - "height": dict(standard_name="height", long_name="height", units="m", dims=(CellDim, KDim), icon_var_name="z_mc"), - "height_on_interface_levels": dict(standard_name="height_on_interface_levels", long_name="height_on_interface_levels", units="m", dims=(CellDim, KHalfDim), icon_var_name="z_ifc") + "height": dict(standard_name="height", + long_name="height", + units="m", + dims=(CellDim, KDim), + icon_var_name="z_mc", dtype = ta.wpfloat) , + "height_on_interface_levels": dict(standard_name="height_on_interface_levels", + long_name="height_on_interface_levels", + units="m", + dims=(CellDim, KHalfDim), + icon_var_name="z_ifc", + dtype = ta.wpfloat), + "model_level_number": dict(standard_name="model_level_number", + long_name="model level number", + units="", dims=(KHalfDim,), + icon_var_name="k_index", + dtype = gtx.int32), } +class FieldProvider(Protocol): + def evaluate(self) -> None: + pass + + def get(self, field_name: str) -> FieldType: + pass + + -class FieldProviderImpl: - """ - In charge of computing a field and providing metadata about it. - TODO: change for tuples of fields - - """ - # TODO that should be a sequence or a dict of fields, since func -> tuple[...] - def __init__(self, grid: BaseGrid, deps: Sequence['FieldProvider'], attrs: dict): - self.grid = grid - self.dependencies = deps - self._attrs = attrs - self.func = metrics.compute_z_mc - self.fields:Sequence[gtx.Field|None] = [] - - # TODO (@halungge) handle DType - def _allocate(self, fields:Sequence[gtx.Field], dimensions: Sequence[gtx.Dimension]): - domain = {dim: (0, self.grid.size[dim]) for dim in dimensions} - return [gtx.constructors.zeros(domain, dtype=ta.wpfloat) for _ in fields] - - def __call__(self): - if not self.fields: - self.field = self._allocate(self.fields, self._attrs["dims"]) - domain = (0, self.grid.num_cells, 0, self.grid.num_levels) - args = [dep(RetrievalType.FIELD) for dep in self.dependencies] - self.field = self.func(*args, self.field, *domain, - offset_provider=self.grid.offset_providers) - return self.field - - -class SimpleFieldProvider: - def id(x: gtx.Field) -> gtx.Field: - return x - - def __init__(self, grid: BaseGrid, field, attrs): - super().__init__(grid, deps=(), attrs=attrs) - self.func = self.id - self.field = field - - -# class FieldProvider(Protocol): -# -# func = metrics.compute_ddqz_z_half -# field: gtx.Field[gtx.Dims[CellDim, KDim], ta.wpfloat] = None -# -# def __init__(self, grid:BaseGrid, func, deps: Sequence['FieldProvider''], attrs): -# super().__init__(grid, deps=deps, attrs=attrs) -# self.func = func +class PrecomputedFieldsProvider: + + def __init__(self,fields: dict[str, FieldType]): + self._fields = fields + + def evaluate(self): + pass + def get(self, field_name: str) -> FieldType: + return self._fields[field_name] + + class MetricsFieldsFactory: """ Factory for metric fields. """ - def __init__(self, grid:icon.IconGrid, z_ifc:gtx.Field): - self.grid = grid - self.z_ifc_provider = SimpleFieldProvider(self.grid, z_ifc, _attrs["height_on_interface_levels"]) - self._providers = {"height_on_interface_levels": self.z_ifc_provider} - - z_mc_provider = None - z_ddqz_provider = None - # TODO (@halungge) use TypedDict - self._providers["functional_determinant_of_the_metrics_on_half_levels"]= z_ddqz_provider - self._providers["height"] = z_mc_provider - + + def __init__(self, grid:icon.IconGrid, z_ifc:gtx.Field, backend=settings.backend): + self._grid = grid + self._sizes = grid.size + self._sizes[KHalfDim] = self._sizes[KDim] + 1 + self._providers: dict[str, 'FieldProvider'] = {} + self._params = {"num_lev": grid.num_levels, } + self._allocator = gtx.constructors.zeros.partial(allocator=backend) + + k_index = gtx.as_field((KDim,), xp.arange(grid.num_levels + 1, dtype=gtx.int32)) + + pre_computed_fields = PrecomputedFieldsProvider( + {"height_on_interface_levels": z_ifc, "model_level_number": k_index}) + self._providers["height_on_interface_levels"] = pre_computed_fields + self._providers["model_level_number"] = pre_computed_fields + self._providers["height"] = self.ProgramFieldProvider(self, + func = mf.compute_z_mc, + domain = {CellDim: (0, grid.num_cells), KDim: (0, grid.num_levels)}, + fields=["height"], + deps=["height_on_interface_levels"]) + self._providers["functional_determinant_of_the_metrics_on_half_levels"] = self.ProgramFieldProvider(self, + func = mf.compute_ddqz_z_half, + domain = {CellDim: (0, grid.num_cells), KHalfDim: (0, grid.num_levels + 1)}, + fields=["functional_determinant_of_the_metrics_on_half_levels"], + deps=["height_on_interface_levels", "height", "model_level_number"], + params=["num_lev"]) + + class ProgramFieldProvider: + """ + In charge of computing a field and providing metadata about it. + + """ + def __init__(self, + outer: 'MetricsFieldsFactory', # + func: Program, + domain: dict[gtx.Dimension:tuple[int, int]], # the compute domain + fields: Sequence[str], + deps: Sequence[str] = [], # the dependencies of func + params: Sequence[str] = [], # the parameters of func + ): + self._outer = outer + self._compute_domain = domain + self._dims = domain.keys() + self._func = func + self._dependencies = {k: self._outer._providers[k] for k in deps} + self._params = {k: self._outer._params[k] for k in params} + + self._fields: dict[str, Optional[gtx.Field | Scalar]] = {name: None for name in fields} + + def _map_dim(self, dim: gtx.Dimension) -> gtx.Dimension: + if dim == KHalfDim: + return KDim + return dim + + def _allocate(self): + # TODO (@halungge) get dimes from attrs? + field_domain = {self._map_dim(dim): (0, self._outer._sizes[dim]) for dim in self._dims} + return {k: self._outer._allocator(field_domain, dtype=_attrs[k]["dtype"]) for k, v in + self._fields.items()} + + + def _unallocated(self) -> bool: + return not all(self._fields.values()) + + def evaluate(self): + self._fields = self._allocate() + + domain = functools.reduce(lambda x, y: x + y, self._compute_domain.values()) + # args = {k: provider.get(k) for k, provider in self._dependencies.items()} + args = [self._dependencies[k].get(k) for k in self._dependencies.keys()] + params = [p for p in self._params.values()] + output = [f for f in self._fields.values()] + self._func(*args, *output, *params, *domain, + offset_provider=self._outer._grid.offset_providers) + + def get(self, field_name: str): + if field_name not in self._fields.keys(): + raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}") + if self._unallocated(): + self.evaluate() + return self._fields[field_name] + def get(self, field_name: str, type_: RetrievalType): if field_name not in _attrs: raise ValueError(f"Field {field_name} not found in metric fields") if type_ == RetrievalType.METADATA: return _attrs[field_name] if type_ == RetrievalType.FIELD: - return self._providers[field_name]() + return self._providers[field_name].get(field_name) if type_ == RetrievalType.DATA_ARRAY: - return to_data_array(self._providers[field_name](), _attrs[field_name]) + return to_data_array(self._providers[field_name].get(field_name), _attrs[field_name]) raise ValueError(f"Invalid retrieval type {type_}") diff --git a/model/common/tests/metric_tests/test_factory.py b/model/common/tests/metric_tests/test_factory.py index eaf0a44b3..3e70388cd 100644 --- a/model/common/tests/metric_tests/test_factory.py +++ b/model/common/tests/metric_tests/test_factory.py @@ -1,16 +1,15 @@ +import pytest + from icon4py.model.common.metrics import factory -from icon4py.model.common.metrics.factory import RetrievalType +from icon4py.model.common.test_utils.helpers import dallclose -def test_field_provider(icon_grid, metrics_savepoint): - z_ifc = factory.SimpleFieldProvider(icon_grid, metrics_savepoint.z_ifc(), factory._attrs["height_on_interface_levels"]) - z_mc = factory.FieldProvider(grid=icon_grid, deps=(z_ifc,), attrs=factory._attrs["height"]) - data_array = z_mc(RetrievalType.FIELD) - - #assert dallclose(metrics_savepoint.z_mc(), data_array.ndarray) - - - #provider = factory.FieldProviderImpl(icon_grid, (z_ifc, z_mc), attrs=factory.attrs["functional_determinant_of_the_metrics_on_half_levels"]) - #provider() +@pytest.mark.datatest +def test_field_provider(icon_grid, metrics_savepoint, backend): + fields_factory = factory.MetricsFieldsFactory(icon_grid, metrics_savepoint.z_ifc(), backend) + + data = fields_factory.get("functional_determinant_of_the_metrics_on_half_levels", type_=factory.RetrievalType.FIELD) + ref = metrics_savepoint.ddqz_z_half().ndarray + assert dallclose(data.ndarray, ref) \ No newline at end of file From b78b24f70f88f1009255634226d30e3000e27af2 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Mon, 1 Jul 2024 14:55:48 +0200 Subject: [PATCH 04/37] use operator.add instead of lambda --- model/common/src/icon4py/model/common/metrics/factory.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/metrics/factory.py b/model/common/src/icon4py/model/common/metrics/factory.py index 53f18d3ca..2782ec4e2 100644 --- a/model/common/src/icon4py/model/common/metrics/factory.py +++ b/model/common/src/icon4py/model/common/metrics/factory.py @@ -1,4 +1,5 @@ import functools +import operator from enum import IntEnum from typing import Optional, Protocol, Sequence, TypeAlias, TypeVar, Union @@ -144,7 +145,7 @@ def _unallocated(self) -> bool: def evaluate(self): self._fields = self._allocate() - domain = functools.reduce(lambda x, y: x + y, self._compute_domain.values()) + domain = functools.reduce(operator.add, self._compute_domain.values()) # args = {k: provider.get(k) for k, provider in self._dependencies.items()} args = [self._dependencies[k].get(k) for k in self._dependencies.keys()] params = [p for p in self._params.values()] From 5836a3254873f43c3586e7e228d2b0396059161e Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 8 Aug 2024 16:26:34 +0200 Subject: [PATCH 05/37] reduce dependencies, move ProgramFieldProvider out of Factory --- .../icon4py/model/common/metrics/factory.py | 123 ++++++++++-------- .../common/tests/metric_tests/test_factory.py | 35 ++++- 2 files changed, 94 insertions(+), 64 deletions(-) diff --git a/model/common/src/icon4py/model/common/metrics/factory.py b/model/common/src/icon4py/model/common/metrics/factory.py index 2782ec4e2..27d33f8ba 100644 --- a/model/common/src/icon4py/model/common/metrics/factory.py +++ b/model/common/src/icon4py/model/common/metrics/factory.py @@ -4,22 +4,20 @@ from typing import Optional, Protocol, Sequence, TypeAlias, TypeVar, Union import gt4py.next as gtx +import gt4py.next.ffront.decorator as gtx_decorator import xarray as xa -from gt4py.next.ffront.decorator import Program -import icon4py.model.common.metrics.metric_fields as mf import icon4py.model.common.type_alias as ta -from icon4py.model.common import settings -from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, KHalfDim, VertexDim +from icon4py.model.common import dimension as dims, settings from icon4py.model.common.grid import icon from icon4py.model.common.settings import xp T = TypeVar("T", ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64) -DimT = TypeVar("DimT", KDim, KHalfDim, CellDim, EdgeDim, VertexDim) +DimT = TypeVar("DimT", dims.KDim, dims.KHalfDim, dims.CellDim, dims.EdgeDim, dims.VertexDim) Scalar: TypeAlias = Union[ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64] -FieldType:TypeAlias = gtx.Field[gtx.Dims[DimT], T] +FieldType:TypeAlias = gtx.Field[Sequence[gtx.Dims[DimT]], T] class RetrievalType(IntEnum): FIELD = 0, DATA_ARRAY = 1, @@ -29,116 +27,99 @@ class RetrievalType(IntEnum): standard_name="functional_determinant_of_the_metrics_on_half_levels", long_name="functional determinant of the metrics [sqrt(gamma)] on half levels", units="", - dims=(CellDim, KHalfDim), + dims=(dims.CellDim, dims.KHalfDim), dtype=ta.wpfloat, icon_var_name="ddqz_z_half", ), "height": dict(standard_name="height", long_name="height", units="m", - dims=(CellDim, KDim), + dims=(dims.CellDim, dims.KDim), icon_var_name="z_mc", dtype = ta.wpfloat) , "height_on_interface_levels": dict(standard_name="height_on_interface_levels", long_name="height_on_interface_levels", units="m", - dims=(CellDim, KHalfDim), + dims=(dims.CellDim, dims.KHalfDim), icon_var_name="z_ifc", dtype = ta.wpfloat), "model_level_number": dict(standard_name="model_level_number", long_name="model level number", - units="", dims=(KHalfDim,), + units="", dims=(dims.KHalfDim,), icon_var_name="k_index", dtype = gtx.int32), } class FieldProvider(Protocol): + """ + Protocol for field providers. + + A field provider is responsible for the computation and caching of a set of fields. + The fields can be accessed by their field_name (str). + + A FieldProvider has to methods: + - evaluate: computes the fields based on the instructions of concrete implementation + - get: returns the field with the given field_name. + + """ def evaluate(self) -> None: pass def get(self, field_name: str) -> FieldType: pass + def fields(self) -> Sequence[str]: + pass class PrecomputedFieldsProvider: + """Simple FieldProvider that does not do any computation but gets its fields at construction and returns it upon provider.get(field_name).""" - def __init__(self,fields: dict[str, FieldType]): + def __init__(self, fields: dict[str, FieldType]): self._fields = fields def evaluate(self): pass def get(self, field_name: str) -> FieldType: return self._fields[field_name] - - - -class MetricsFieldsFactory: - """ - Factory for metric fields. - """ - - def __init__(self, grid:icon.IconGrid, z_ifc:gtx.Field, backend=settings.backend): - self._grid = grid - self._sizes = grid.size - self._sizes[KHalfDim] = self._sizes[KDim] + 1 - self._providers: dict[str, 'FieldProvider'] = {} - self._params = {"num_lev": grid.num_levels, } - self._allocator = gtx.constructors.zeros.partial(allocator=backend) - - k_index = gtx.as_field((KDim,), xp.arange(grid.num_levels + 1, dtype=gtx.int32)) + def fields(self) -> Sequence[str]: + return self._fields.keys() - pre_computed_fields = PrecomputedFieldsProvider( - {"height_on_interface_levels": z_ifc, "model_level_number": k_index}) - self._providers["height_on_interface_levels"] = pre_computed_fields - self._providers["model_level_number"] = pre_computed_fields - self._providers["height"] = self.ProgramFieldProvider(self, - func = mf.compute_z_mc, - domain = {CellDim: (0, grid.num_cells), KDim: (0, grid.num_levels)}, - fields=["height"], - deps=["height_on_interface_levels"]) - self._providers["functional_determinant_of_the_metrics_on_half_levels"] = self.ProgramFieldProvider(self, - func = mf.compute_ddqz_z_half, - domain = {CellDim: (0, grid.num_cells), KHalfDim: (0, grid.num_levels + 1)}, - fields=["functional_determinant_of_the_metrics_on_half_levels"], - deps=["height_on_interface_levels", "height", "model_level_number"], - params=["num_lev"]) - - class ProgramFieldProvider: +class ProgramFieldProvider: """ - In charge of computing a field and providing metadata about it. + Computes a field defined by a GT4Py Program. """ + def __init__(self, outer: 'MetricsFieldsFactory', # - func: Program, + func: gtx_decorator.Program, domain: dict[gtx.Dimension:tuple[int, int]], # the compute domain fields: Sequence[str], deps: Sequence[str] = [], # the dependencies of func params: Sequence[str] = [], # the parameters of func ): - self._outer = outer + self._factory = outer self._compute_domain = domain self._dims = domain.keys() self._func = func - self._dependencies = {k: self._outer._providers[k] for k in deps} - self._params = {k: self._outer._params[k] for k in params} + self._dependencies = {k: self._factory._providers[k] for k in deps} + self._params = {k: self._factory._params[k] for k in params} self._fields: dict[str, Optional[gtx.Field | Scalar]] = {name: None for name in fields} def _map_dim(self, dim: gtx.Dimension) -> gtx.Dimension: - if dim == KHalfDim: - return KDim + if dim == dims.KHalfDim: + return dims.KDim return dim def _allocate(self): - # TODO (@halungge) get dimes from attrs? - field_domain = {self._map_dim(dim): (0, self._outer._sizes[dim]) for dim in self._dims} - return {k: self._outer._allocator(field_domain, dtype=_attrs[k]["dtype"]) for k, v in + field_domain = {self._map_dim(dim): (0, self._factory._sizes[dim]) for dim in + self._dims} + return {k: self._factory._allocator(field_domain, dtype=_attrs[k]["dtype"]) for k, v in self._fields.items()} - def _unallocated(self) -> bool: return not all(self._fields.values()) @@ -151,8 +132,10 @@ def evaluate(self): params = [p for p in self._params.values()] output = [f for f in self._fields.values()] self._func(*args, *output, *params, *domain, - offset_provider=self._outer._grid.offset_providers) + offset_provider=self._factory._grid.offset_providers) + def fields(self): + return self._fields.keys() def get(self, field_name: str): if field_name not in self._fields.keys(): raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}") @@ -160,6 +143,32 @@ def get(self, field_name: str): self.evaluate() return self._fields[field_name] + +class MetricsFieldsFactory: + """ + Factory for metric fields. + """ + + + def __init__(self, grid:icon.IconGrid, z_ifc:gtx.Field, backend=settings.backend): + self._grid = grid + self._sizes = grid.size + self._sizes[dims.KHalfDim] = self._sizes[dims.KDim] + 1 + self._providers: dict[str, 'FieldProvider'] = {} + self._params = {"num_lev": grid.num_levels, } + self._allocator = gtx.constructors.zeros.partial(allocator=backend) + + k_index = gtx.as_field((dims.KDim,), xp.arange(grid.num_levels + 1, dtype=gtx.int32)) + + pre_computed_fields = PrecomputedFieldsProvider( + {"height_on_interface_levels": z_ifc, "model_level_number": k_index}) + self.register_provider(pre_computed_fields) + + def register_provider(self, provider:FieldProvider): + for field in provider.fields(): + self._providers[field] = provider + + def get(self, field_name: str, type_: RetrievalType): if field_name not in _attrs: raise ValueError(f"Field {field_name} not found in metric fields") diff --git a/model/common/tests/metric_tests/test_factory.py b/model/common/tests/metric_tests/test_factory.py index 3e70388cd..f1a32448c 100644 --- a/model/common/tests/metric_tests/test_factory.py +++ b/model/common/tests/metric_tests/test_factory.py @@ -1,15 +1,36 @@ import pytest -from icon4py.model.common.metrics import factory -from icon4py.model.common.test_utils.helpers import dallclose +import icon4py.model.common.test_utils.helpers as helpers +from icon4py.model.common import dimension as dims +from icon4py.model.common.metrics import factory, metric_fields as mf @pytest.mark.datatest def test_field_provider(icon_grid, metrics_savepoint, backend): fields_factory = factory.MetricsFieldsFactory(icon_grid, metrics_savepoint.z_ifc(), backend) - - data = fields_factory.get("functional_determinant_of_the_metrics_on_half_levels", type_=factory.RetrievalType.FIELD) - ref = metrics_savepoint.ddqz_z_half().ndarray - assert dallclose(data.ndarray, ref) + height_provider = factory.ProgramFieldProvider(func=mf.compute_z_mc, + domain={dims.CellDim: (0, icon_grid.num_cells), + dims.KDim: (0, icon_grid.num_levels)}, + fields=["height"], + deps=["height_on_interface_levels"], + outer=fields_factory) + fields_factory.register_provider(height_provider) + functional_determinant_provider = factory.ProgramFieldProvider(func=mf.compute_ddqz_z_half, + domain={dims.CellDim: (0,icon_grid.num_cells), + dims.KHalfDim: ( + 0, + icon_grid.num_levels + 1)}, + fields=[ + "functional_determinant_of_the_metrics_on_half_levels"], + deps=[ + "height_on_interface_levels", + "height", + "model_level_number"], + params=[ + "num_lev"], outer=fields_factory) + fields_factory.register_provider(functional_determinant_provider) - \ No newline at end of file + data = fields_factory.get("functional_determinant_of_the_metrics_on_half_levels", + type_=factory.RetrievalType.FIELD) + ref = metrics_savepoint.ddqz_z_half().ndarray + assert helpers.dallclose(data.ndarray, ref) From 6f3e6c64860aee6a068d72b6d33839bc9a36ecc9 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Fri, 9 Aug 2024 13:17:55 +0200 Subject: [PATCH 06/37] rename fields --- .../icon4py/model/common/metrics/factory.py | 198 ++++++++++-------- .../common/tests/metric_tests/test_factory.py | 38 +++- 2 files changed, 146 insertions(+), 90 deletions(-) diff --git a/model/common/src/icon4py/model/common/metrics/factory.py b/model/common/src/icon4py/model/common/metrics/factory.py index 27d33f8ba..3cc7cfd9a 100644 --- a/model/common/src/icon4py/model/common/metrics/factory.py +++ b/model/common/src/icon4py/model/common/metrics/factory.py @@ -1,7 +1,8 @@ +import abc import functools import operator from enum import IntEnum -from typing import Optional, Protocol, Sequence, TypeAlias, TypeVar, Union +from typing import Iterable, Optional, Protocol, Sequence, TypeAlias, TypeVar, Union import gt4py.next as gtx import gt4py.next.ffront.decorator as gtx_decorator @@ -9,8 +10,8 @@ import icon4py.model.common.type_alias as ta from icon4py.model.common import dimension as dims, settings -from icon4py.model.common.grid import icon -from icon4py.model.common.settings import xp +from icon4py.model.common.grid import base as base_grid +from icon4py.model.common.io import cf_utils T = TypeVar("T", ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64) @@ -23,8 +24,8 @@ class RetrievalType(IntEnum): DATA_ARRAY = 1, METADATA = 2, -_attrs = {"functional_determinant_of_the_metrics_on_half_levels":dict( - standard_name="functional_determinant_of_the_metrics_on_half_levels", +_attrs = {"functional_determinant_of_metrics_on_interface_levels":dict( + standard_name="functional_determinant_of_metrics_on_interface_levels", long_name="functional determinant of the metrics [sqrt(gamma)] on half levels", units="", dims=(dims.CellDim, dims.KHalfDim), @@ -44,11 +45,20 @@ class RetrievalType(IntEnum): dtype = ta.wpfloat), "model_level_number": dict(standard_name="model_level_number", long_name="model level number", - units="", dims=(dims.KHalfDim,), + units="", dims=(dims.KDim,), icon_var_name="k_index", dtype = gtx.int32), + cf_utils.INTERFACE_LEVEL_STANDARD_NAME: dict(standard_name=cf_utils.INTERFACE_LEVEL_STANDARD_NAME, + long_name="model interface level number", + units="", dims=(dims.KHalfDim,), + icon_var_name="k_index", + dtype=gtx.int32), } + + + + class FieldProvider(Protocol): """ Protocol for field providers. @@ -56,128 +66,149 @@ class FieldProvider(Protocol): A field provider is responsible for the computation and caching of a set of fields. The fields can be accessed by their field_name (str). - A FieldProvider has to methods: + A FieldProvider has three methods: - evaluate: computes the fields based on the instructions of concrete implementation - get: returns the field with the given field_name. + - fields: returns the list of field names provided by the """ - def evaluate(self) -> None: + @abc.abstractmethod + def _evaluate(self, factory:'FieldsFactory') -> None: pass - - def get(self, field_name: str) -> FieldType: + + @abc.abstractmethod + def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: pass - - def fields(self) -> Sequence[str]: + + @abc.abstractmethod + def dependencies(self) -> Iterable[str]: pass - + @abc.abstractmethod + def fields(self) -> Iterable[str]: + pass + -class PrecomputedFieldsProvider: +class PrecomputedFieldsProvider(FieldProvider): """Simple FieldProvider that does not do any computation but gets its fields at construction and returns it upon provider.get(field_name).""" def __init__(self, fields: dict[str, FieldType]): self._fields = fields - def evaluate(self): + def _evaluate(self, factory: 'FieldsFactory') -> None: pass - def get(self, field_name: str) -> FieldType: + + def dependencies(self) -> Sequence[str]: + return [] + + def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: return self._fields[field_name] - def fields(self) -> Sequence[str]: + def fields(self) -> Iterable[str]: return self._fields.keys() + class ProgramFieldProvider: - """ - Computes a field defined by a GT4Py Program. - - """ - - def __init__(self, - outer: 'MetricsFieldsFactory', # - func: gtx_decorator.Program, - domain: dict[gtx.Dimension:tuple[int, int]], # the compute domain - fields: Sequence[str], - deps: Sequence[str] = [], # the dependencies of func - params: Sequence[str] = [], # the parameters of func - ): - self._factory = outer - self._compute_domain = domain - self._dims = domain.keys() - self._func = func - self._dependencies = {k: self._factory._providers[k] for k in deps} - self._params = {k: self._factory._params[k] for k in params} - - self._fields: dict[str, Optional[gtx.Field | Scalar]] = {name: None for name in fields} - - def _map_dim(self, dim: gtx.Dimension) -> gtx.Dimension: + """ + Computes a field defined by a GT4Py Program. + + """ + + def __init__(self, + func: gtx_decorator.Program, + domain: dict[gtx.Dimension:tuple[gtx.int32, gtx.int32]], # the compute domain + fields: Sequence[str], + deps: Sequence[str] = [], # the dependencies of func + params: dict[str, Scalar] = {}, # the parameters of func + ): + self._compute_domain = domain + self._dims = domain.keys() + self._func = func + self._dependencies = deps + self._params = params + self._fields: dict[str, Optional[gtx.Field | Scalar]] = {name: None for name in fields} + + + + def _allocate(self, allocator, grid:base_grid.BaseGrid) -> dict[str, FieldType]: + def _map_size(dim:gtx.Dimension, grid:base_grid.BaseGrid) -> int: + if dim == dims.KHalfDim: + return grid.num_levels + 1 + return grid.size[dim] + + def _map_dim(dim: gtx.Dimension) -> gtx.Dimension: if dim == dims.KHalfDim: return dims.KDim return dim - def _allocate(self): - field_domain = {self._map_dim(dim): (0, self._factory._sizes[dim]) for dim in - self._dims} - return {k: self._factory._allocator(field_domain, dtype=_attrs[k]["dtype"]) for k, v in - self._fields.items()} - - def _unallocated(self) -> bool: - return not all(self._fields.values()) - - def evaluate(self): - self._fields = self._allocate() - - domain = functools.reduce(operator.add, self._compute_domain.values()) - # args = {k: provider.get(k) for k, provider in self._dependencies.items()} - args = [self._dependencies[k].get(k) for k in self._dependencies.keys()] - params = [p for p in self._params.values()] - output = [f for f in self._fields.values()] - self._func(*args, *output, *params, *domain, - offset_provider=self._factory._grid.offset_providers) - - def fields(self): - return self._fields.keys() - def get(self, field_name: str): - if field_name not in self._fields.keys(): - raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}") - if self._unallocated(): - self.evaluate() - return self._fields[field_name] - - -class MetricsFieldsFactory: + field_domain = {_map_dim(dim): (0, _map_size(dim, grid)) for dim in + self._compute_domain.keys()} + return {k: allocator(field_domain, dtype=_attrs[k]["dtype"]) for k in + self._fields.keys()} + + def _unallocated(self) -> bool: + return not all(self._fields.values()) + + def _evaluate(self, factory: 'FieldsFactory'): + self._fields = self._allocate(factory._allocator, factory.grid) + domain = functools.reduce(operator.add, self._compute_domain.values()) + args = [factory.get(k) for k in self.dependencies()] + params = [p for p in self._params.values()] + output = [f for f in self._fields.values()] + self._func(*args, *output, *params, *domain, + offset_provider=factory.grid.offset_providers) + + def fields(self)->Iterable[str]: + return self._fields.keys() + + def dependencies(self)->Iterable[str]: + return self._dependencies + def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: + if field_name not in self._fields.keys(): + raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}") + if self._unallocated(): + self._evaluate(factory) + return self._fields[field_name] + + +class FieldsFactory: """ - Factory for metric fields. + Factory for fields. + + Lazily compute fields and cache them. """ - def __init__(self, grid:icon.IconGrid, z_ifc:gtx.Field, backend=settings.backend): + def __init__(self, grid:base_grid.BaseGrid, backend=settings.backend): self._grid = grid - self._sizes = grid.size - self._sizes[dims.KHalfDim] = self._sizes[dims.KDim] + 1 self._providers: dict[str, 'FieldProvider'] = {} - self._params = {"num_lev": grid.num_levels, } self._allocator = gtx.constructors.zeros.partial(allocator=backend) - k_index = gtx.as_field((dims.KDim,), xp.arange(grid.num_levels + 1, dtype=gtx.int32)) - pre_computed_fields = PrecomputedFieldsProvider( - {"height_on_interface_levels": z_ifc, "model_level_number": k_index}) - self.register_provider(pre_computed_fields) + @property + def grid(self): + return self._grid def register_provider(self, provider:FieldProvider): + + for dependency in provider.dependencies(): + if dependency not in self._providers.keys(): + raise ValueError(f"Dependency '{dependency}' not found in registered providers") + + for field in provider.fields(): self._providers[field] = provider - def get(self, field_name: str, type_: RetrievalType): + def get(self, field_name: str, type_: RetrievalType = RetrievalType.FIELD) -> Union[FieldType, xa.DataArray, dict]: if field_name not in _attrs: raise ValueError(f"Field {field_name} not found in metric fields") if type_ == RetrievalType.METADATA: return _attrs[field_name] if type_ == RetrievalType.FIELD: - return self._providers[field_name].get(field_name) + return self._providers[field_name](field_name, self) if type_ == RetrievalType.DATA_ARRAY: - return to_data_array(self._providers[field_name].get(field_name), _attrs[field_name]) + return to_data_array(self._providers[field_name](field_name), _attrs[field_name]) raise ValueError(f"Invalid retrieval type {type_}") @@ -188,6 +219,7 @@ def to_data_array(field, attrs): + diff --git a/model/common/tests/metric_tests/test_factory.py b/model/common/tests/metric_tests/test_factory.py index f1a32448c..29d225827 100644 --- a/model/common/tests/metric_tests/test_factory.py +++ b/model/common/tests/metric_tests/test_factory.py @@ -1,19 +1,43 @@ +import gt4py.next as gtx import pytest import icon4py.model.common.test_utils.helpers as helpers from icon4py.model.common import dimension as dims +from icon4py.model.common.io import cf_utils from icon4py.model.common.metrics import factory, metric_fields as mf +from icon4py.model.common.settings import xp +def test_check_dependencies_on_register(icon_grid, backend): + fields_factory = factory.FieldsFactory(icon_grid, backend) + provider = factory.ProgramFieldProvider(func=mf.compute_z_mc, + domain={dims.CellDim: (0, icon_grid.num_cells), + dims.KDim: (0, icon_grid.num_levels)}, + fields=["height"], + deps=["height_on_interface_levels"], + ) + with pytest.raises(ValueError) as e: + fields_factory.register_provider(provider) + assert e.value.match("'height_on_interface_levels' not found") + + @pytest.mark.datatest def test_field_provider(icon_grid, metrics_savepoint, backend): - fields_factory = factory.MetricsFieldsFactory(icon_grid, metrics_savepoint.z_ifc(), backend) + fields_factory = factory.FieldsFactory(icon_grid, backend) + k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) + z_ifc = metrics_savepoint.z_ifc() + + pre_computed_fields = factory.PrecomputedFieldsProvider( + {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index}) + + fields_factory.register_provider(pre_computed_fields) + height_provider = factory.ProgramFieldProvider(func=mf.compute_z_mc, domain={dims.CellDim: (0, icon_grid.num_cells), dims.KDim: (0, icon_grid.num_levels)}, fields=["height"], deps=["height_on_interface_levels"], - outer=fields_factory) + ) fields_factory.register_provider(height_provider) functional_determinant_provider = factory.ProgramFieldProvider(func=mf.compute_ddqz_z_half, domain={dims.CellDim: (0,icon_grid.num_cells), @@ -21,16 +45,16 @@ def test_field_provider(icon_grid, metrics_savepoint, backend): 0, icon_grid.num_levels + 1)}, fields=[ - "functional_determinant_of_the_metrics_on_half_levels"], + "functional_determinant_of_metrics_on_interface_levels"], deps=[ "height_on_interface_levels", "height", - "model_level_number"], - params=[ - "num_lev"], outer=fields_factory) + cf_utils.INTERFACE_LEVEL_STANDARD_NAME], + params={ + "num_lev": icon_grid.num_levels}) fields_factory.register_provider(functional_determinant_provider) - data = fields_factory.get("functional_determinant_of_the_metrics_on_half_levels", + data = fields_factory.get("functional_determinant_of_metrics_on_interface_levels", type_=factory.RetrievalType.FIELD) ref = metrics_savepoint.ddqz_z_half().ndarray assert helpers.dallclose(data.ndarray, ref) From 21c744bb1925e9822379f5d047098f0990c8d652 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 15 Aug 2024 14:30:20 +0200 Subject: [PATCH 07/37] move factory.py to states package allow factory to be instantiated without backend, and grid --- .../src/icon4py/model/common/exceptions.py | 5 +- .../common/{metrics => states}/factory.py | 80 +++++++++---------- .../icon4py/model/common/states/metadata.py | 38 +++++++++ model/common/tests/states_test/conftest.py | 22 +++++ .../test_factory.py | 32 +++++++- 5 files changed, 133 insertions(+), 44 deletions(-) rename model/common/src/icon4py/model/common/{metrics => states}/factory.py (72%) create mode 100644 model/common/src/icon4py/model/common/states/metadata.py create mode 100644 model/common/tests/states_test/conftest.py rename model/common/tests/{metric_tests => states_test}/test_factory.py (69%) diff --git a/model/common/src/icon4py/model/common/exceptions.py b/model/common/src/icon4py/model/common/exceptions.py index 901617e57..c55f668e4 100644 --- a/model/common/src/icon4py/model/common/exceptions.py +++ b/model/common/src/icon4py/model/common/exceptions.py @@ -10,7 +10,10 @@ class InvalidConfigError(Exception): pass +class IncompleteSetupError(Exception): + def __init__(self, msg): + super().__init__(f"{msg}" ) class IncompleteStateError(Exception): def __init__(self, field_name): - super().__init__(f"Field '{field_name}' is missing in state.") + super().__init__(f"Field '{field_name}' is missing.") diff --git a/model/common/src/icon4py/model/common/metrics/factory.py b/model/common/src/icon4py/model/common/states/factory.py similarity index 72% rename from model/common/src/icon4py/model/common/metrics/factory.py rename to model/common/src/icon4py/model/common/states/factory.py index 3cc7cfd9a..0470f5496 100644 --- a/model/common/src/icon4py/model/common/metrics/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -8,10 +8,10 @@ import gt4py.next.ffront.decorator as gtx_decorator import xarray as xa -import icon4py.model.common.type_alias as ta -from icon4py.model.common import dimension as dims, settings +import icon4py.model.common.states.metadata as metadata +from icon4py.model.common import dimension as dims, exceptions, settings, type_alias as ta from icon4py.model.common.grid import base as base_grid -from icon4py.model.common.io import cf_utils +from icon4py.model.common.utils import builder T = TypeVar("T", ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64) @@ -24,39 +24,16 @@ class RetrievalType(IntEnum): DATA_ARRAY = 1, METADATA = 2, -_attrs = {"functional_determinant_of_metrics_on_interface_levels":dict( - standard_name="functional_determinant_of_metrics_on_interface_levels", - long_name="functional determinant of the metrics [sqrt(gamma)] on half levels", - units="", - dims=(dims.CellDim, dims.KHalfDim), - dtype=ta.wpfloat, - icon_var_name="ddqz_z_half", - ), - "height": dict(standard_name="height", - long_name="height", - units="m", - dims=(dims.CellDim, dims.KDim), - icon_var_name="z_mc", dtype = ta.wpfloat) , - "height_on_interface_levels": dict(standard_name="height_on_interface_levels", - long_name="height_on_interface_levels", - units="m", - dims=(dims.CellDim, dims.KHalfDim), - icon_var_name="z_ifc", - dtype = ta.wpfloat), - "model_level_number": dict(standard_name="model_level_number", - long_name="model level number", - units="", dims=(dims.KDim,), - icon_var_name="k_index", - dtype = gtx.int32), - cf_utils.INTERFACE_LEVEL_STANDARD_NAME: dict(standard_name=cf_utils.INTERFACE_LEVEL_STANDARD_NAME, - long_name="model interface level number", - units="", dims=(dims.KHalfDim,), - icon_var_name="k_index", - dtype=gtx.int32), - } +def valid(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if not self.validate(): + raise exceptions.IncompleteSetupError("Factory not fully instantiated, missing grid or allocator") + return func(self, *args, **kwargs) + return wrapper class FieldProvider(Protocol): @@ -106,7 +83,8 @@ def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: def fields(self) -> Iterable[str]: return self._fields.keys() - + + class ProgramFieldProvider: """ @@ -143,14 +121,14 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension: field_domain = {_map_dim(dim): (0, _map_size(dim, grid)) for dim in self._compute_domain.keys()} - return {k: allocator(field_domain, dtype=_attrs[k]["dtype"]) for k in + return {k: allocator(field_domain, dtype=metadata.attrs[k]["dtype"]) for k in self._fields.keys()} def _unallocated(self) -> bool: return not all(self._fields.values()) def _evaluate(self, factory: 'FieldsFactory'): - self._fields = self._allocate(factory._allocator, factory.grid) + self._fields = self._allocate(factory.allocator, factory.grid) domain = functools.reduce(operator.add, self._compute_domain.values()) args = [factory.get(k) for k in self.dependencies()] params = [p for p in self._params.values()] @@ -179,15 +157,32 @@ class FieldsFactory: """ - def __init__(self, grid:base_grid.BaseGrid, backend=settings.backend): + def __init__(self, grid:base_grid.BaseGrid = None, backend=settings.backend): self._grid = grid self._providers: dict[str, 'FieldProvider'] = {} self._allocator = gtx.constructors.zeros.partial(allocator=backend) + def validate(self): + return self._grid is not None and self._allocator is not None + + @builder.builder + def with_grid(self, grid:base_grid.BaseGrid): + self._grid = grid + + @builder.builder + def with_allocator(self, backend = settings.backend): + self._allocator = backend + + + @property def grid(self): return self._grid + + @property + def allocator(self): + return self._allocator def register_provider(self, provider:FieldProvider): @@ -199,19 +194,22 @@ def register_provider(self, provider:FieldProvider): for field in provider.fields(): self._providers[field] = provider - + @valid def get(self, field_name: str, type_: RetrievalType = RetrievalType.FIELD) -> Union[FieldType, xa.DataArray, dict]: - if field_name not in _attrs: + if field_name not in metadata.attrs: raise ValueError(f"Field {field_name} not found in metric fields") if type_ == RetrievalType.METADATA: - return _attrs[field_name] + return metadata.attrs[field_name] if type_ == RetrievalType.FIELD: return self._providers[field_name](field_name, self) if type_ == RetrievalType.DATA_ARRAY: - return to_data_array(self._providers[field_name](field_name), _attrs[field_name]) + return to_data_array(self._providers[field_name](field_name), metadata.attrs[field_name]) raise ValueError(f"Invalid retrieval type {type_}") + + + def to_data_array(field, attrs): return xa.DataArray(field, attrs=attrs) diff --git a/model/common/src/icon4py/model/common/states/metadata.py b/model/common/src/icon4py/model/common/states/metadata.py new file mode 100644 index 000000000..67134322f --- /dev/null +++ b/model/common/src/icon4py/model/common/states/metadata.py @@ -0,0 +1,38 @@ + + +import gt4py.next as gtx + +import icon4py.model.common.io.cf_utils as cf_utils +from icon4py.model.common import dimension as dims, type_alias as ta + + +attrs = {"functional_determinant_of_metrics_on_interface_levels":dict( + standard_name="functional_determinant_of_metrics_on_interface_levels", + long_name="functional determinant of the metrics [sqrt(gamma)] on half levels", + units="", + dims=(dims.CellDim, dims.KHalfDim), + dtype=ta.wpfloat, + icon_var_name="ddqz_z_half", + ), + "height": dict(standard_name="height", + long_name="height", + units="m", + dims=(dims.CellDim, dims.KDim), + icon_var_name="z_mc", dtype = ta.wpfloat) , + "height_on_interface_levels": dict(standard_name="height_on_interface_levels", + long_name="height_on_interface_levels", + units="m", + dims=(dims.CellDim, dims.KHalfDim), + icon_var_name="z_ifc", + dtype = ta.wpfloat), + "model_level_number": dict(standard_name="model_level_number", + long_name="model level number", + units="", dims=(dims.KDim,), + icon_var_name="k_index", + dtype = gtx.int32), + cf_utils.INTERFACE_LEVEL_STANDARD_NAME: dict(standard_name=cf_utils.INTERFACE_LEVEL_STANDARD_NAME, + long_name="model interface level number", + units="", dims=(dims.KHalfDim,), + icon_var_name="k_index", + dtype=gtx.int32), + } \ No newline at end of file diff --git a/model/common/tests/states_test/conftest.py b/model/common/tests/states_test/conftest.py new file mode 100644 index 000000000..cb7be87d5 --- /dev/null +++ b/model/common/tests/states_test/conftest.py @@ -0,0 +1,22 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from icon4py.model.common.test_utils.datatest_fixtures import ( # noqa: F401 # import fixtures from test_utils package + data_provider, + download_ser_data, + experiment, + grid_savepoint, + icon_grid, + interpolation_savepoint, + metrics_savepoint, + processor_props, + ranked_data_path, +) +from icon4py.model.common.test_utils.helpers import ( # noqa : F401 # fixtures from test_utils + backend, +) diff --git a/model/common/tests/metric_tests/test_factory.py b/model/common/tests/states_test/test_factory.py similarity index 69% rename from model/common/tests/metric_tests/test_factory.py rename to model/common/tests/states_test/test_factory.py index 29d225827..1d433d126 100644 --- a/model/common/tests/metric_tests/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -2,12 +2,14 @@ import pytest import icon4py.model.common.test_utils.helpers as helpers -from icon4py.model.common import dimension as dims +from icon4py.model.common import dimension as dims, exceptions from icon4py.model.common.io import cf_utils -from icon4py.model.common.metrics import factory, metric_fields as mf +from icon4py.model.common.metrics import metric_fields as mf from icon4py.model.common.settings import xp +from icon4py.model.common.states import factory +@pytest.mark.datatest def test_check_dependencies_on_register(icon_grid, backend): fields_factory = factory.FieldsFactory(icon_grid, backend) provider = factory.ProgramFieldProvider(func=mf.compute_z_mc, @@ -21,6 +23,32 @@ def test_check_dependencies_on_register(icon_grid, backend): assert e.value.match("'height_on_interface_levels' not found") +@pytest.mark.datatest +def test_factory_raise_error_if_no_grid_or_backend_set(metrics_savepoint): + z_ifc = metrics_savepoint.z_ifc() + k_index = gtx.as_field((dims.KDim,), xp.arange( 1, dtype=gtx.int32)) + pre_computed_fields = factory.PrecomputedFieldsProvider( + {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index}) + fields_factory = factory.FieldsFactory(None, None) + fields_factory.register_provider(pre_computed_fields) + with pytest.raises(exceptions.IncompleteSetupError) as e: + fields_factory.get("height_on_interface_levels") + assert e.value.match("not fully instantiated") + + +@pytest.mark.datatest +def test_factory_returns_field(metrics_savepoint, icon_grid, backend): + z_ifc = metrics_savepoint.z_ifc() + k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels +1, dtype=gtx.int32)) + pre_computed_fields = factory.PrecomputedFieldsProvider( + {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index}) + fields_factory = factory.FieldsFactory(None, None) + fields_factory.register_provider(pre_computed_fields) + fields_factory.with_grid(icon_grid).with_allocator(backend) + field = fields_factory.get("height_on_interface_levels", factory.RetrievalType.FIELD) + assert field.ndarray.shape == (icon_grid.num_cells, icon_grid.num_levels + 1) + + @pytest.mark.datatest def test_field_provider(icon_grid, metrics_savepoint, backend): fields_factory = factory.FieldsFactory(icon_grid, backend) From bf7dc7e7f47abcc9c889511aa9df45c1707adfc0 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 15 Aug 2024 16:48:26 +0200 Subject: [PATCH 08/37] remove duplicated computation of wgtfacq_c_dsl --- .../model/common/metrics/compute_wgtfacq.py | 19 +++++++++---------- .../metric_tests/test_compute_wgtfacq.py | 3 ++- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py b/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py index 2a7b92a8b..1bf535bbd 100644 --- a/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py +++ b/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py @@ -9,7 +9,7 @@ import numpy as np -def compute_z1_z2_z3(z_ifc, i1, i2, i3, i4): +def _compute_z1_z2_z3(z_ifc, i1, i2, i3, i4): z1 = 0.5 * (z_ifc[:, i2] - z_ifc[:, i1]) z2 = 0.5 * (z_ifc[:, i2] + z_ifc[:, i3]) - z_ifc[:, i1] z3 = 0.5 * (z_ifc[:, i3] + z_ifc[:, i4]) - z_ifc[:, i1] @@ -31,7 +31,7 @@ def compute_wgtfacq_c_dsl( """ wgtfacq_c = np.zeros((z_ifc.shape[0], nlev + 1)) wgtfacq_c_dsl = np.zeros((z_ifc.shape[0], nlev)) - z1, z2, z3 = compute_z1_z2_z3(z_ifc, nlev, nlev - 1, nlev - 2, nlev - 3) + z1, z2, z3 = _compute_z1_z2_z3(z_ifc, nlev, nlev - 1, nlev - 2, nlev - 3) wgtfacq_c[:, 2] = z1 * z2 / (z2 - z3) / (z1 - z3) wgtfacq_c[:, 1] = (z1 - wgtfacq_c[:, 2] * (z1 - z3)) / (z1 - z2) @@ -43,12 +43,11 @@ def compute_wgtfacq_c_dsl( return wgtfacq_c_dsl - def compute_wgtfacq_e_dsl( e2c, - z_ifc: np.array, - z_aux_c: np.array, - c_lin_e: np.array, + z_ifc: np.ndarray, + c_lin_e: np.ndarray, + wgtfacq_c_dsl: np.ndarray, n_edges: int, nlev: int, ): @@ -58,7 +57,7 @@ def compute_wgtfacq_e_dsl( Args: e2c: Edge to Cell offset z_ifc: geometric height at the vertical interface of cells. - z_aux_c: interpolation of weighting coefficients to edges + wgtfacq_c_dsl: weighting factor for quadratic interpolation to surface c_lin_e: interpolation field n_edges: number of edges nlev: int, last k level @@ -66,13 +65,13 @@ def compute_wgtfacq_e_dsl( Field[EdgeDim, KDim] (full levels) """ wgtfacq_e_dsl = np.zeros(shape=(n_edges, nlev + 1)) - z1, z2, z3 = compute_z1_z2_z3(z_ifc, nlev, nlev - 1, nlev - 2, nlev - 3) - wgtfacq_c_dsl = compute_wgtfacq_c_dsl(z_ifc, nlev) + z_aux_c = np.zeros((z_ifc.shape[0], 6)) + z1, z2, z3 = _compute_z1_z2_z3(z_ifc, nlev, nlev - 1, nlev - 2, nlev - 3) z_aux_c[:, 2] = z1 * z2 / (z2 - z3) / (z1 - z3) z_aux_c[:, 1] = (z1 - wgtfacq_c_dsl[:, nlev - 3] * (z1 - z3)) / (z1 - z2) z_aux_c[:, 0] = 1.0 - (wgtfacq_c_dsl[:, nlev - 2] + wgtfacq_c_dsl[:, nlev - 3]) - z1, z2, z3 = compute_z1_z2_z3(z_ifc, 0, 1, 2, 3) + z1, z2, z3 = _compute_z1_z2_z3(z_ifc, 0, 1, 2, 3) z_aux_c[:, 5] = z1 * z2 / (z2 - z3) / (z1 - z3) z_aux_c[:, 4] = (z1 - z_aux_c[:, 5] * (z1 - z3)) / (z1 - z2) z_aux_c[:, 3] = 1.0 - (z_aux_c[:, 4] + z_aux_c[:, 5]) diff --git a/model/common/tests/metric_tests/test_compute_wgtfacq.py b/model/common/tests/metric_tests/test_compute_wgtfacq.py index dda14b19e..9da5ccb32 100644 --- a/model/common/tests/metric_tests/test_compute_wgtfacq.py +++ b/model/common/tests/metric_tests/test_compute_wgtfacq.py @@ -32,11 +32,12 @@ def test_compute_wgtfacq_c_dsl(icon_grid, metrics_savepoint): @pytest.mark.datatest def test_compute_wgtfacq_e_dsl(metrics_savepoint, interpolation_savepoint, icon_grid): wgtfacq_e_dsl_ref = metrics_savepoint.wgtfacq_e_dsl(icon_grid.num_levels + 1) + wgtfacq_c_dsl = metrics_savepoint.wgtfacq_c_dsl() wgtfacq_e_dsl_full = compute_wgtfacq_e_dsl( e2c=icon_grid.connectivities[E2CDim], z_ifc=metrics_savepoint.z_ifc().asnumpy(), - z_aux_c=metrics_savepoint.wgtfac_c().asnumpy(), + wgtfacq_c_dsl=wgtfacq_c_dsl.asnumpy(), c_lin_e=interpolation_savepoint.c_lin_e().asnumpy(), n_edges=icon_grid.num_edges, nlev=icon_grid.num_levels, From d07fef2367dd3eaab7378535f822a41946718bb3 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 15 Aug 2024 17:35:59 +0200 Subject: [PATCH 09/37] fix type annotations for arrays --- .../src/icon4py/model/common/metrics/compute_wgtfacq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py b/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py index 1bf535bbd..b87af31b4 100644 --- a/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py +++ b/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py @@ -17,9 +17,9 @@ def _compute_z1_z2_z3(z_ifc, i1, i2, i3, i4): def compute_wgtfacq_c_dsl( - z_ifc: np.array, + z_ifc: np.ndarray, nlev: int, -) -> np.array: +) -> np.ndarray: """ Compute weighting factor for quadratic interpolation to surface. From 8bb63f6d376ce2499268346f383ab606a65e737a Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 15 Aug 2024 17:51:07 +0200 Subject: [PATCH 10/37] add type annotations to compute_vwind_impl_wgt.py fix type annotations for np.ndarray in compute_zdiff_gradp_dsl.py and compute_diffusion_metrics.py --- .../metrics/compute_diffusion_metrics.py | 58 +++++++++---------- .../common/metrics/compute_vwind_impl_wgt.py | 25 ++++---- .../common/metrics/compute_zdiff_gradp_dsl.py | 12 ++-- 3 files changed, 49 insertions(+), 46 deletions(-) diff --git a/model/common/src/icon4py/model/common/metrics/compute_diffusion_metrics.py b/model/common/src/icon4py/model/common/metrics/compute_diffusion_metrics.py index 494518274..6f289626f 100644 --- a/model/common/src/icon4py/model/common/metrics/compute_diffusion_metrics.py +++ b/model/common/src/icon4py/model/common/metrics/compute_diffusion_metrics.py @@ -11,12 +11,12 @@ def _compute_nbidx( k_range: range, - z_mc: np.array, - z_mc_off: np.array, - nbidx: np.array, + z_mc: np.ndarray, + z_mc_off: np.ndarray, + nbidx: np.ndarray, jc: int, nlev: int, -) -> np.array: +) -> np.ndarray: for ind in range(3): jk_start = nlev - 1 for jk in reversed(k_range): @@ -34,12 +34,12 @@ def _compute_nbidx( def _compute_z_vintcoeff( k_range: range, - z_mc: np.array, - z_mc_off: np.array, - z_vintcoeff: np.array, + z_mc: np.ndarray, + z_mc_off: np.ndarray, + z_vintcoeff: np.ndarray, jc: int, nlev: int, -) -> np.array: +) -> np.ndarray: for ind in range(3): jk_start = nlev - 1 for jk in reversed(k_range): @@ -60,9 +60,9 @@ def _compute_z_vintcoeff( def _compute_ls_params( k_start: list, k_end: list, - z_maxslp_avg: np.array, - z_maxhgtd_avg: np.array, - c_owner_mask: np.array, + z_maxslp_avg: np.ndarray, + z_maxhgtd_avg: np.ndarray, + c_owner_mask: np.ndarray, thslp_zdiffu: float, thhgtd_zdiffu: float, cell_nudging: int, @@ -92,11 +92,11 @@ def _compute_ls_params( def _compute_k_start_end( - z_mc: np.array, - max_nbhgt: np.array, - z_maxslp_avg: np.array, - z_maxhgtd_avg: np.array, - c_owner_mask: np.array, + z_mc: np.ndarray, + max_nbhgt: np.ndarray, + z_maxslp_avg: np.ndarray, + z_maxhgtd_avg: np.ndarray, + c_owner_mask: np.ndarray, thslp_zdiffu: float, thhgtd_zdiffu: float, cell_nudging: int, @@ -127,24 +127,24 @@ def _compute_k_start_end( def compute_diffusion_metrics( - z_mc: np.array, - z_mc_off: np.array, - max_nbhgt: np.array, - c_owner_mask: np.array, - nbidx: np.array, - z_vintcoeff: np.array, - z_maxslp_avg: np.array, - z_maxhgtd_avg: np.array, - mask_hdiff: np.array, - zd_diffcoef_dsl: np.array, - zd_intcoef_dsl: np.array, - zd_vertoffset_dsl: np.array, + z_mc: np.ndarray, + z_mc_off: np.ndarray, + max_nbhgt: np.ndarray, + c_owner_mask: np.ndarray, + nbidx: np.ndarray, + z_vintcoeff: np.ndarray, + z_maxslp_avg: np.ndarray, + z_maxhgtd_avg: np.ndarray, + mask_hdiff: np.ndarray, + zd_diffcoef_dsl: np.ndarray, + zd_intcoef_dsl: np.ndarray, + zd_vertoffset_dsl: np.ndarray, thslp_zdiffu: float, thhgtd_zdiffu: float, cell_nudging: int, n_cells: int, nlev: int, -) -> tuple[np.array, np.array, np.array, np.array]: +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: k_start, k_end = _compute_k_start_end( z_mc=z_mc, max_nbhgt=max_nbhgt, diff --git a/model/common/src/icon4py/model/common/metrics/compute_vwind_impl_wgt.py b/model/common/src/icon4py/model/common/metrics/compute_vwind_impl_wgt.py index 1b87efeb4..d3a7a96e9 100644 --- a/model/common/src/icon4py/model/common/metrics/compute_vwind_impl_wgt.py +++ b/model/common/src/icon4py/model/common/metrics/compute_vwind_impl_wgt.py @@ -5,27 +5,30 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - import numpy as np +import icon4py.model.common.field_type_aliases as fa +from icon4py.model.common.grid import base as grid from icon4py.model.common.metrics.metric_fields import compute_vwind_impl_wgt_partial +from icon4py.model.common.type_alias import wpfloat def compute_vwind_impl_wgt( backend, - icon_grid, - vct_a, - z_ifc, - z_ddxn_z_half_e, - z_ddxt_z_half_e, - dual_edge_length, - vwind_impl_wgt_full, - vwind_impl_wgt_k, + icon_grid: grid.BaseGrid, + vct_a:fa.KField[wpfloat], + z_ifc:fa.CellKField[wpfloat], + z_ddxn_z_half_e:fa.EdgeField[wpfloat], + z_ddxt_z_half_e:fa.EdgeField[wpfloat], + dual_edge_length:fa.EdgeField[wpfloat], + vwind_impl_wgt_full:fa.CellField[wpfloat], + vwind_impl_wgt_k:fa.CellField[wpfloat], global_exp: str, experiment: str, vwind_offctr: float, horizontal_start_cell: int, -): +)-> np.ndarray: + compute_vwind_impl_wgt_partial.with_backend(backend)( z_ddxn_z_half_e=z_ddxn_z_half_e, z_ddxt_z_half_e=z_ddxt_z_half_e, @@ -37,7 +40,7 @@ def compute_vwind_impl_wgt( vwind_offctr=vwind_offctr, horizontal_start=horizontal_start_cell, horizontal_end=icon_grid.num_cells, - vertical_start=max(10, icon_grid.num_levels - 8), + vertical_start=max(10, icon_grid.num_levels - 8),# TODO check this what are these constants? vertical_end=icon_grid.num_levels, offset_provider={ "C2E": icon_grid.get_offset_provider("C2E"), diff --git a/model/common/src/icon4py/model/common/metrics/compute_zdiff_gradp_dsl.py b/model/common/src/icon4py/model/common/metrics/compute_zdiff_gradp_dsl.py index 85e5d9cc1..4156f8191 100644 --- a/model/common/src/icon4py/model/common/metrics/compute_zdiff_gradp_dsl.py +++ b/model/common/src/icon4py/model/common/metrics/compute_zdiff_gradp_dsl.py @@ -11,16 +11,16 @@ def compute_zdiff_gradp_dsl( e2c, - z_me: np.array, - z_mc: np.array, - z_ifc: np.array, - flat_idx: np.array, - z_aux2: np.array, + z_me: np.ndarray, + z_mc: np.ndarray, + z_ifc: np.ndarray, + flat_idx: np.ndarray, + z_aux2: np.ndarray, nlev: int, horizontal_start: int, horizontal_start_1: int, nedges: int, -) -> np.array: +) -> np.ndarray: zdiff_gradp = np.zeros_like(z_mc[e2c]) zdiff_gradp[horizontal_start:, :, :] = ( np.expand_dims(z_me, axis=1)[horizontal_start:, :, :] - z_mc[e2c][horizontal_start:, :, :] From a9b0b542a675234e52779340e8249e249a8c684b Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Fri, 16 Aug 2024 09:49:22 +0200 Subject: [PATCH 11/37] FieldProvider for numpy functions (WIP I) --- .../icon4py/model/common/states/factory.py | 41 +++++++++++++++++-- .../icon4py/model/common/states/metadata.py | 6 +++ .../common/tests/states_test/test_factory.py | 27 ++++++++++++ 3 files changed, 70 insertions(+), 4 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 0470f5496..506d548ab 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -1,8 +1,9 @@ import abc import functools +import inspect import operator from enum import IntEnum -from typing import Iterable, Optional, Protocol, Sequence, TypeAlias, TypeVar, Union +from typing import Callable, Iterable, Optional, Protocol, Sequence, TypeAlias, TypeVar, Union import gt4py.next as gtx import gt4py.next.ffront.decorator as gtx_decorator @@ -93,7 +94,7 @@ class ProgramFieldProvider: """ def __init__(self, - func: gtx_decorator.Program, + func: Union[gtx_decorator.Program, Callable], domain: dict[gtx.Dimension:tuple[gtx.int32, gtx.int32]], # the compute domain fields: Sequence[str], deps: Sequence[str] = [], # the dependencies of func @@ -130,25 +131,57 @@ def _unallocated(self) -> bool: def _evaluate(self, factory: 'FieldsFactory'): self._fields = self._allocate(factory.allocator, factory.grid) domain = functools.reduce(operator.add, self._compute_domain.values()) - args = [factory.get(k) for k in self.dependencies()] + deps = [factory.get(k) for k in self.dependencies()] params = [p for p in self._params.values()] output = [f for f in self._fields.values()] - self._func(*args, *output, *params, *domain, + # it might be safer to call the field_operator here? then we can use the keyword only args for out= and domain= + self._func(*deps, *output, *params, *domain, offset_provider=factory.grid.offset_providers) + def fields(self)->Iterable[str]: return self._fields.keys() def dependencies(self)->Iterable[str]: return self._dependencies + def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: if field_name not in self._fields.keys(): raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}") if self._unallocated(): + self._evaluate(factory) return self._fields[field_name] +class NumpyFieldsProvider(ProgramFieldProvider): + def __init__(self, func:Callable, + domain:dict[gtx.Dimension:tuple[gtx.int32, gtx.int32]], + fields:Sequence[str], + deps:Sequence[str] = [], + params:dict[str, Scalar] = {}): + super().__init__(func, domain, fields, deps, params) + def _evaluate(self, factory: 'FieldsFactory') -> None: + domain = {dim: range(*self._compute_domain[dim]) for dim in self._compute_domain.keys()} + deps = [factory.get(k).ndarray for k in self.dependencies()] + params = [p for p in self._params.values()] + + results = self._func(*deps, *params) + self._fields = {k: results[i] for i, k in enumerate(self._fields.keys())} + + +def inspect_func(func:Callable): + signa = inspect.signature(func) + print(f"signature: {signa}") + print(f"parameters: {signa.parameters}") + + print(f"return : {signa.return_annotation}") + return signa + + + + + class FieldsFactory: """ Factory for fields. diff --git a/model/common/src/icon4py/model/common/states/metadata.py b/model/common/src/icon4py/model/common/states/metadata.py index 67134322f..7454b5cc3 100644 --- a/model/common/src/icon4py/model/common/states/metadata.py +++ b/model/common/src/icon4py/model/common/states/metadata.py @@ -35,4 +35,10 @@ units="", dims=(dims.KHalfDim,), icon_var_name="k_index", dtype=gtx.int32), + "weight_factor_for_quadratic_interpolation_to_cell_surface": dict(standard_name="weight_factor_for_quadratic_interpolation_to_cell_surface", + units="", + dims=(dims.CellDim, dims.KDim), + dtype=ta.wpfloat, + icon_var_name="wgtfacq_c_dsl", + long_name="weighting factor for quadratic interpolation to cell surface"), } \ No newline at end of file diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index 1d433d126..7feefcc41 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -5,6 +5,7 @@ from icon4py.model.common import dimension as dims, exceptions from icon4py.model.common.io import cf_utils from icon4py.model.common.metrics import metric_fields as mf +from icon4py.model.common.metrics.compute_wgtfacq import compute_wgtfacq_c_dsl from icon4py.model.common.settings import xp from icon4py.model.common.states import factory @@ -86,3 +87,29 @@ def test_field_provider(icon_grid, metrics_savepoint, backend): type_=factory.RetrievalType.FIELD) ref = metrics_savepoint.ddqz_z_half().ndarray assert helpers.dallclose(data.ndarray, ref) + + +def test_numpy_func(icon_grid, metrics_savepoint, backend): + fields_factory = factory.FieldsFactory(grid=icon_grid, backend=backend) + k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) + z_ifc = metrics_savepoint.z_ifc() + + pre_computed_fields = factory.PrecomputedFieldsProvider( + {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index}) + fields_factory.register_provider(pre_computed_fields) + func = compute_wgtfacq_c_dsl + signature = factory.inspect_func(compute_wgtfacq_c_dsl) + compute_wgtfacq_c_provider = factory.NumpyFieldsProvider(func=func, + domain={dims.CellDim: (0, icon_grid.num_cells), + dims.KDim: (0, icon_grid.num_levels)}, + fields=[ + "weighting_factor_for_quadratic_interpolation_to_cell_surface"], + deps=[ + "height_on_interface_levels"], + params={ + "num_lev": icon_grid.num_levels}) + fields_factory.register_provider(compute_wgtfacq_c_provider) + + + fields_factory.get("weighting_factor_for_quadratic_interpolation_to_cell_surface", factory.RetrievalType.FIELD) + \ No newline at end of file From ffb46614063039db24f30dd49f9601641b293a70 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Fri, 16 Aug 2024 15:58:59 +0200 Subject: [PATCH 12/37] first version for numpy functions --- .../icon4py/model/common/states/factory.py | 64 ++++++++++++++++--- .../icon4py/model/common/states/metadata.py | 2 +- .../common/tests/states_test/test_factory.py | 14 ++-- 3 files changed, 62 insertions(+), 18 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 506d548ab..454cd1a93 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -12,6 +12,7 @@ import icon4py.model.common.states.metadata as metadata from icon4py.model.common import dimension as dims, exceptions, settings, type_alias as ta from icon4py.model.common.grid import base as base_grid +from icon4py.model.common.settings import xp from icon4py.model.common.utils import builder @@ -65,7 +66,9 @@ def dependencies(self) -> Iterable[str]: @abc.abstractmethod def fields(self) -> Iterable[str]: pass - + + def _unallocated(self) -> bool: + return not all(self._fields.values()) class PrecomputedFieldsProvider(FieldProvider): """Simple FieldProvider that does not do any computation but gets its fields at construction and returns it upon provider.get(field_name).""" @@ -125,8 +128,7 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension: return {k: allocator(field_domain, dtype=metadata.attrs[k]["dtype"]) for k in self._fields.keys()} - def _unallocated(self) -> bool: - return not all(self._fields.values()) + def _evaluate(self, factory: 'FieldsFactory'): self._fields = self._allocate(factory.allocator, factory.grid) @@ -154,21 +156,63 @@ def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: return self._fields[field_name] -class NumpyFieldsProvider(ProgramFieldProvider): +class NumpyFieldsProvider(FieldProvider): def __init__(self, func:Callable, domain:dict[gtx.Dimension:tuple[gtx.int32, gtx.int32]], fields:Sequence[str], - deps:Sequence[str] = [], + deps:dict[str, str], params:dict[str, Scalar] = {}): - super().__init__(func, domain, fields, deps, params) + self._compute_domain = domain + self._func = func + self._fields:dict[str, Optional[FieldType]] = {name: None for name in fields} + self._dependencies = deps + self._params = params + def _evaluate(self, factory: 'FieldsFactory') -> None: domain = {dim: range(*self._compute_domain[dim]) for dim in self._compute_domain.keys()} - deps = [factory.get(k).ndarray for k in self.dependencies()] - params = [p for p in self._params.values()] + + # validate deps: + self._validate_dependencies(factory) + args = {k: factory.get(v).ndarray for k, v in self._dependencies.items()} + args.update(self._params) + results = self._func(**args) + ## TODO: check order of return values + results = (results,) if isinstance(results, xp.ndarray) else results + + self._fields = {k: gtx.as_field(tuple(self._compute_domain.keys()), results[i]) for i, k in enumerate(self._fields.keys())} + + def _validate_dependencies(self, factory): + func_signature = inspect.signature(self._func) + parameters = func_signature.parameters + for dep_key in self._dependencies.keys(): + try: + parameter_definition = parameters[dep_key] + if parameter_definition.annotation != xp.ndarray: # also allow for gtx.Field ??? + raise ValueError(f"Dependency {dep_key} in function {self._func.__name__} : {func_signature} is not of type xp.ndarray") + except KeyError: + raise ValueError(f"Argument {dep_key} does not exist in {self._func.__name__} : {func_signature}.") - results = self._func(*deps, *params) - self._fields = {k: results[i] for i, k in enumerate(self._fields.keys())} + for param_key, param_value in self._params.items(): + try: + parameter_definition = parameters[param_key] + if parameter_definition.annotation != type(param_value): + raise ValueError(f"parameter {parameter_definition} to function {self._func.__name__} has the wrong type") + except KeyError: + raise ValueError(f"Argument {param_key} does not exist in {self._func.__name__} : {func_signature}.") + + def dependencies(self) -> Iterable[str]: + return self._dependencies.values() + + def fields(self) -> Iterable[str]: + return self._fields.keys() + + def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: + if field_name not in self._fields.keys(): + raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}") + if any([f is None for f in self._fields.values()]): + self._evaluate(factory) + return self._fields[field_name] def inspect_func(func:Callable): signa = inspect.signature(func) diff --git a/model/common/src/icon4py/model/common/states/metadata.py b/model/common/src/icon4py/model/common/states/metadata.py index 7454b5cc3..e6f50a088 100644 --- a/model/common/src/icon4py/model/common/states/metadata.py +++ b/model/common/src/icon4py/model/common/states/metadata.py @@ -35,7 +35,7 @@ units="", dims=(dims.KHalfDim,), icon_var_name="k_index", dtype=gtx.int32), - "weight_factor_for_quadratic_interpolation_to_cell_surface": dict(standard_name="weight_factor_for_quadratic_interpolation_to_cell_surface", + "weighting_factor_for_quadratic_interpolation_to_cell_surface": dict(standard_name="weighting_factor_for_quadratic_interpolation_to_cell_surface", units="", dims=(dims.CellDim, dims.KDim), dtype=ta.wpfloat, diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index 7feefcc41..6d7ce0987 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -93,23 +93,23 @@ def test_numpy_func(icon_grid, metrics_savepoint, backend): fields_factory = factory.FieldsFactory(grid=icon_grid, backend=backend) k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() + wgtfacq_c_ref = metrics_savepoint.wgtfacq_c_dsl() pre_computed_fields = factory.PrecomputedFieldsProvider( {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index}) fields_factory.register_provider(pre_computed_fields) func = compute_wgtfacq_c_dsl - signature = factory.inspect_func(compute_wgtfacq_c_dsl) + deps = {"z_ifc": "height_on_interface_levels"} + params = {"nlev": icon_grid.num_levels} compute_wgtfacq_c_provider = factory.NumpyFieldsProvider(func=func, domain={dims.CellDim: (0, icon_grid.num_cells), dims.KDim: (0, icon_grid.num_levels)}, fields=[ "weighting_factor_for_quadratic_interpolation_to_cell_surface"], - deps=[ - "height_on_interface_levels"], - params={ - "num_lev": icon_grid.num_levels}) + deps=deps, + params=params) fields_factory.register_provider(compute_wgtfacq_c_provider) - fields_factory.get("weighting_factor_for_quadratic_interpolation_to_cell_surface", factory.RetrievalType.FIELD) - \ No newline at end of file + wgtfacq_c = fields_factory.get("weighting_factor_for_quadratic_interpolation_to_cell_surface", factory.RetrievalType.FIELD) + assert helpers.dallclose(wgtfacq_c.asnumpy(), wgtfacq_c_ref.asnumpy()) \ No newline at end of file From 9f042b11f86d2d4326650c9ae59cdb4fe0d356fe Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Tue, 20 Aug 2024 10:39:31 +0200 Subject: [PATCH 13/37] fix: move _unallocated to ProgramFieldProvider --- model/common/src/icon4py/model/common/states/factory.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 454cd1a93..850a5fa96 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -67,8 +67,7 @@ def dependencies(self) -> Iterable[str]: def fields(self) -> Iterable[str]: pass - def _unallocated(self) -> bool: - return not all(self._fields.values()) + class PrecomputedFieldsProvider(FieldProvider): """Simple FieldProvider that does not do any computation but gets its fields at construction and returns it upon provider.get(field_name).""" @@ -110,7 +109,8 @@ def __init__(self, self._params = params self._fields: dict[str, Optional[gtx.Field | Scalar]] = {name: None for name in fields} - + def _unallocated(self) -> bool: + return not all(self._fields.values()) def _allocate(self, allocator, grid:base_grid.BaseGrid) -> dict[str, FieldType]: def _map_size(dim:gtx.Dimension, grid:base_grid.BaseGrid) -> int: @@ -151,7 +151,6 @@ def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: if field_name not in self._fields.keys(): raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}") if self._unallocated(): - self._evaluate(factory) return self._fields[field_name] From 809f06094fb92ec8ac7a510c4e331c42532d93c4 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Tue, 20 Aug 2024 14:05:11 +0200 Subject: [PATCH 14/37] move joint functionality into FieldProvider --- .../src/icon4py/model/common/exceptions.py | 4 +- .../icon4py/model/common/states/factory.py | 322 +++++++++--------- .../icon4py/model/common/states/metadata.py | 94 +++-- .../common/tests/states_test/test_factory.py | 119 ++++--- 4 files changed, 281 insertions(+), 258 deletions(-) diff --git a/model/common/src/icon4py/model/common/exceptions.py b/model/common/src/icon4py/model/common/exceptions.py index c55f668e4..418c1bd9b 100644 --- a/model/common/src/icon4py/model/common/exceptions.py +++ b/model/common/src/icon4py/model/common/exceptions.py @@ -10,9 +10,11 @@ class InvalidConfigError(Exception): pass + class IncompleteSetupError(Exception): def __init__(self, msg): - super().__init__(f"{msg}" ) + super().__init__(f"{msg}") + class IncompleteStateError(Exception): def __init__(self, field_name): diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 850a5fa96..67ec0a348 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -1,7 +1,14 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + import abc import functools import inspect -import operator from enum import IntEnum from typing import Callable, Iterable, Optional, Protocol, Sequence, TypeAlias, TypeVar, Union @@ -20,100 +27,110 @@ DimT = TypeVar("DimT", dims.KDim, dims.KHalfDim, dims.CellDim, dims.EdgeDim, dims.VertexDim) Scalar: TypeAlias = Union[ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64] -FieldType:TypeAlias = gtx.Field[Sequence[gtx.Dims[DimT]], T] -class RetrievalType(IntEnum): - FIELD = 0, - DATA_ARRAY = 1, - METADATA = 2, +FieldType: TypeAlias = gtx.Field[Sequence[gtx.Dims[DimT]], T] +class RetrievalType(IntEnum): + FIELD = (0,) + DATA_ARRAY = (1,) + METADATA = (2,) def valid(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): if not self.validate(): - raise exceptions.IncompleteSetupError("Factory not fully instantiated, missing grid or allocator") + raise exceptions.IncompleteSetupError( + "Factory not fully instantiated, missing grid or allocator" + ) return func(self, *args, **kwargs) + return wrapper class FieldProvider(Protocol): """ Protocol for field providers. - + A field provider is responsible for the computation and caching of a set of fields. The fields can be accessed by their field_name (str). - - A FieldProvider has three methods: - - evaluate: computes the fields based on the instructions of concrete implementation - - get: returns the field with the given field_name. - - fields: returns the list of field names provided by the - + + A FieldProvider is a callable that has three methods (except for __call__): + - evaluate (abstract) : computes the fields based on the instructions of the concrete implementation + - fields(): returns the list of field names provided by the provider + - dependencies(): returns a list of field_names that the fields provided by this provider depend on. + + evaluate must be implemented, for the others default implementations are provided. """ - @abc.abstractmethod - def _evaluate(self, factory:'FieldsFactory') -> None: - pass + + def __init__(self, func: Callable): + self._func = func + self._fields: dict[str, Optional[FieldType]] = {} + self._dependencies: dict[str, str] = {} @abc.abstractmethod - def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: + def evaluate(self, factory: "FieldsFactory") -> None: pass - @abc.abstractmethod + def __call__(self, field_name: str, factory: "FieldsFactory") -> FieldType: + if field_name not in self.fields(): + raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}.") + if any([f is None for f in self._fields.values()]): + self.evaluate(factory) + return self._fields[field_name] + def dependencies(self) -> Iterable[str]: - pass + return self._dependencies.values() - @abc.abstractmethod def fields(self) -> Iterable[str]: - pass - + return self._fields.keys() class PrecomputedFieldsProvider(FieldProvider): """Simple FieldProvider that does not do any computation but gets its fields at construction and returns it upon provider.get(field_name).""" - + def __init__(self, fields: dict[str, FieldType]): self._fields = fields - - def _evaluate(self, factory: 'FieldsFactory') -> None: + + def evaluate(self, factory: "FieldsFactory") -> None: pass - + def dependencies(self) -> Sequence[str]: return [] - - def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: - return self._fields[field_name] - - def fields(self) -> Iterable[str]: - return self._fields.keys() + def __call__(self, field_name: str, factory: "FieldsFactory") -> FieldType: + return self._fields[field_name] -class ProgramFieldProvider: +class ProgramFieldProvider(FieldProvider): """ Computes a field defined by a GT4Py Program. """ - def __init__(self, - func: Union[gtx_decorator.Program, Callable], - domain: dict[gtx.Dimension:tuple[gtx.int32, gtx.int32]], # the compute domain - fields: Sequence[str], - deps: Sequence[str] = [], # the dependencies of func - params: dict[str, Scalar] = {}, # the parameters of func - ): - self._compute_domain = domain - self._dims = domain.keys() + def __init__( + self, + func: gtx_decorator.Program, + domain: dict[gtx.Dimension : tuple[gtx.int32, gtx.int32]], + fields: dict[str:str], + deps: dict[str, str], + params: Optional[dict[str, Scalar]] = None, + ): self._func = func + self._compute_domain = domain self._dependencies = deps - self._params = params - self._fields: dict[str, Optional[gtx.Field | Scalar]] = {name: None for name in fields} + self._output = fields + self._params = params if params is not None else {} + self._dims = self._domain_args() + self._fields: dict[str, Optional[gtx.Field | Scalar]] = { + name: None for name in fields.values() + } def _unallocated(self) -> bool: return not all(self._fields.values()) - def _allocate(self, allocator, grid:base_grid.BaseGrid) -> dict[str, FieldType]: - def _map_size(dim:gtx.Dimension, grid:base_grid.BaseGrid) -> int: + def _allocate(self, allocator, grid: base_grid.BaseGrid) -> dict[str, FieldType]: + def _map_size(dim: gtx.Dimension, grid: base_grid.BaseGrid) -> int: if dim == dims.KHalfDim: return grid.num_levels + 1 return grid.size[dim] @@ -123,155 +140,136 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension: return dims.KDim return dim - field_domain = {_map_dim(dim): (0, _map_size(dim, grid)) for dim in - self._compute_domain.keys()} - return {k: allocator(field_domain, dtype=metadata.attrs[k]["dtype"]) for k in - self._fields.keys()} - - - - def _evaluate(self, factory: 'FieldsFactory'): + field_domain = { + _map_dim(dim): (0, _map_size(dim, grid)) for dim in self._compute_domain.keys() + } + return { + k: allocator(field_domain, dtype=metadata.attrs[k]["dtype"]) + for k in self._fields.keys() + } + + def _domain_args(self) -> dict[str : gtx.int32]: + domain_args = {} + for dim in self._compute_domain: + if dim.kind == gtx.DimensionKind.HORIZONTAL: + domain_args.update( + { + "horizontal_start": self._compute_domain[dim][0], + "horizontal_end": self._compute_domain[dim][1], + } + ) + elif dim.kind == gtx.DimensionKind.VERTICAL: + domain_args.update( + { + "vertical_start": self._compute_domain[dim][0], + "vertical_end": self._compute_domain[dim][1], + } + ) + else: + raise ValueError(f"DimensionKind '{dim.kind}' not supported in Program Domain") + return domain_args + + def evaluate(self, factory: "FieldsFactory"): self._fields = self._allocate(factory.allocator, factory.grid) - domain = functools.reduce(operator.add, self._compute_domain.values()) - deps = [factory.get(k) for k in self.dependencies()] - params = [p for p in self._params.values()] - output = [f for f in self._fields.values()] - # it might be safer to call the field_operator here? then we can use the keyword only args for out= and domain= - self._func(*deps, *output, *params, *domain, - offset_provider=factory.grid.offset_providers) - - - def fields(self)->Iterable[str]: - return self._fields.keys() - - def dependencies(self)->Iterable[str]: - return self._dependencies - - def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: - if field_name not in self._fields.keys(): - raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}") - if self._unallocated(): - self._evaluate(factory) - return self._fields[field_name] + deps = {k: factory.get(v) for k, v in self._dependencies.items()} + deps.update(self._params) + deps.update({k: self._fields[v] for k, v in self._output.items()}) + deps.update(self._dims) + self._func(**deps, offset_provider=factory.grid.offset_providers) + + def fields(self) -> Iterable[str]: + return self._output.values() class NumpyFieldsProvider(FieldProvider): - def __init__(self, func:Callable, - domain:dict[gtx.Dimension:tuple[gtx.int32, gtx.int32]], - fields:Sequence[str], - deps:dict[str, str], - params:dict[str, Scalar] = {}): - self._compute_domain = domain + def __init__( + self, + func: Callable, + domain: dict[gtx.Dimension : tuple[gtx.int32, gtx.int32]], + fields: Sequence[str], + deps: dict[str, str], + params: Optional[dict[str, Scalar]] = None, + ): self._func = func - self._fields:dict[str, Optional[FieldType]] = {name: None for name in fields} + self._compute_domain = domain + self._dims = domain.keys() + self._fields: dict[str, Optional[FieldType]] = {name: None for name in fields} self._dependencies = deps - self._params = params - - def _evaluate(self, factory: 'FieldsFactory') -> None: - domain = {dim: range(*self._compute_domain[dim]) for dim in self._compute_domain.keys()} - - # validate deps: - self._validate_dependencies(factory) + self._params = params if params is not None else {} + + def evaluate(self, factory: "FieldsFactory") -> None: + self._validate_dependencies() args = {k: factory.get(v).ndarray for k, v in self._dependencies.items()} args.update(self._params) results = self._func(**args) - ## TODO: check order of return values + ## TODO: can the order of return values be checked? results = (results,) if isinstance(results, xp.ndarray) else results - - self._fields = {k: gtx.as_field(tuple(self._compute_domain.keys()), results[i]) for i, k in enumerate(self._fields.keys())} - def _validate_dependencies(self, factory): + self._fields = { + k: gtx.as_field(tuple(self._dims), results[i]) for i, k in enumerate(self.fields()) + } + + def _validate_dependencies(self): func_signature = inspect.signature(self._func) parameters = func_signature.parameters for dep_key in self._dependencies.keys(): - try: - parameter_definition = parameters[dep_key] - if parameter_definition.annotation != xp.ndarray: # also allow for gtx.Field ??? - raise ValueError(f"Dependency {dep_key} in function {self._func.__name__} : {func_signature} is not of type xp.ndarray") - except KeyError: - raise ValueError(f"Argument {dep_key} does not exist in {self._func.__name__} : {func_signature}.") - - - for param_key, param_value in self._params.items(): - try: - parameter_definition = parameters[param_key] - if parameter_definition.annotation != type(param_value): - raise ValueError(f"parameter {parameter_definition} to function {self._func.__name__} has the wrong type") - except KeyError: - raise ValueError(f"Argument {param_key} does not exist in {self._func.__name__} : {func_signature}.") - - def dependencies(self) -> Iterable[str]: - return self._dependencies.values() - - def fields(self) -> Iterable[str]: - return self._fields.keys() - - def __call__(self, field_name: str, factory:'FieldsFactory') -> FieldType: - if field_name not in self._fields.keys(): - raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}") - if any([f is None for f in self._fields.values()]): - self._evaluate(factory) - return self._fields[field_name] + parameter_definition = parameters.get(dep_key) + if parameter_definition is None or parameter_definition.annotation != xp.ndarray: + raise ValueError( + f"Dependency {dep_key} in function {self._func.__name__} : does not exist in {func_signature} or has wrong type ('expected np.ndarray')" + ) -def inspect_func(func:Callable): - signa = inspect.signature(func) - print(f"signature: {signa}") - print(f"parameters: {signa.parameters}") - - print(f"return : {signa.return_annotation}") - return signa + for param_key, param_value in self._params.items(): + parameter_definition = parameters.get(param_key) + if parameter_definition is None or parameter_definition.annotation != type(param_value): + raise ValueError( + f"parameter {param_key} in function {self._func.__name__} does not exist or has the has the wrong type: {type(param_value)}" + ) - - - class FieldsFactory: """ Factory for fields. - - Lazily compute fields and cache them. + + Lazily compute fields and cache them. """ - - def __init__(self, grid:base_grid.BaseGrid = None, backend=settings.backend): + def __init__(self, grid: base_grid.BaseGrid = None, backend=settings.backend): self._grid = grid - self._providers: dict[str, 'FieldProvider'] = {} + self._providers: dict[str, "FieldProvider"] = {} self._allocator = gtx.constructors.zeros.partial(allocator=backend) - def validate(self): return self._grid is not None and self._allocator is not None - + @builder.builder - def with_grid(self, grid:base_grid.BaseGrid): + def with_grid(self, grid: base_grid.BaseGrid): self._grid = grid - + @builder.builder - def with_allocator(self, backend = settings.backend): + def with_allocator(self, backend=settings.backend): self._allocator = backend - - - + @property def grid(self): return self._grid - + @property def allocator(self): return self._allocator - - def register_provider(self, provider:FieldProvider): - + + def register_provider(self, provider: FieldProvider): for dependency in provider.dependencies(): if dependency not in self._providers.keys(): raise ValueError(f"Dependency '{dependency}' not found in registered providers") - - + for field in provider.fields(): self._providers[field] = provider - + @valid - def get(self, field_name: str, type_: RetrievalType = RetrievalType.FIELD) -> Union[FieldType, xa.DataArray, dict]: + def get( + self, field_name: str, type_: RetrievalType = RetrievalType.FIELD + ) -> Union[FieldType, xa.DataArray, dict]: if field_name not in metadata.attrs: raise ValueError(f"Field {field_name} not found in metric fields") if type_ == RetrievalType.METADATA: @@ -279,23 +277,11 @@ def get(self, field_name: str, type_: RetrievalType = RetrievalType.FIELD) -> Un if type_ == RetrievalType.FIELD: return self._providers[field_name](field_name, self) if type_ == RetrievalType.DATA_ARRAY: - return to_data_array(self._providers[field_name](field_name), metadata.attrs[field_name]) + return to_data_array( + self._providers[field_name](field_name), metadata.attrs[field_name] + ) raise ValueError(f"Invalid retrieval type {type_}") - - - def to_data_array(field, attrs): return xa.DataArray(field, attrs=attrs) - - - - - - - - - - - diff --git a/model/common/src/icon4py/model/common/states/metadata.py b/model/common/src/icon4py/model/common/states/metadata.py index e6f50a088..93462fe3b 100644 --- a/model/common/src/icon4py/model/common/states/metadata.py +++ b/model/common/src/icon4py/model/common/states/metadata.py @@ -1,4 +1,10 @@ - +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause import gt4py.next as gtx @@ -6,39 +12,53 @@ from icon4py.model.common import dimension as dims, type_alias as ta -attrs = {"functional_determinant_of_metrics_on_interface_levels":dict( - standard_name="functional_determinant_of_metrics_on_interface_levels", - long_name="functional determinant of the metrics [sqrt(gamma)] on half levels", - units="", - dims=(dims.CellDim, dims.KHalfDim), - dtype=ta.wpfloat, - icon_var_name="ddqz_z_half", - ), - "height": dict(standard_name="height", - long_name="height", - units="m", - dims=(dims.CellDim, dims.KDim), - icon_var_name="z_mc", dtype = ta.wpfloat) , - "height_on_interface_levels": dict(standard_name="height_on_interface_levels", - long_name="height_on_interface_levels", - units="m", - dims=(dims.CellDim, dims.KHalfDim), - icon_var_name="z_ifc", - dtype = ta.wpfloat), - "model_level_number": dict(standard_name="model_level_number", - long_name="model level number", - units="", dims=(dims.KDim,), - icon_var_name="k_index", - dtype = gtx.int32), - cf_utils.INTERFACE_LEVEL_STANDARD_NAME: dict(standard_name=cf_utils.INTERFACE_LEVEL_STANDARD_NAME, - long_name="model interface level number", - units="", dims=(dims.KHalfDim,), - icon_var_name="k_index", - dtype=gtx.int32), - "weighting_factor_for_quadratic_interpolation_to_cell_surface": dict(standard_name="weighting_factor_for_quadratic_interpolation_to_cell_surface", - units="", - dims=(dims.CellDim, dims.KDim), - dtype=ta.wpfloat, - icon_var_name="wgtfacq_c_dsl", - long_name="weighting factor for quadratic interpolation to cell surface"), - } \ No newline at end of file +attrs = { + "functional_determinant_of_metrics_on_interface_levels": dict( + standard_name="functional_determinant_of_metrics_on_interface_levels", + long_name="functional determinant of the metrics [sqrt(gamma)] on half levels", + units="", + dims=(dims.CellDim, dims.KHalfDim), + dtype=ta.wpfloat, + icon_var_name="ddqz_z_half", + ), + "height": dict( + standard_name="height", + long_name="height", + units="m", + dims=(dims.CellDim, dims.KDim), + icon_var_name="z_mc", + dtype=ta.wpfloat, + ), + "height_on_interface_levels": dict( + standard_name="height_on_interface_levels", + long_name="height_on_interface_levels", + units="m", + dims=(dims.CellDim, dims.KHalfDim), + icon_var_name="z_ifc", + dtype=ta.wpfloat, + ), + "model_level_number": dict( + standard_name="model_level_number", + long_name="model level number", + units="", + dims=(dims.KDim,), + icon_var_name="k_index", + dtype=gtx.int32, + ), + cf_utils.INTERFACE_LEVEL_STANDARD_NAME: dict( + standard_name=cf_utils.INTERFACE_LEVEL_STANDARD_NAME, + long_name="model interface level number", + units="", + dims=(dims.KHalfDim,), + icon_var_name="k_index", + dtype=gtx.int32, + ), + "weighting_factor_for_quadratic_interpolation_to_cell_surface": dict( + standard_name="weighting_factor_for_quadratic_interpolation_to_cell_surface", + units="", + dims=(dims.CellDim, dims.KDim), + dtype=ta.wpfloat, + icon_var_name="wgtfacq_c_dsl", + long_name="weighting factor for quadratic interpolation to cell surface", + ), +} diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index 6d7ce0987..103a48c1e 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -1,3 +1,11 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + import gt4py.next as gtx import pytest @@ -13,25 +21,26 @@ @pytest.mark.datatest def test_check_dependencies_on_register(icon_grid, backend): fields_factory = factory.FieldsFactory(icon_grid, backend) - provider = factory.ProgramFieldProvider(func=mf.compute_z_mc, - domain={dims.CellDim: (0, icon_grid.num_cells), - dims.KDim: (0, icon_grid.num_levels)}, - fields=["height"], - deps=["height_on_interface_levels"], - ) + provider = factory.ProgramFieldProvider( + func=mf.compute_z_mc, + domain={dims.CellDim: (0, icon_grid.num_cells), dims.KDim: (0, icon_grid.num_levels)}, + fields={"z_mc": "height"}, + deps={"z_ifc": "height_on_interface_levels"}, + ) with pytest.raises(ValueError) as e: fields_factory.register_provider(provider) assert e.value.match("'height_on_interface_levels' not found") - + @pytest.mark.datatest def test_factory_raise_error_if_no_grid_or_backend_set(metrics_savepoint): z_ifc = metrics_savepoint.z_ifc() - k_index = gtx.as_field((dims.KDim,), xp.arange( 1, dtype=gtx.int32)) + k_index = gtx.as_field((dims.KDim,), xp.arange(1, dtype=gtx.int32)) pre_computed_fields = factory.PrecomputedFieldsProvider( - {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index}) + {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} + ) fields_factory = factory.FieldsFactory(None, None) - fields_factory.register_provider(pre_computed_fields) + fields_factory.register_provider(pre_computed_fields) with pytest.raises(exceptions.IncompleteSetupError) as e: fields_factory.get("height_on_interface_levels") assert e.value.match("not fully instantiated") @@ -40,16 +49,17 @@ def test_factory_raise_error_if_no_grid_or_backend_set(metrics_savepoint): @pytest.mark.datatest def test_factory_returns_field(metrics_savepoint, icon_grid, backend): z_ifc = metrics_savepoint.z_ifc() - k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels +1, dtype=gtx.int32)) + k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) pre_computed_fields = factory.PrecomputedFieldsProvider( - {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index}) + {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} + ) fields_factory = factory.FieldsFactory(None, None) fields_factory.register_provider(pre_computed_fields) fields_factory.with_grid(icon_grid).with_allocator(backend) field = fields_factory.get("height_on_interface_levels", factory.RetrievalType.FIELD) assert field.ndarray.shape == (icon_grid.num_cells, icon_grid.num_levels + 1) - - + + @pytest.mark.datatest def test_field_provider(icon_grid, metrics_savepoint, backend): fields_factory = factory.FieldsFactory(icon_grid, backend) @@ -57,59 +67,64 @@ def test_field_provider(icon_grid, metrics_savepoint, backend): z_ifc = metrics_savepoint.z_ifc() pre_computed_fields = factory.PrecomputedFieldsProvider( - {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index}) - + {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} + ) + fields_factory.register_provider(pre_computed_fields) - - height_provider = factory.ProgramFieldProvider(func=mf.compute_z_mc, - domain={dims.CellDim: (0, icon_grid.num_cells), - dims.KDim: (0, icon_grid.num_levels)}, - fields=["height"], - deps=["height_on_interface_levels"], - ) + + height_provider = factory.ProgramFieldProvider( + func=mf.compute_z_mc, + domain={dims.CellDim: (0, icon_grid.num_cells), dims.KDim: (0, icon_grid.num_levels)}, + fields={"z_mc": "height"}, + deps={"z_ifc": "height_on_interface_levels"}, + ) fields_factory.register_provider(height_provider) - functional_determinant_provider = factory.ProgramFieldProvider(func=mf.compute_ddqz_z_half, - domain={dims.CellDim: (0,icon_grid.num_cells), - dims.KHalfDim: ( - 0, - icon_grid.num_levels + 1)}, - fields=[ - "functional_determinant_of_metrics_on_interface_levels"], - deps=[ - "height_on_interface_levels", - "height", - cf_utils.INTERFACE_LEVEL_STANDARD_NAME], - params={ - "num_lev": icon_grid.num_levels}) + functional_determinant_provider = factory.ProgramFieldProvider( + func=mf.compute_ddqz_z_half, + domain={ + dims.CellDim: (0, icon_grid.num_cells), + dims.KHalfDim: (0, icon_grid.num_levels + 1), + }, + fields={"ddqz_z_half": "functional_determinant_of_metrics_on_interface_levels"}, + deps={ + "z_ifc": "height_on_interface_levels", + "z_mc": "height", + "k": cf_utils.INTERFACE_LEVEL_STANDARD_NAME, + }, + params={"nlev": icon_grid.num_levels}, + ) fields_factory.register_provider(functional_determinant_provider) - - data = fields_factory.get("functional_determinant_of_metrics_on_interface_levels", - type_=factory.RetrievalType.FIELD) + data = fields_factory.get( + "functional_determinant_of_metrics_on_interface_levels", type_=factory.RetrievalType.FIELD + ) ref = metrics_savepoint.ddqz_z_half().ndarray assert helpers.dallclose(data.ndarray, ref) -def test_numpy_func(icon_grid, metrics_savepoint, backend): +def test_numpy_function_evaluation(icon_grid, metrics_savepoint, backend): fields_factory = factory.FieldsFactory(grid=icon_grid, backend=backend) k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() wgtfacq_c_ref = metrics_savepoint.wgtfacq_c_dsl() pre_computed_fields = factory.PrecomputedFieldsProvider( - {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index}) + {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} + ) fields_factory.register_provider(pre_computed_fields) func = compute_wgtfacq_c_dsl deps = {"z_ifc": "height_on_interface_levels"} params = {"nlev": icon_grid.num_levels} - compute_wgtfacq_c_provider = factory.NumpyFieldsProvider(func=func, - domain={dims.CellDim: (0, icon_grid.num_cells), - dims.KDim: (0, icon_grid.num_levels)}, - fields=[ - "weighting_factor_for_quadratic_interpolation_to_cell_surface"], - deps=deps, - params=params) + compute_wgtfacq_c_provider = factory.NumpyFieldsProvider( + func=func, + domain={dims.CellDim: (0, icon_grid.num_cells), dims.KDim: (0, icon_grid.num_levels)}, + fields=["weighting_factor_for_quadratic_interpolation_to_cell_surface"], + deps=deps, + params=params, + ) fields_factory.register_provider(compute_wgtfacq_c_provider) - - - wgtfacq_c = fields_factory.get("weighting_factor_for_quadratic_interpolation_to_cell_surface", factory.RetrievalType.FIELD) - assert helpers.dallclose(wgtfacq_c.asnumpy(), wgtfacq_c_ref.asnumpy()) \ No newline at end of file + + wgtfacq_c = fields_factory.get( + "weighting_factor_for_quadratic_interpolation_to_cell_surface", factory.RetrievalType.FIELD + ) + + assert helpers.dallclose(wgtfacq_c.asnumpy(), wgtfacq_c_ref.asnumpy()) From bcd65b57426f0bf7d76498ebfc45e2f5b6252eb0 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Tue, 20 Aug 2024 15:10:48 +0200 Subject: [PATCH 15/37] - switch to device dependent import in compute_wgtfacq.py - cleanup --- .../model/common/metrics/compute_wgtfacq.py | 28 +++++------ .../icon4py/model/common/states/factory.py | 46 ++++++++----------- .../src/icon4py/model/common/states/utils.py | 18 ++++++++ .../common/tests/states_test/test_factory.py | 29 ++++++++---- 4 files changed, 72 insertions(+), 49 deletions(-) create mode 100644 model/common/src/icon4py/model/common/states/utils.py diff --git a/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py b/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py index cd8874377..ad4cd0148 100644 --- a/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py +++ b/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py @@ -6,12 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np +from icon4py.model.common.settings import xp def _compute_z1_z2_z3( - z_ifc: np.ndarray, i1: int, i2: int, i3: int, i4: int -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + z_ifc: xp.ndarray, i1: int, i2: int, i3: int, i4: int +) -> tuple[xp.ndarray, xp.ndarray, xp.ndarray]: z1 = 0.5 * (z_ifc[:, i2] - z_ifc[:, i1]) z2 = 0.5 * (z_ifc[:, i2] + z_ifc[:, i3]) - z_ifc[:, i1] z3 = 0.5 * (z_ifc[:, i3] + z_ifc[:, i4]) - z_ifc[:, i1] @@ -19,9 +19,9 @@ def _compute_z1_z2_z3( def compute_wgtfacq_c_dsl( - z_ifc: np.ndarray, + z_ifc: xp.ndarray, nlev: int, -) -> np.ndarray: +) -> xp.ndarray: """ Compute weighting factor for quadratic interpolation to surface. @@ -31,8 +31,8 @@ def compute_wgtfacq_c_dsl( Returns: Field[CellDim, KDim] (full levels) """ - wgtfacq_c = np.zeros((z_ifc.shape[0], nlev + 1)) - wgtfacq_c_dsl = np.zeros((z_ifc.shape[0], nlev)) + wgtfacq_c = xp.zeros((z_ifc.shape[0], nlev + 1)) + wgtfacq_c_dsl = xp.zeros((z_ifc.shape[0], nlev)) z1, z2, z3 = _compute_z1_z2_z3(z_ifc, nlev, nlev - 1, nlev - 2, nlev - 3) wgtfacq_c[:, 2] = z1 * z2 / (z2 - z3) / (z1 - z3) @@ -48,9 +48,9 @@ def compute_wgtfacq_c_dsl( def compute_wgtfacq_e_dsl( e2c, - z_ifc: np.ndarray, - c_lin_e: np.ndarray, - wgtfacq_c_dsl: np.ndarray, + z_ifc: xp.ndarray, + c_lin_e: xp.ndarray, + wgtfacq_c_dsl: xp.ndarray, n_edges: int, nlev: int, ): @@ -67,8 +67,8 @@ def compute_wgtfacq_e_dsl( Returns: Field[EdgeDim, KDim] (full levels) """ - wgtfacq_e_dsl = np.zeros(shape=(n_edges, nlev + 1)) - z_aux_c = np.zeros((z_ifc.shape[0], 6)) + wgtfacq_e_dsl = xp.zeros(shape=(n_edges, nlev + 1)) + z_aux_c = xp.zeros((z_ifc.shape[0], 6)) z1, z2, z3 = _compute_z1_z2_z3(z_ifc, nlev, nlev - 1, nlev - 2, nlev - 3) z_aux_c[:, 2] = z1 * z2 / (z2 - z3) / (z1 - z3) z_aux_c[:, 1] = (z1 - wgtfacq_c_dsl[:, nlev - 3] * (z1 - z3)) / (z1 - z2) @@ -79,8 +79,8 @@ def compute_wgtfacq_e_dsl( z_aux_c[:, 4] = (z1 - z_aux_c[:, 5] * (z1 - z3)) / (z1 - z2) z_aux_c[:, 3] = 1.0 - (z_aux_c[:, 4] + z_aux_c[:, 5]) - c_lin_e = c_lin_e[:, :, np.newaxis] - z_aux_e = np.sum(c_lin_e * z_aux_c[e2c], axis=1) + c_lin_e = c_lin_e[:, :, xp.newaxis] + z_aux_e = xp.sum(c_lin_e * z_aux_c[e2c], axis=1) wgtfacq_e_dsl[:, nlev] = z_aux_e[:, 0] wgtfacq_e_dsl[:, nlev - 1] = z_aux_e[:, 1] diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 67ec0a348..6eac491ed 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -7,30 +7,24 @@ # SPDX-License-Identifier: BSD-3-Clause import abc +import enum import functools import inspect -from enum import IntEnum -from typing import Callable, Iterable, Optional, Protocol, Sequence, TypeAlias, TypeVar, Union +from typing import Callable, Iterable, Optional, Protocol, Sequence, Union import gt4py.next as gtx import gt4py.next.ffront.decorator as gtx_decorator import xarray as xa import icon4py.model.common.states.metadata as metadata -from icon4py.model.common import dimension as dims, exceptions, settings, type_alias as ta +from icon4py.model.common import dimension as dims, exceptions, settings from icon4py.model.common.grid import base as base_grid from icon4py.model.common.settings import xp +from icon4py.model.common.states import utils as state_utils from icon4py.model.common.utils import builder -T = TypeVar("T", ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64) -DimT = TypeVar("DimT", dims.KDim, dims.KHalfDim, dims.CellDim, dims.EdgeDim, dims.VertexDim) -Scalar: TypeAlias = Union[ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64] - -FieldType: TypeAlias = gtx.Field[Sequence[gtx.Dims[DimT]], T] - - -class RetrievalType(IntEnum): +class RetrievalType(enum.IntEnum): FIELD = (0,) DATA_ARRAY = (1,) METADATA = (2,) @@ -65,14 +59,14 @@ class FieldProvider(Protocol): def __init__(self, func: Callable): self._func = func - self._fields: dict[str, Optional[FieldType]] = {} + self._fields: dict[str, Optional[state_utils.FieldType]] = {} self._dependencies: dict[str, str] = {} @abc.abstractmethod def evaluate(self, factory: "FieldsFactory") -> None: pass - def __call__(self, field_name: str, factory: "FieldsFactory") -> FieldType: + def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType: if field_name not in self.fields(): raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}.") if any([f is None for f in self._fields.values()]): @@ -89,7 +83,7 @@ def fields(self) -> Iterable[str]: class PrecomputedFieldsProvider(FieldProvider): """Simple FieldProvider that does not do any computation but gets its fields at construction and returns it upon provider.get(field_name).""" - def __init__(self, fields: dict[str, FieldType]): + def __init__(self, fields: dict[str, state_utils.FieldType]): self._fields = fields def evaluate(self, factory: "FieldsFactory") -> None: @@ -98,7 +92,7 @@ def evaluate(self, factory: "FieldsFactory") -> None: def dependencies(self) -> Sequence[str]: return [] - def __call__(self, field_name: str, factory: "FieldsFactory") -> FieldType: + def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType: return self._fields[field_name] @@ -114,7 +108,7 @@ def __init__( domain: dict[gtx.Dimension : tuple[gtx.int32, gtx.int32]], fields: dict[str:str], deps: dict[str, str], - params: Optional[dict[str, Scalar]] = None, + params: Optional[dict[str, state_utils.Scalar]] = None, ): self._func = func self._compute_domain = domain @@ -122,14 +116,14 @@ def __init__( self._output = fields self._params = params if params is not None else {} self._dims = self._domain_args() - self._fields: dict[str, Optional[gtx.Field | Scalar]] = { + self._fields: dict[str, Optional[gtx.Field | state_utils.Scalar]] = { name: None for name in fields.values() } def _unallocated(self) -> bool: return not all(self._fields.values()) - def _allocate(self, allocator, grid: base_grid.BaseGrid) -> dict[str, FieldType]: + def _allocate(self, allocator, grid: base_grid.BaseGrid) -> dict[str, state_utils.FieldType]: def _map_size(dim: gtx.Dimension, grid: base_grid.BaseGrid) -> int: if dim == dims.KHalfDim: return grid.num_levels + 1 @@ -188,12 +182,12 @@ def __init__( domain: dict[gtx.Dimension : tuple[gtx.int32, gtx.int32]], fields: Sequence[str], deps: dict[str, str], - params: Optional[dict[str, Scalar]] = None, + params: Optional[dict[str, state_utils.Scalar]] = None, ): self._func = func self._compute_domain = domain self._dims = domain.keys() - self._fields: dict[str, Optional[FieldType]] = {name: None for name in fields} + self._fields: dict[str, Optional[state_utils.FieldType]] = {name: None for name in fields} self._dependencies = deps self._params = params if params is not None else {} @@ -236,11 +230,11 @@ class FieldsFactory: def __init__(self, grid: base_grid.BaseGrid = None, backend=settings.backend): self._grid = grid - self._providers: dict[str, "FieldProvider"] = {} + self._providers: dict[str, 'FieldProvider'] = {} self._allocator = gtx.constructors.zeros.partial(allocator=backend) def validate(self): - return self._grid is not None and self._allocator is not None + return self._grid is not None @builder.builder def with_grid(self, grid: base_grid.BaseGrid): @@ -269,7 +263,7 @@ def register_provider(self, provider: FieldProvider): @valid def get( self, field_name: str, type_: RetrievalType = RetrievalType.FIELD - ) -> Union[FieldType, xa.DataArray, dict]: + ) -> Union[state_utils.FieldType, xa.DataArray, dict]: if field_name not in metadata.attrs: raise ValueError(f"Field {field_name} not found in metric fields") if type_ == RetrievalType.METADATA: @@ -277,11 +271,9 @@ def get( if type_ == RetrievalType.FIELD: return self._providers[field_name](field_name, self) if type_ == RetrievalType.DATA_ARRAY: - return to_data_array( - self._providers[field_name](field_name), metadata.attrs[field_name] + return state_utils.to_data_array( + self._providers[field_name](field_name, self), metadata.attrs[field_name] ) raise ValueError(f"Invalid retrieval type {type_}") -def to_data_array(field, attrs): - return xa.DataArray(field, attrs=attrs) diff --git a/model/common/src/icon4py/model/common/states/utils.py b/model/common/src/icon4py/model/common/states/utils.py new file mode 100644 index 000000000..b8fb58bc5 --- /dev/null +++ b/model/common/src/icon4py/model/common/states/utils.py @@ -0,0 +1,18 @@ +from typing import Sequence, TypeAlias, TypeVar, Union + +import gt4py.next as gtx +import xarray as xa + +from icon4py.model.common import dimension as dims, type_alias as ta +from icon4py.model.common.settings import xp + + +T = TypeVar("T", ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64) +DimT = TypeVar("DimT", dims.KDim, dims.KHalfDim, dims.CellDim, dims.EdgeDim, dims.VertexDim) +Scalar: TypeAlias = Union[ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64] + +FieldType: TypeAlias = Union[gtx.Field[Sequence[gtx.Dims[DimT]], T], xp.ndarray] + +def to_data_array(field:FieldType, attrs:dict): + data = field if isinstance(field, xp.ndarray) else field.ndarray + return xa.DataArray(data, attrs=attrs) diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index 103a48c1e..3fe120d6a 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -13,13 +13,15 @@ from icon4py.model.common import dimension as dims, exceptions from icon4py.model.common.io import cf_utils from icon4py.model.common.metrics import metric_fields as mf -from icon4py.model.common.metrics.compute_wgtfacq import compute_wgtfacq_c_dsl +from icon4py.model.common.metrics.compute_wgtfacq import ( + compute_wgtfacq_c_dsl, +) from icon4py.model.common.settings import xp from icon4py.model.common.states import factory @pytest.mark.datatest -def test_check_dependencies_on_register(icon_grid, backend): +def test_factory_check_dependencies_on_register(icon_grid, backend): fields_factory = factory.FieldsFactory(icon_grid, backend) provider = factory.ProgramFieldProvider( func=mf.compute_z_mc, @@ -33,13 +35,15 @@ def test_check_dependencies_on_register(icon_grid, backend): @pytest.mark.datatest -def test_factory_raise_error_if_no_grid_or_backend_set(metrics_savepoint): +def test_factory_raise_error_if_no_grid_is_set( + metrics_savepoint +): z_ifc = metrics_savepoint.z_ifc() k_index = gtx.as_field((dims.KDim,), xp.arange(1, dtype=gtx.int32)) pre_computed_fields = factory.PrecomputedFieldsProvider( {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} ) - fields_factory = factory.FieldsFactory(None, None) + fields_factory = factory.FieldsFactory(grid=None) fields_factory.register_provider(pre_computed_fields) with pytest.raises(exceptions.IncompleteSetupError) as e: fields_factory.get("height_on_interface_levels") @@ -53,15 +57,24 @@ def test_factory_returns_field(metrics_savepoint, icon_grid, backend): pre_computed_fields = factory.PrecomputedFieldsProvider( {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} ) - fields_factory = factory.FieldsFactory(None, None) + fields_factory = factory.FieldsFactory() fields_factory.register_provider(pre_computed_fields) fields_factory.with_grid(icon_grid).with_allocator(backend) field = fields_factory.get("height_on_interface_levels", factory.RetrievalType.FIELD) assert field.ndarray.shape == (icon_grid.num_cells, icon_grid.num_levels + 1) - + meta = fields_factory.get("height_on_interface_levels", factory.RetrievalType.METADATA) + assert meta["standard_name"] == "height_on_interface_levels" + assert meta["dims"] == (dims.CellDim, dims.KHalfDim,) + assert meta["units"] == "m" + data_array = fields_factory.get("height_on_interface_levels", factory.RetrievalType.DATA_ARRAY) + assert data_array.data.shape == (icon_grid.num_cells, icon_grid.num_levels + 1) + assert data_array.data.dtype == xp.float64 + for key in ("dims", "standard_name", "units", "icon_var_name"): + assert key in data_array.attrs.keys() + @pytest.mark.datatest -def test_field_provider(icon_grid, metrics_savepoint, backend): +def test_field_provider_for_program(icon_grid, metrics_savepoint, backend): fields_factory = factory.FieldsFactory(icon_grid, backend) k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() @@ -101,7 +114,7 @@ def test_field_provider(icon_grid, metrics_savepoint, backend): assert helpers.dallclose(data.ndarray, ref) -def test_numpy_function_evaluation(icon_grid, metrics_savepoint, backend): +def test_field_provider_for_numpy_function(icon_grid, metrics_savepoint, interpolation_savepoint, backend): fields_factory = factory.FieldsFactory(grid=icon_grid, backend=backend) k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() From 52a837d17b8b3b49ed8b19910b6e3ddcbf15ab9b Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Wed, 21 Aug 2024 11:41:12 +0200 Subject: [PATCH 16/37] add type annotation to connectivity --- .../common/src/icon4py/model/common/metrics/compute_wgtfacq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py b/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py index ad4cd0148..0a7c0ad53 100644 --- a/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py +++ b/model/common/src/icon4py/model/common/metrics/compute_wgtfacq.py @@ -47,7 +47,7 @@ def compute_wgtfacq_c_dsl( def compute_wgtfacq_e_dsl( - e2c, + e2c: xp.ndarray, z_ifc: xp.ndarray, c_lin_e: xp.ndarray, wgtfacq_c_dsl: xp.ndarray, From 72e742bda7c3fab3db1fbf4b1a8eb1cfd7be74a4 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Wed, 21 Aug 2024 11:43:40 +0200 Subject: [PATCH 17/37] handle numpy field with connectivity --- .../icon4py/model/common/states/factory.py | 62 ++++++++----- .../icon4py/model/common/states/metadata.py | 17 ++++ .../src/icon4py/model/common/states/utils.py | 16 +++- .../common/tests/states_test/test_factory.py | 86 +++++++++++++++++-- 4 files changed, 148 insertions(+), 33 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 6eac491ed..3019274b3 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -10,17 +10,16 @@ import enum import functools import inspect -from typing import Callable, Iterable, Optional, Protocol, Sequence, Union +from typing import Callable, Iterable, Optional, Protocol, Sequence, Union, get_args import gt4py.next as gtx import gt4py.next.ffront.decorator as gtx_decorator import xarray as xa -import icon4py.model.common.states.metadata as metadata from icon4py.model.common import dimension as dims, exceptions, settings -from icon4py.model.common.grid import base as base_grid +from icon4py.model.common.grid import base as base_grid, icon as icon_grid from icon4py.model.common.settings import xp -from icon4py.model.common.states import utils as state_utils +from icon4py.model.common.states import metadata as metadata, utils as state_utils from icon4py.model.common.utils import builder @@ -105,7 +104,9 @@ class ProgramFieldProvider(FieldProvider): def __init__( self, func: gtx_decorator.Program, - domain: dict[gtx.Dimension : tuple[gtx.int32, gtx.int32]], + domain: dict[ + gtx.Dimension : tuple[Callable[[gtx.Dimension], int], Callable[[gtx.Dimension], int]] + ], fields: dict[str:str], deps: dict[str, str], params: Optional[dict[str, state_utils.Scalar]] = None, @@ -115,7 +116,6 @@ def __init__( self._dependencies = deps self._output = fields self._params = params if params is not None else {} - self._dims = self._domain_args() self._fields: dict[str, Optional[gtx.Field | state_utils.Scalar]] = { name: None for name in fields.values() } @@ -142,14 +142,14 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension: for k in self._fields.keys() } - def _domain_args(self) -> dict[str : gtx.int32]: + def _domain_args(self, grid: icon_grid.IconGrid) -> dict[str : gtx.int32]: domain_args = {} for dim in self._compute_domain: if dim.kind == gtx.DimensionKind.HORIZONTAL: domain_args.update( { - "horizontal_start": self._compute_domain[dim][0], - "horizontal_end": self._compute_domain[dim][1], + "horizontal_start": grid.get_start_index(dim, self._compute_domain[dim][0]), + "horizontal_end": grid.get_end_index(dim, self._compute_domain[dim][1]), } ) elif dim.kind == gtx.DimensionKind.VERTICAL: @@ -168,7 +168,8 @@ def evaluate(self, factory: "FieldsFactory"): deps = {k: factory.get(v) for k, v in self._dependencies.items()} deps.update(self._params) deps.update({k: self._fields[v] for k, v in self._output.items()}) - deps.update(self._dims) + dims = self._domain_args(factory.grid) + deps.update(dims) self._func(**deps, offset_provider=factory.grid.offset_providers) def fields(self) -> Iterable[str]: @@ -182,18 +183,23 @@ def __init__( domain: dict[gtx.Dimension : tuple[gtx.int32, gtx.int32]], fields: Sequence[str], deps: dict[str, str], + offsets: Optional[dict[str, gtx.Dimension]] = None, params: Optional[dict[str, state_utils.Scalar]] = None, ): self._func = func self._compute_domain = domain + self._offsets = offsets self._dims = domain.keys() self._fields: dict[str, Optional[state_utils.FieldType]] = {name: None for name in fields} self._dependencies = deps + self._offsets = offsets if offsets is not None else {} self._params = params if params is not None else {} def evaluate(self, factory: "FieldsFactory") -> None: self._validate_dependencies() args = {k: factory.get(v).ndarray for k, v in self._dependencies.items()} + offsets = {k: factory.grid.connectivities[v] for k, v in self._offsets.items()} + args.update(offsets) args.update(self._params) results = self._func(**args) ## TODO: can the order of return values be checked? @@ -208,17 +214,31 @@ def _validate_dependencies(self): parameters = func_signature.parameters for dep_key in self._dependencies.keys(): parameter_definition = parameters.get(dep_key) - if parameter_definition is None or parameter_definition.annotation != xp.ndarray: - raise ValueError( - f"Dependency {dep_key} in function {self._func.__name__} : does not exist in {func_signature} or has wrong type ('expected np.ndarray')" - ) + assert ( + parameter_definition.annotation == xp.ndarray + ), (f"Dependency {dep_key} in function {self._func.__name__}: does not exist or has " + f"or has wrong type ('expected np.ndarray') in {func_signature}.") for param_key, param_value in self._params.items(): parameter_definition = parameters.get(param_key) - if parameter_definition is None or parameter_definition.annotation != type(param_value): - raise ValueError( - f"parameter {param_key} in function {self._func.__name__} does not exist or has the has the wrong type: {type(param_value)}" - ) + checked = _check( + parameter_definition, param_value, union=state_utils.IntegerType + ) or _check(parameter_definition, param_value, union=state_utils.FloatType) + assert checked, (f"Parameter {param_key} in function {self._func.__name__} does not " + f"exist or has the wrong type: {type(param_value)}.") + + +def _check( + parameter_definition: inspect.Parameter, + value: Union[state_utils.Scalar, gtx.Field], + union: Union, +) -> bool: + members = get_args(union) + return ( + parameter_definition is not None + and parameter_definition.annotation in members + and type(value) in members + ) class FieldsFactory: @@ -228,9 +248,9 @@ class FieldsFactory: Lazily compute fields and cache them. """ - def __init__(self, grid: base_grid.BaseGrid = None, backend=settings.backend): + def __init__(self, grid: icon_grid.IconGrid = None, backend=settings.backend): self._grid = grid - self._providers: dict[str, 'FieldProvider'] = {} + self._providers: dict[str, "FieldProvider"] = {} self._allocator = gtx.constructors.zeros.partial(allocator=backend) def validate(self): @@ -275,5 +295,3 @@ def get( self._providers[field_name](field_name, self), metadata.attrs[field_name] ) raise ValueError(f"Invalid retrieval type {type_}") - - diff --git a/model/common/src/icon4py/model/common/states/metadata.py b/model/common/src/icon4py/model/common/states/metadata.py index 93462fe3b..7e1f3773f 100644 --- a/model/common/src/icon4py/model/common/states/metadata.py +++ b/model/common/src/icon4py/model/common/states/metadata.py @@ -61,4 +61,21 @@ icon_var_name="wgtfacq_c_dsl", long_name="weighting factor for quadratic interpolation to cell surface", ), + "weighting_factor_for_quadratic_interpolation_to_edge_center": dict( + standard_name="weighting_factor_for_quadratic_interpolation_to_edge_center", + units="", + dims=(dims.EdgeDim, dims.KDim), + dtype=ta.wpfloat, + icon_var_name="wgtfacq_e_dsl", + long_name="weighting factor for quadratic interpolation to edge centers", + ), + # TODO : FIX + "c_lin_e": dict( + standard_name="c_lin_e", + units="", + dims=(dims.EdgeDim, dims.E2CDim), + dtype=ta.wpfloat, + icon_var_name="c_lin_e", + long_name="interpolation field", + ), } diff --git a/model/common/src/icon4py/model/common/states/utils.py b/model/common/src/icon4py/model/common/states/utils.py index b8fb58bc5..e8ad795ae 100644 --- a/model/common/src/icon4py/model/common/states/utils.py +++ b/model/common/src/icon4py/model/common/states/utils.py @@ -1,3 +1,11 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + from typing import Sequence, TypeAlias, TypeVar, Union import gt4py.next as gtx @@ -9,10 +17,14 @@ T = TypeVar("T", ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64) DimT = TypeVar("DimT", dims.KDim, dims.KHalfDim, dims.CellDim, dims.EdgeDim, dims.VertexDim) -Scalar: TypeAlias = Union[ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64] +FloatType: TypeAlias = Union[ta.wpfloat, ta.vpfloat, float] +IntegerType: TypeAlias = Union[gtx.int32, gtx.int64, int] +Scalar: TypeAlias = Union[FloatType, bool, IntegerType] + FieldType: TypeAlias = Union[gtx.Field[Sequence[gtx.Dims[DimT]], T], xp.ndarray] -def to_data_array(field:FieldType, attrs:dict): + +def to_data_array(field: FieldType, attrs: dict): data = field if isinstance(field, xp.ndarray) else field.ndarray return xa.DataArray(data, attrs=attrs) diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index 3fe120d6a..e1c74f126 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -11,10 +11,12 @@ import icon4py.model.common.test_utils.helpers as helpers from icon4py.model.common import dimension as dims, exceptions +from icon4py.model.common.grid.horizontal import HorizontalMarkerIndex from icon4py.model.common.io import cf_utils from icon4py.model.common.metrics import metric_fields as mf from icon4py.model.common.metrics.compute_wgtfacq import ( compute_wgtfacq_c_dsl, + compute_wgtfacq_e_dsl, ) from icon4py.model.common.settings import xp from icon4py.model.common.states import factory @@ -35,9 +37,7 @@ def test_factory_check_dependencies_on_register(icon_grid, backend): @pytest.mark.datatest -def test_factory_raise_error_if_no_grid_is_set( - metrics_savepoint -): +def test_factory_raise_error_if_no_grid_is_set(metrics_savepoint): z_ifc = metrics_savepoint.z_ifc() k_index = gtx.as_field((dims.KDim,), xp.arange(1, dtype=gtx.int32)) pre_computed_fields = factory.PrecomputedFieldsProvider( @@ -64,14 +64,17 @@ def test_factory_returns_field(metrics_savepoint, icon_grid, backend): assert field.ndarray.shape == (icon_grid.num_cells, icon_grid.num_levels + 1) meta = fields_factory.get("height_on_interface_levels", factory.RetrievalType.METADATA) assert meta["standard_name"] == "height_on_interface_levels" - assert meta["dims"] == (dims.CellDim, dims.KHalfDim,) + assert meta["dims"] == ( + dims.CellDim, + dims.KHalfDim, + ) assert meta["units"] == "m" data_array = fields_factory.get("height_on_interface_levels", factory.RetrievalType.DATA_ARRAY) assert data_array.data.shape == (icon_grid.num_cells, icon_grid.num_levels + 1) assert data_array.data.dtype == xp.float64 for key in ("dims", "standard_name", "units", "icon_var_name"): assert key in data_array.attrs.keys() - + @pytest.mark.datatest def test_field_provider_for_program(icon_grid, metrics_savepoint, backend): @@ -87,7 +90,13 @@ def test_field_provider_for_program(icon_grid, metrics_savepoint, backend): height_provider = factory.ProgramFieldProvider( func=mf.compute_z_mc, - domain={dims.CellDim: (0, icon_grid.num_cells), dims.KDim: (0, icon_grid.num_levels)}, + domain={ + dims.CellDim: ( + HorizontalMarkerIndex.local(dims.CellDim), + HorizontalMarkerIndex.end(dims.CellDim), + ), + dims.KDim: (0, icon_grid.num_levels), + }, fields={"z_mc": "height"}, deps={"z_ifc": "height_on_interface_levels"}, ) @@ -95,7 +104,10 @@ def test_field_provider_for_program(icon_grid, metrics_savepoint, backend): functional_determinant_provider = factory.ProgramFieldProvider( func=mf.compute_ddqz_z_half, domain={ - dims.CellDim: (0, icon_grid.num_cells), + dims.CellDim: ( + HorizontalMarkerIndex.local(dims.CellDim), + HorizontalMarkerIndex.end(dims.CellDim), + ), dims.KHalfDim: (0, icon_grid.num_levels + 1), }, fields={"ddqz_z_half": "functional_determinant_of_metrics_on_interface_levels"}, @@ -114,7 +126,9 @@ def test_field_provider_for_program(icon_grid, metrics_savepoint, backend): assert helpers.dallclose(data.ndarray, ref) -def test_field_provider_for_numpy_function(icon_grid, metrics_savepoint, interpolation_savepoint, backend): +def test_field_provider_for_numpy_function( + icon_grid, metrics_savepoint, interpolation_savepoint, backend +): fields_factory = factory.FieldsFactory(grid=icon_grid, backend=backend) k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() @@ -129,7 +143,10 @@ def test_field_provider_for_numpy_function(icon_grid, metrics_savepoint, interpo params = {"nlev": icon_grid.num_levels} compute_wgtfacq_c_provider = factory.NumpyFieldsProvider( func=func, - domain={dims.CellDim: (0, icon_grid.num_cells), dims.KDim: (0, icon_grid.num_levels)}, + domain={ + dims.CellDim: (0, HorizontalMarkerIndex.end(dims.CellDim)), + dims.KDim: (0, icon_grid.num_levels), + }, fields=["weighting_factor_for_quadratic_interpolation_to_cell_surface"], deps=deps, params=params, @@ -141,3 +158,54 @@ def test_field_provider_for_numpy_function(icon_grid, metrics_savepoint, interpo ) assert helpers.dallclose(wgtfacq_c.asnumpy(), wgtfacq_c_ref.asnumpy()) + + +def test_field_provider_for_numpy_function_with_offsets( + icon_grid, metrics_savepoint, interpolation_savepoint, backend +): + fields_factory = factory.FieldsFactory(grid=icon_grid, backend=backend) + k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) + z_ifc = metrics_savepoint.z_ifc() + c_lin_e = interpolation_savepoint.c_lin_e() + wgtfacq_e_ref = metrics_savepoint.wgtfacq_e_dsl(icon_grid.num_levels + 1) + + pre_computed_fields = factory.PrecomputedFieldsProvider( + { + "height_on_interface_levels": z_ifc, + cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index, + "c_lin_e": c_lin_e, + } + ) + fields_factory.register_provider(pre_computed_fields) + func = compute_wgtfacq_c_dsl + params = {"nlev": icon_grid.num_levels} + compute_wgtfacq_c_provider = factory.NumpyFieldsProvider( + func=func, + domain={dims.CellDim: (0, icon_grid.num_cells), dims.KDim: (0, icon_grid.num_levels)}, + fields=["weighting_factor_for_quadratic_interpolation_to_cell_surface"], + deps={"z_ifc": "height_on_interface_levels"}, + params=params, + ) + deps = { + "z_ifc": "height_on_interface_levels", + "wgtfacq_c_dsl": "weighting_factor_for_quadratic_interpolation_to_cell_surface", + "c_lin_e": "c_lin_e", + } + fields_factory.register_provider(compute_wgtfacq_c_provider) + wgtfacq_e_provider = factory.NumpyFieldsProvider( + func=compute_wgtfacq_e_dsl, + deps=deps, + offsets={"e2c": dims.E2CDim}, + domain={dims.EdgeDim: (0, icon_grid.num_edges), dims.KDim: (0, icon_grid.num_levels)}, + fields=["weighting_factor_for_quadratic_interpolation_to_edge_center"], + params={"n_edges": icon_grid.num_edges, "nlev": icon_grid.num_levels}, + ) + + fields_factory.register_provider(wgtfacq_e_provider) + wgtfacq_e = fields_factory.get( + "weighting_factor_for_quadratic_interpolation_to_edge_center", factory.RetrievalType.FIELD + ) + + assert helpers.dallclose(wgtfacq_e.asnumpy(), wgtfacq_e_ref.asnumpy()) + + From fba0891bcd01cea68b7645b553da621d738a8738 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Wed, 28 Aug 2024 08:42:44 +0200 Subject: [PATCH 18/37] add type to get_processor_properties argument --- .../src/icon4py/model/common/decomposition/definitions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/decomposition/definitions.py b/model/common/src/icon4py/model/common/decomposition/definitions.py index 3405a88b0..e190b2648 100644 --- a/model/common/src/icon4py/model/common/decomposition/definitions.py +++ b/model/common/src/icon4py/model/common/decomposition/definitions.py @@ -201,7 +201,7 @@ def get_runtype(with_mpi: bool = False) -> RunType: @functools.singledispatch -def get_processor_properties(runtime) -> ProcessProperties: +def get_processor_properties(runtime:RunType) -> ProcessProperties: raise TypeError(f"Cannot define ProcessProperties for ({type(runtime)})") From c2c250a7fe0d563192b1210685fefd318c1286a0 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Wed, 28 Aug 2024 08:43:43 +0200 Subject: [PATCH 19/37] add c_lin_e metadata --- model/common/src/icon4py/model/common/states/metadata.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/metadata.py b/model/common/src/icon4py/model/common/states/metadata.py index 7e1f3773f..30df9e9b9 100644 --- a/model/common/src/icon4py/model/common/states/metadata.py +++ b/model/common/src/icon4py/model/common/states/metadata.py @@ -69,13 +69,12 @@ icon_var_name="wgtfacq_e_dsl", long_name="weighting factor for quadratic interpolation to edge centers", ), - # TODO : FIX - "c_lin_e": dict( - standard_name="c_lin_e", + "cell_to_edge_interpolation_coefficient": dict( + standard_name="cell_to_edge_interpolation_coefficient", units="", dims=(dims.EdgeDim, dims.E2CDim), dtype=ta.wpfloat, icon_var_name="c_lin_e", - long_name="interpolation field", + long_name="coefficients for cell to edge interpolation", ), } From 04645e0c218db2ecd8d5e47f7d570cad3e4fe2f5 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Wed, 28 Aug 2024 08:46:52 +0200 Subject: [PATCH 20/37] start_index, end_index abstraction for vertical (WIP) --- .../src/icon4py/model/common/grid/vertical.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/grid/vertical.py b/model/common/src/icon4py/model/common/grid/vertical.py index 87da98fd4..dcce2407b 100644 --- a/model/common/src/icon4py/model/common/grid/vertical.py +++ b/model/common/src/icon4py/model/common/grid/vertical.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses +import enum import logging import math import pathlib @@ -21,6 +22,21 @@ log = logging.getLogger(__name__) +class VerticalZone(enum.IntEnum): + FULL = 0 + DAMPING_HEIGHT = 1 + +@dataclasses.dataclass(frozen=True) +class VerticalDomain: + dim: dims.KDim + zone: VerticalZone + + + + + + # TODO (@halungge) add as needed + @dataclasses.dataclass(frozen=True) class VerticalGridConfig: """ @@ -74,7 +90,7 @@ class VerticalGridParams: _start_index_for_moist_physics: Final[gtx.int32] = dataclasses.field(init=False) _end_index_of_flat_layer: Final[gtx.int32] = dataclasses.field(init=False) _min_index_flat_horizontal_grad_pressure: Final[gtx.int32] = None - + def __post_init__(self, vertical_config, vct_a, vct_b): object.__setattr__( self, @@ -123,6 +139,16 @@ def __str__(self): vertical_params_properties.extend(array_value) return "\n".join(vertical_params_properties) + def start_index(self, domain:VerticalDomain): + return self._end_index_of_damping_layer if domain.zone == VerticalZone.DAMPING_HEIGHT else 0 + + + + def end_index(self, domain:VerticalDomain): + num_levels = self.vertical_config.num_levels if domain.dim == dims.KDim else self.vertical_config.num_levels + 1 + return self._end_index_of_damping_layer if domain.zone == VerticalZone.DAMPING_HEIGHT else gtx.int32(num_levels) + + @property def metadata_interface_physical_height(self): return dict( From 306b761b08eb5faf6f1538c9b8df8307ebb7b7ee Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Wed, 28 Aug 2024 11:10:51 +0200 Subject: [PATCH 21/37] basic sample of factory. --- .../model/common/metrics/metrics_factory.py | 60 +++++++++++++++++++ .../metric_tests/test_metrics_factory.py | 10 ++++ 2 files changed, 70 insertions(+) create mode 100644 model/common/src/icon4py/model/common/metrics/metrics_factory.py create mode 100644 model/common/tests/metric_tests/test_metrics_factory.py diff --git a/model/common/src/icon4py/model/common/metrics/metrics_factory.py b/model/common/src/icon4py/model/common/metrics/metrics_factory.py new file mode 100644 index 000000000..4ad4aabcc --- /dev/null +++ b/model/common/src/icon4py/model/common/metrics/metrics_factory.py @@ -0,0 +1,60 @@ +import pathlib + +import icon4py.model.common.states.factory as factory +from icon4py.model.common import dimension as dims +from icon4py.model.common.decomposition import definitions as decomposition +from icon4py.model.common.grid import horizontal +from icon4py.model.common.metrics import metric_fields as mf +from icon4py.model.common.test_utils import datatest_utils as dt_utils, serialbox_utils as sb + + +# we need to register a couple of fields from the serializer. Those should get replaced one by one. + +dt_utils.TEST_DATA_ROOT = pathlib.Path(__file__).parent / "testdata" +properties = decomposition.get_processor_properties(decomposition.get_runtype(with_mpi=False)) +path = dt_utils.get_ranked_data_path(dt_utils.SERIALIZED_DATA_PATH, properties) + +data_provider = sb.IconSerialDataProvider( + "icon_pydycore", str(path.absolute()), False, mpi_rank=properties.rank + ) + +# z_ifc (computable from vertical grid for model without topography) +metrics_savepoint = data_provider.from_metrics_savepoint() + +#interpolation fields also for now passing as precomputed fields +interpolation_savepoint = data_provider.from_interpolation_savepoint() +#can get geometry fields as pre computed fields from the grid_savepoint +grid_savepoint = data_provider.from_savepoint_grid() +####### + +# start build up factory: + + +interface_model_height = metrics_savepoint.z_ifc() +c_lin_e = interpolation_savepoint.c_lin_e() + +fields_factory = factory.FieldsFactory() + +# used for vertical domain below: should go away once vertical grid provids start_index and end_index like interface +grid = grid_savepoint.global_grid_params + +fields_factory.register_provider( + factory.PrecomputedFieldsProvider( + { + "height_on_interface_levels": interface_model_height, + "cell_to_edge_interpolation_coefficient": c_lin_e, + } + ) +) +height_provider = factory.ProgramFieldProvider( + func=mf.compute_z_mc, + domain={ + dims.CellDim: ( + horizontal.HorizontalMarkerIndex.local(dims.CellDim), + horizontal.HorizontalMarkerIndex.end(dims.CellDim), + ), + dims.KDim: (0, grid.num_levels), + }, + fields={"z_mc": "height"}, + deps={"z_ifc": "height_on_interface_levels"}, + ) diff --git a/model/common/tests/metric_tests/test_metrics_factory.py b/model/common/tests/metric_tests/test_metrics_factory.py new file mode 100644 index 000000000..d731d1aa4 --- /dev/null +++ b/model/common/tests/metric_tests/test_metrics_factory.py @@ -0,0 +1,10 @@ + +import icon4py.model.common.settings as settings +from icon4py.model.common.metrics import metrics_factory + + +def test_factory(icon_grid): + + factory = metrics_factory.fields_factory + factory.with_grid(icon_grid).with_allocator(settings.backend) + factory.get("height_on_interface_levels", metrics_factory.RetrievalType.FIELD) \ No newline at end of file From cec01f9d320b46cacc5c8cacfe291829be677ae2 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 29 Aug 2024 12:06:11 +0200 Subject: [PATCH 22/37] fix with_allocator function --- model/common/src/icon4py/model/common/states/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index c0d8b9a7a..5abb42563 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -264,7 +264,7 @@ def with_grid(self, grid: base_grid.BaseGrid): @builder.builder def with_allocator(self, backend=settings.backend): - self._allocator = backend + self._allocator = gtx.constructors.zeros.partial(allocator=backend) @property def grid(self): From aa2c402faee03f66e179108bfba53c203ae20ada Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:00:29 +0200 Subject: [PATCH 23/37] ran pre-commit and made fixes --- .../model/common/decomposition/definitions.py | 2 +- .../src/icon4py/model/common/grid/vertical.py | 23 +++++++---- .../model/common/metrics/metrics_factory.py | 38 +++++++++++-------- .../metric_tests/test_metrics_factory.py | 10 ++++- 4 files changed, 47 insertions(+), 26 deletions(-) diff --git a/model/common/src/icon4py/model/common/decomposition/definitions.py b/model/common/src/icon4py/model/common/decomposition/definitions.py index e190b2648..5b4a84f82 100644 --- a/model/common/src/icon4py/model/common/decomposition/definitions.py +++ b/model/common/src/icon4py/model/common/decomposition/definitions.py @@ -201,7 +201,7 @@ def get_runtype(with_mpi: bool = False) -> RunType: @functools.singledispatch -def get_processor_properties(runtime:RunType) -> ProcessProperties: +def get_processor_properties(runtime: RunType) -> ProcessProperties: raise TypeError(f"Cannot define ProcessProperties for ({type(runtime)})") diff --git a/model/common/src/icon4py/model/common/grid/vertical.py b/model/common/src/icon4py/model/common/grid/vertical.py index e1c533313..f1feccf6a 100644 --- a/model/common/src/icon4py/model/common/grid/vertical.py +++ b/model/common/src/icon4py/model/common/grid/vertical.py @@ -147,15 +147,22 @@ def __str__(self): vertical_params_properties.extend(array_value) return "\n".join(vertical_params_properties) - def start_index(self, domain:VerticalDomain): - return self._end_index_of_damping_layer if domain.zone == VerticalZone.DAMPING_HEIGHT else 0 - - - - def end_index(self, domain:VerticalDomain): - num_levels = self.vertical_config.num_levels if domain.dim == dims.KDim else self.vertical_config.num_levels + 1 - return self._end_index_of_damping_layer if domain.zone == VerticalZone.DAMPING_HEIGHT else gtx.int32(num_levels) + def start_index(self, domain: Domain): + return ( + self._end_index_of_damping_layer + if domain.zone == self.config.rayleigh_damping_height + else 0 + ) + def end_index(self, domain: Domain): + num_levels = ( + self.config.num_levels if domain.dim == dims.KDim else self.config.num_levels + 1 + ) + return ( + self._end_index_of_damping_layer + if domain.zone == self.config.rayleigh_damping_height + else gtx.int32(num_levels) + ) @property def metadata_interface_physical_height(self): diff --git a/model/common/src/icon4py/model/common/metrics/metrics_factory.py b/model/common/src/icon4py/model/common/metrics/metrics_factory.py index 4ad4aabcc..58a28a0f7 100644 --- a/model/common/src/icon4py/model/common/metrics/metrics_factory.py +++ b/model/common/src/icon4py/model/common/metrics/metrics_factory.py @@ -1,3 +1,11 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + import pathlib import icon4py.model.common.states.factory as factory @@ -15,15 +23,15 @@ path = dt_utils.get_ranked_data_path(dt_utils.SERIALIZED_DATA_PATH, properties) data_provider = sb.IconSerialDataProvider( - "icon_pydycore", str(path.absolute()), False, mpi_rank=properties.rank - ) + "icon_pydycore", str(path.absolute()), False, mpi_rank=properties.rank +) # z_ifc (computable from vertical grid for model without topography) metrics_savepoint = data_provider.from_metrics_savepoint() -#interpolation fields also for now passing as precomputed fields +# interpolation fields also for now passing as precomputed fields interpolation_savepoint = data_provider.from_interpolation_savepoint() -#can get geometry fields as pre computed fields from the grid_savepoint +# can get geometry fields as pre computed fields from the grid_savepoint grid_savepoint = data_provider.from_savepoint_grid() ####### @@ -47,14 +55,14 @@ ) ) height_provider = factory.ProgramFieldProvider( - func=mf.compute_z_mc, - domain={ - dims.CellDim: ( - horizontal.HorizontalMarkerIndex.local(dims.CellDim), - horizontal.HorizontalMarkerIndex.end(dims.CellDim), - ), - dims.KDim: (0, grid.num_levels), - }, - fields={"z_mc": "height"}, - deps={"z_ifc": "height_on_interface_levels"}, - ) + func=mf.compute_z_mc, + domain={ + dims.CellDim: ( + horizontal.HorizontalMarkerIndex.local(dims.CellDim), + horizontal.HorizontalMarkerIndex.end(dims.CellDim), + ), + dims.KDim: (0, grid.num_levels), + }, + fields={"z_mc": "height"}, + deps={"z_ifc": "height_on_interface_levels"}, +) diff --git a/model/common/tests/metric_tests/test_metrics_factory.py b/model/common/tests/metric_tests/test_metrics_factory.py index d731d1aa4..97a3f6f76 100644 --- a/model/common/tests/metric_tests/test_metrics_factory.py +++ b/model/common/tests/metric_tests/test_metrics_factory.py @@ -1,10 +1,16 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause import icon4py.model.common.settings as settings from icon4py.model.common.metrics import metrics_factory def test_factory(icon_grid): - factory = metrics_factory.fields_factory factory.with_grid(icon_grid).with_allocator(settings.backend) - factory.get("height_on_interface_levels", metrics_factory.RetrievalType.FIELD) \ No newline at end of file + factory.get("height_on_interface_levels", metrics_factory.RetrievalType.FIELD) From afe3f47100d89368dcb2267f93a9002aa22abd11 Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:33:00 +0200 Subject: [PATCH 24/37] small edit --- model/common/src/icon4py/model/common/grid/vertical.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/common/src/icon4py/model/common/grid/vertical.py b/model/common/src/icon4py/model/common/grid/vertical.py index f1feccf6a..c9b1ec787 100644 --- a/model/common/src/icon4py/model/common/grid/vertical.py +++ b/model/common/src/icon4py/model/common/grid/vertical.py @@ -150,7 +150,7 @@ def __str__(self): def start_index(self, domain: Domain): return ( self._end_index_of_damping_layer - if domain.zone == self.config.rayleigh_damping_height + if domain.zone.DAMPING == self.config.rayleigh_damping_height else 0 ) @@ -160,7 +160,7 @@ def end_index(self, domain: Domain): ) return ( self._end_index_of_damping_layer - if domain.zone == self.config.rayleigh_damping_height + if domain.zone.DAMPING == self.config.rayleigh_damping_height else gtx.int32(num_levels) ) From 8f8d8de7dbcdc19459e4fbc45997d51b647b87eb Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 5 Sep 2024 21:57:15 +0200 Subject: [PATCH 25/37] using domains for the compute domain in factory --- .../src/icon4py/model/common/grid/vertical.py | 31 ++++---- .../icon4py/model/common/states/factory.py | 57 +++++++++++---- .../common/tests/states_test/test_factory.py | 73 +++++++++++++------ 3 files changed, 108 insertions(+), 53 deletions(-) diff --git a/model/common/src/icon4py/model/common/grid/vertical.py b/model/common/src/icon4py/model/common/grid/vertical.py index c9b1ec787..9e4b37662 100644 --- a/model/common/src/icon4py/model/common/grid/vertical.py +++ b/model/common/src/icon4py/model/common/grid/vertical.py @@ -43,10 +43,18 @@ class Domain: Simple data class used to specify a vertical domain such that index lookup and domain specification can be separated. """ - dim: dims.KDim + dim: gtx.Dimension marker: Zone +def domain(dim: gtx.Dimension): + def _domain(marker: Zone): + assert dim.kind == gtx.DimensionKind.VERTICAL, "Only vertical dimensions are supported" + return Domain(dim, marker) + + return _domain + + @dataclasses.dataclass(frozen=True) class VerticalGridConfig: """ @@ -147,23 +155,6 @@ def __str__(self): vertical_params_properties.extend(array_value) return "\n".join(vertical_params_properties) - def start_index(self, domain: Domain): - return ( - self._end_index_of_damping_layer - if domain.zone.DAMPING == self.config.rayleigh_damping_height - else 0 - ) - - def end_index(self, domain: Domain): - num_levels = ( - self.config.num_levels if domain.dim == dims.KDim else self.config.num_levels + 1 - ) - return ( - self._end_index_of_damping_layer - if domain.zone.DAMPING == self.config.rayleigh_damping_height - else gtx.int32(num_levels) - ) - @property def metadata_interface_physical_height(self): return dict( @@ -174,6 +165,10 @@ def metadata_interface_physical_height(self): icon_var_name="vct_a", ) + @property + def num_levels(self): + return self.config.num_levels + def index(self, domain: Domain) -> gtx.int32: match domain.marker: case Zone.TOP: diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 5abb42563..d17735e59 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -10,19 +10,36 @@ import enum import functools import inspect -from typing import Callable, Iterable, Optional, Protocol, Sequence, Union, get_args +from typing import ( + Callable, + Iterable, + Optional, + Protocol, + Sequence, + TypeVar, + Union, + get_args, +) import gt4py.next as gtx import gt4py.next.ffront.decorator as gtx_decorator import xarray as xa from icon4py.model.common import dimension as dims, exceptions, settings -from icon4py.model.common.grid import base as base_grid, icon as icon_grid +from icon4py.model.common.grid import ( + base as base_grid, + horizontal as h_grid, + icon as icon_grid, + vertical as v_grid, +) from icon4py.model.common.settings import xp from icon4py.model.common.states import metadata as metadata, utils as state_utils from icon4py.model.common.utils import builder +DomainType = TypeVar("DomainType", h_grid.Domain, v_grid.Domain) + + class RetrievalType(enum.IntEnum): FIELD = (0,) DATA_ARRAY = (1,) @@ -104,9 +121,7 @@ class ProgramFieldProvider(FieldProvider): def __init__( self, func: gtx_decorator.Program, - domain: dict[ - gtx.Dimension : tuple[Callable[[gtx.Dimension], int], Callable[[gtx.Dimension], int]] - ], + domain: dict[gtx.Dimension : tuple[DomainType, DomainType]], fields: dict[str:str], deps: dict[str, str], params: Optional[dict[str, state_utils.Scalar]] = None, @@ -142,21 +157,24 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension: for k in self._fields.keys() } - def _domain_args(self, grid: icon_grid.IconGrid) -> dict[str : gtx.int32]: + def _domain_args( + self, grid: icon_grid.IconGrid, vertical_grid: v_grid.VerticalGrid + ) -> dict[str : gtx.int32]: domain_args = {} + for dim in self._compute_domain: if dim.kind == gtx.DimensionKind.HORIZONTAL: domain_args.update( { - "horizontal_start": grid.get_start_index(dim, self._compute_domain[dim][0]), - "horizontal_end": grid.get_end_index(dim, self._compute_domain[dim][1]), + "horizontal_start": grid.start_index(self._compute_domain[dim][0]), + "horizontal_end": grid.end_index(self._compute_domain[dim][1]), } ) elif dim.kind == gtx.DimensionKind.VERTICAL: domain_args.update( { - "vertical_start": self._compute_domain[dim][0], - "vertical_end": self._compute_domain[dim][1], + "vertical_start": vertical_grid.index(self._compute_domain[dim][0]), + "vertical_end": vertical_grid.index(self._compute_domain[dim][1]), } ) else: @@ -168,7 +186,7 @@ def evaluate(self, factory: "FieldsFactory"): deps = {k: factory.get(v) for k, v in self._dependencies.items()} deps.update(self._params) deps.update({k: self._fields[v] for k, v in self._output.items()}) - dims = self._domain_args(factory.grid) + dims = self._domain_args(factory.grid, factory.vertical_grid) deps.update(dims) self._func(**deps, offset_provider=factory.grid.offset_providers) @@ -180,7 +198,7 @@ class NumpyFieldsProvider(FieldProvider): def __init__( self, func: Callable, - domain: dict[gtx.Dimension : tuple[gtx.int32, gtx.int32]], + domain: dict[gtx.Dimension : tuple[DomainType, DomainType]], fields: Sequence[str], deps: dict[str, str], offsets: Optional[dict[str, gtx.Dimension]] = None, @@ -250,8 +268,14 @@ class FieldsFactory: Lazily compute fields and cache them. """ - def __init__(self, grid: icon_grid.IconGrid = None, backend=settings.backend): + def __init__( + self, + grid: icon_grid.IconGrid = None, + vertical_grid: v_grid.VerticalGrid = None, + backend=settings.backend, + ): self._grid = grid + self._vertical = vertical_grid self._providers: dict[str, "FieldProvider"] = {} self._allocator = gtx.constructors.zeros.partial(allocator=backend) @@ -259,8 +283,9 @@ def validate(self): return self._grid is not None @builder.builder - def with_grid(self, grid: base_grid.BaseGrid): + def with_grid(self, grid: base_grid.BaseGrid, vertical_grid: v_grid.VerticalGrid): self._grid = grid + self._vertical = vertical_grid @builder.builder def with_allocator(self, backend=settings.backend): @@ -270,6 +295,10 @@ def with_allocator(self, backend=settings.backend): def grid(self): return self._grid + @property + def vertical_grid(self): + return self._vertical + @property def allocator(self): return self._allocator diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index de13792a9..8a980c233 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -11,7 +11,7 @@ import icon4py.model.common.test_utils.helpers as helpers from icon4py.model.common import dimension as dims, exceptions -from icon4py.model.common.grid.horizontal import HorizontalMarkerIndex +from icon4py.model.common.grid import horizontal as h_grid, vertical as v_grid from icon4py.model.common.io import cf_utils from icon4py.model.common.metrics import metric_fields as mf from icon4py.model.common.metrics.compute_wgtfacq import ( @@ -22,15 +22,24 @@ from icon4py.model.common.states import factory +cell_domain = h_grid.domain(dims.CellDim) +full_level = v_grid.domain(dims.KDim) +interface_level = v_grid.domain(dims.KHalfDim) + + @pytest.mark.datatest def test_factory_check_dependencies_on_register(icon_grid, backend): fields_factory = factory.FieldsFactory(icon_grid, backend) provider = factory.ProgramFieldProvider( func=mf.compute_z_mc, - domain={dims.CellDim: (0, icon_grid.num_cells), dims.KDim: (0, icon_grid.num_levels)}, + domain={ + dims.CellDim: (cell_domain(h_grid.Zone.LOCAL), cell_domain(h_grid.Zone.END)), + dims.KDim: (full_level(v_grid.Zone.TOP), full_level(v_grid.Zone.BOTTOM)), + }, fields={"z_mc": "height"}, deps={"z_ifc": "height_on_interface_levels"}, ) + with pytest.raises(ValueError) as e: fields_factory.register_provider(provider) assert e.value.match("'height_on_interface_levels' not found") @@ -51,17 +60,24 @@ def test_factory_raise_error_if_no_grid_is_set(metrics_savepoint): @pytest.mark.datatest -def test_factory_returns_field(metrics_savepoint, icon_grid, backend): +def test_factory_returns_field(grid_savepoint, metrics_savepoint, backend): z_ifc = metrics_savepoint.z_ifc() - k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) + grid = grid_savepoint.construct_icon_grid(on_gpu=False) # TODO: determine from backend + num_levels = grid_savepoint.num(dims.KDim) + vertical = v_grid.VerticalGrid( + v_grid.VerticalGridConfig(num_levels=num_levels), + grid_savepoint.vct_a(), + grid_savepoint.vct_b(), + ) + k_index = gtx.as_field((dims.KDim,), xp.arange(num_levels + 1, dtype=gtx.int32)) pre_computed_fields = factory.PrecomputedFieldsProvider( {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} ) fields_factory = factory.FieldsFactory() fields_factory.register_provider(pre_computed_fields) - fields_factory.with_grid(icon_grid).with_allocator(backend) + fields_factory.with_grid(grid, vertical).with_allocator(backend) field = fields_factory.get("height_on_interface_levels", factory.RetrievalType.FIELD) - assert field.ndarray.shape == (icon_grid.num_cells, icon_grid.num_levels + 1) + assert field.ndarray.shape == (grid.num_cells, num_levels + 1) meta = fields_factory.get("height_on_interface_levels", factory.RetrievalType.METADATA) assert meta["standard_name"] == "height_on_interface_levels" assert meta["dims"] == ( @@ -70,18 +86,31 @@ def test_factory_returns_field(metrics_savepoint, icon_grid, backend): ) assert meta["units"] == "m" data_array = fields_factory.get("height_on_interface_levels", factory.RetrievalType.DATA_ARRAY) - assert data_array.data.shape == (icon_grid.num_cells, icon_grid.num_levels + 1) + assert data_array.data.shape == (grid.num_cells, num_levels + 1) assert data_array.data.dtype == xp.float64 for key in ("dims", "standard_name", "units", "icon_var_name"): assert key in data_array.attrs.keys() @pytest.mark.datatest -def test_field_provider_for_program(icon_grid, metrics_savepoint, backend): - fields_factory = factory.FieldsFactory(icon_grid, backend) - k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) +def test_field_provider_for_program(grid_savepoint, metrics_savepoint, backend): + horizontal_grid = grid_savepoint.construct_icon_grid( + on_gpu=False + ) # TODO: determine from backend + num_levels = grid_savepoint.num(dims.KDim) + vct_a = grid_savepoint.vct_a() + vct_b = grid_savepoint.vct_b() + vertical_grid = v_grid.VerticalGrid( + v_grid.VerticalGridConfig(num_levels=num_levels), vct_a, vct_b + ) + + fields_factory = factory.FieldsFactory() + k_index = gtx.as_field((dims.KDim,), xp.arange(num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() + local_cell_domain = cell_domain(h_grid.Zone.LOCAL) + end_cell_domain = cell_domain(h_grid.Zone.END) + pre_computed_fields = factory.PrecomputedFieldsProvider( {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} ) @@ -92,10 +121,10 @@ def test_field_provider_for_program(icon_grid, metrics_savepoint, backend): func=mf.compute_z_mc, domain={ dims.CellDim: ( - HorizontalMarkerIndex.local(dims.CellDim), - HorizontalMarkerIndex.end(dims.CellDim), + local_cell_domain, + end_cell_domain, ), - dims.KDim: (0, icon_grid.num_levels), + dims.KDim: (full_level(v_grid.Zone.TOP), full_level(v_grid.Zone.BOTTOM)), }, fields={"z_mc": "height"}, deps={"z_ifc": "height_on_interface_levels"}, @@ -105,10 +134,10 @@ def test_field_provider_for_program(icon_grid, metrics_savepoint, backend): func=mf.compute_ddqz_z_half, domain={ dims.CellDim: ( - HorizontalMarkerIndex.local(dims.CellDim), - HorizontalMarkerIndex.end(dims.CellDim), + local_cell_domain, + end_cell_domain, ), - dims.KHalfDim: (0, icon_grid.num_levels + 1), + dims.KHalfDim: (interface_level(v_grid.Zone.TOP), interface_level(v_grid.Zone.BOTTOM)), }, fields={"ddqz_z_half": "functional_determinant_of_metrics_on_interface_levels"}, deps={ @@ -116,9 +145,10 @@ def test_field_provider_for_program(icon_grid, metrics_savepoint, backend): "z_mc": "height", "k": cf_utils.INTERFACE_LEVEL_STANDARD_NAME, }, - params={"nlev": icon_grid.num_levels}, + params={"nlev": vertical_grid.num_levels}, ) fields_factory.register_provider(functional_determinant_provider) + fields_factory.with_grid(horizontal_grid, vertical_grid).with_allocator(backend) data = fields_factory.get( "functional_determinant_of_metrics_on_interface_levels", type_=factory.RetrievalType.FIELD ) @@ -144,8 +174,8 @@ def test_field_provider_for_numpy_function( compute_wgtfacq_c_provider = factory.NumpyFieldsProvider( func=func, domain={ - dims.CellDim: (0, HorizontalMarkerIndex.end(dims.CellDim)), - dims.KDim: (0, icon_grid.num_levels), + dims.CellDim: (cell_domain(h_grid.Zone.LOCAL), cell_domain(h_grid.Zone.END)), + dims.KDim: (interface_level(v_grid.Zone.TOP), interface_level(v_grid.Zone.BOTTOM)), }, fields=["weighting_factor_for_quadratic_interpolation_to_cell_surface"], deps=deps, @@ -173,11 +203,12 @@ def test_field_provider_for_numpy_function_with_offsets( { "height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index, - "c_lin_e": c_lin_e, + "cell_to_edge_interpolation_coefficient": c_lin_e, } ) fields_factory.register_provider(pre_computed_fields) func = compute_wgtfacq_c_dsl + # TODO (magdalena): need to fix this for parameters params = {"nlev": icon_grid.num_levels} compute_wgtfacq_c_provider = factory.NumpyFieldsProvider( func=func, @@ -189,7 +220,7 @@ def test_field_provider_for_numpy_function_with_offsets( deps = { "z_ifc": "height_on_interface_levels", "wgtfacq_c_dsl": "weighting_factor_for_quadratic_interpolation_to_cell_surface", - "c_lin_e": "c_lin_e", + "c_lin_e": "cell_to_edge_interpolation_coefficient", } fields_factory.register_provider(compute_wgtfacq_c_provider) wgtfacq_e_provider = factory.NumpyFieldsProvider( From e1ec5312f306263a0176e2928be3eef65914f6e7 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Tue, 10 Sep 2024 11:09:50 +0200 Subject: [PATCH 26/37] add docstring to Providers --- .../src/icon4py/model/common/states/factory.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index d17735e59..48428ead2 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -116,6 +116,12 @@ class ProgramFieldProvider(FieldProvider): """ Computes a field defined by a GT4Py Program. + Args: + func: GT4Py Program that computes the fields + domain: the compute domain used for the stencil computation + fields: dict[str, str], fields produced by this stencils: the key is the variable name of the out arguments used in the program and the value the name the field is registered under and declared in the metadata. + deps: dict[str, str], input fields used for computing this stencil: the key is the variable name used in the program and the value the name of the field it depends on. + params: scalar parameters used in the program """ def __init__( @@ -195,6 +201,17 @@ def fields(self) -> Iterable[str]: class NumpyFieldsProvider(FieldProvider): + """ + Computes a field defined by a numpy function. + + Args: + func: numpy function that computes the fields + domain: the compute domain used for the stencil computation + fields: Seq[str] names under which the results fo the function will be registered + deps: dict[str, str] input fields used for computing this stencil: the key is the variable name used in the program and the value the name of the field it depends on. + params: scalar arguments for the function + """ + def __init__( self, func: Callable, From 62c21ae50d050ef6e547343a14d798b4a8204d6d Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Fri, 13 Sep 2024 14:22:45 +0200 Subject: [PATCH 27/37] separate vertical and horizontal connectivities --- .../src/icon4py/model/common/grid/vertical.py | 9 +-- .../icon4py/model/common/states/factory.py | 26 +++++++- .../icon4py/model/common/states/metadata.py | 37 +++++++++++ .../common/tests/states_test/test_factory.py | 65 ++++++++++++++++++- 4 files changed, 125 insertions(+), 12 deletions(-) diff --git a/model/common/src/icon4py/model/common/grid/vertical.py b/model/common/src/icon4py/model/common/grid/vertical.py index 9e4b37662..a9750306b 100644 --- a/model/common/src/icon4py/model/common/grid/vertical.py +++ b/model/common/src/icon4py/model/common/grid/vertical.py @@ -16,6 +16,7 @@ import gt4py.next as gtx +import icon4py.model.common.states.metadata as data from icon4py.model.common import dimension as dims, field_type_aliases as fa from icon4py.model.common.settings import xp @@ -157,13 +158,7 @@ def __str__(self): @property def metadata_interface_physical_height(self): - return dict( - standard_name="model_interface_height", - long_name="height value of half levels without topography", - units="m", - positive="up", - icon_var_name="vct_a", - ) + return data.attrs["model_interface_height"] @property def num_levels(self): diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 48428ead2..2efc52075 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -163,6 +163,19 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension: for k in self._fields.keys() } + # TODO (@halungge) this can be simplified when completely disentangling vertical and horizontal grid. + # the IconGrid should then only contain horizontal connectivities and no longer any Koff which should be moved to the VerticalGrid + def _get_offset_providers(self, grid:icon_grid.IconGrid, vertical_grid:v_grid.VerticalGrid) -> dict[str, gtx.FieldOffset]: + offset_providers = {} + for dim in self._compute_domain.keys(): + if dim.kind == gtx.DimensionKind.HORIZONTAL: + horizontal_offsets = {k:v for k , v in grid.offset_providers.items() if isinstance(v, gtx.NeighborTableOffsetProvider) and v.origin_axis.kind == gtx.DimensionKind.HORIZONTAL} + offset_providers.update(horizontal_offsets) + if dim.kind == gtx.DimensionKind.VERTICAL: + vertical_offsets = {k:v for k , v in grid.offset_providers.items() if isinstance(v, gtx.Dimension) and v.kind == gtx.DimensionKind.VERTICAL} + offset_providers.update(vertical_offsets) + return offset_providers + def _domain_args( self, grid: icon_grid.IconGrid, vertical_grid: v_grid.VerticalGrid ) -> dict[str : gtx.int32]: @@ -193,13 +206,16 @@ def evaluate(self, factory: "FieldsFactory"): deps.update(self._params) deps.update({k: self._fields[v] for k, v in self._output.items()}) dims = self._domain_args(factory.grid, factory.vertical_grid) + offset_providers =self._get_offset_providers(factory.grid, factory.vertical_grid) deps.update(dims) - self._func(**deps, offset_provider=factory.grid.offset_providers) + self._func.with_backend(factory._backend)(**deps, offset_provider=offset_providers) def fields(self) -> Iterable[str]: return self._output.values() + + class NumpyFieldsProvider(FieldProvider): """ Computes a field defined by a numpy function. @@ -294,6 +310,7 @@ def __init__( self._grid = grid self._vertical = vertical_grid self._providers: dict[str, "FieldProvider"] = {} + self._backend = backend self._allocator = gtx.constructors.zeros.partial(allocator=backend) def validate(self): @@ -305,9 +322,14 @@ def with_grid(self, grid: base_grid.BaseGrid, vertical_grid: v_grid.VerticalGrid self._vertical = vertical_grid @builder.builder - def with_allocator(self, backend=settings.backend): + def with_backend(self, backend=settings.backend): + self._backend = backend self._allocator = gtx.constructors.zeros.partial(allocator=backend) + @property + def backend(self): + return self._backend + @property def grid(self): return self._grid diff --git a/model/common/src/icon4py/model/common/states/metadata.py b/model/common/src/icon4py/model/common/states/metadata.py index 30df9e9b9..052b83302 100644 --- a/model/common/src/icon4py/model/common/states/metadata.py +++ b/model/common/src/icon4py/model/common/states/metadata.py @@ -77,4 +77,41 @@ icon_var_name="c_lin_e", long_name="coefficients for cell to edge interpolation", ), + "scaling_factor_for_3d_divergence_damping": dict( + standard_name="scaling_factor_for_3d_divergence_damping", + units="", + dims=(dims.KDim), + dtype=ta.wpfloat, + icon_var_name="scalfac_dd3d", + long_name="Scaling factor for 3D divergence damping terms", + ), + "model_interface_height": + dict( + standard_name="model_interface_height", + long_name="height value of half levels without topography", + units="m", + dims = (dims.KHalfDim,), + dtype=ta.wpfloat, + positive="up", + icon_var_name="vct_a", + ), + "nudging_coefficient_on_edges": + dict( + standard_name="nudging_coefficient_on_edges", + long_name="nudging coefficients on edges", + units="", + dtype = ta.wpfloat, + dims = (dims.EdgeDim,), + icon_var_name="nudgecoeff_e", + ), + "refin_e_ctrl": + dict( + standard_name="refin_e_ctrl", + long_name="grid refinement control on edgeds", + units="", + dtype = int, + dims = (dims.EdgeDim,), + icon_var_name="refin_e_ctrl", + ) + } diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index 8a980c233..b76cd9126 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -8,12 +8,13 @@ import gt4py.next as gtx import pytest +from common.tests.metric_tests.test_metric_fields import edge_domain import icon4py.model.common.test_utils.helpers as helpers from icon4py.model.common import dimension as dims, exceptions from icon4py.model.common.grid import horizontal as h_grid, vertical as v_grid from icon4py.model.common.io import cf_utils -from icon4py.model.common.metrics import metric_fields as mf +from icon4py.model.common.metrics import compute_nudgecoeffs, metric_fields as mf from icon4py.model.common.metrics.compute_wgtfacq import ( compute_wgtfacq_c_dsl, compute_wgtfacq_e_dsl, @@ -75,7 +76,7 @@ def test_factory_returns_field(grid_savepoint, metrics_savepoint, backend): ) fields_factory = factory.FieldsFactory() fields_factory.register_provider(pre_computed_fields) - fields_factory.with_grid(grid, vertical).with_allocator(backend) + fields_factory.with_grid(grid, vertical).with_backend(backend) field = fields_factory.get("height_on_interface_levels", factory.RetrievalType.FIELD) assert field.ndarray.shape == (grid.num_cells, num_levels + 1) meta = fields_factory.get("height_on_interface_levels", factory.RetrievalType.METADATA) @@ -148,7 +149,7 @@ def test_field_provider_for_program(grid_savepoint, metrics_savepoint, backend): params={"nlev": vertical_grid.num_levels}, ) fields_factory.register_provider(functional_determinant_provider) - fields_factory.with_grid(horizontal_grid, vertical_grid).with_allocator(backend) + fields_factory.with_grid(horizontal_grid, vertical_grid).with_backend(backend) data = fields_factory.get( "functional_determinant_of_metrics_on_interface_levels", type_=factory.RetrievalType.FIELD ) @@ -238,3 +239,61 @@ def test_field_provider_for_numpy_function_with_offsets( ) assert helpers.dallclose(wgtfacq_e.asnumpy(), wgtfacq_e_ref.asnumpy()) + + +def test_factory_for_k_only_field(icon_grid, metrics_savepoint, grid_savepoint, backend): + fields_factory = factory.FieldsFactory() + vct_a = grid_savepoint.vct_a() + divdamp_trans_start = 12500.0 + divdamp_trans_end = 17500.0 + divdamp_type = 3 + pre_computed_fields = factory.PrecomputedFieldsProvider({"model_interface_height": vct_a}) + fields_factory.register_provider(pre_computed_fields) + vertical_grid = v_grid.VerticalGrid(v_grid.VerticalGridConfig(grid_savepoint.num(dims.KDim)), + grid_savepoint.vct_a(), grid_savepoint.vct_b()) + provider = factory.ProgramFieldProvider( + func=mf.compute_scalfac_dd3d, + domain={ + dims.KDim: (full_level(v_grid.Zone.TOP), full_level(v_grid.Zone.BOTTOM)), + }, + deps={"vct_a": "model_interface_height"}, + fields={"scalfac_dd3d": "scaling_factor_for_3d_divergence_damping"}, + params={ + "divdamp_trans_start": divdamp_trans_start, + "divdamp_trans_end": divdamp_trans_end, + "divdamp_type": divdamp_type, + }, + + ) + fields_factory.register_provider(provider) + fields_factory.with_grid(icon_grid, vertical_grid).with_backend(backend) + helpers.dallclose(fields_factory.get("scaling_factor_for_3d_divergence_damping").asnumpy(), + metrics_savepoint.scalfac_dd3d().asnumpy()) + + +def test_horizontal_only_field(icon_grid, interpolation_savepoint, grid_savepoint, backend): + fields_factory = factory.FieldsFactory() + refin_ctl = grid_savepoint.refin_ctrl(dims.EdgeDim) + pre_computed_fields = factory.PrecomputedFieldsProvider({"refin_e_ctrl": refin_ctl}) + fields_factory.register_provider(pre_computed_fields) + vertical_grid = v_grid.VerticalGrid(v_grid.VerticalGridConfig(grid_savepoint.num(dims.KDim)), + grid_savepoint.vct_a(), grid_savepoint.vct_b()) + provider = factory.ProgramFieldProvider( + func=compute_nudgecoeffs.compute_nudgecoeffs, + domain={ + dims.EdgeDim: (edge_domain(h_grid.Zone.NUDGING_LEVEL_2), edge_domain(h_grid.Zone.LOCAL)), + }, + deps={"refin_ctrl": "refin_e_ctrl"}, + fields={"nudgecoeffs_e": "nudging_coefficient_on_edges"}, + params={ + "grf_nudge_start_e": 10, + "nudge_max_coeffs": 0.375, + "nudge_efold_width": 2.0, + "nudge_zone_width": 10 + }, + + ) + fields_factory.register_provider(provider) + fields_factory.with_grid(icon_grid, vertical_grid).with_backend(backend) + helpers.dallclose(fields_factory.get("nudging_coefficient_on_edges").asnumpy(), + interpolation_savepoint.nudgecoeff_e().asnumpy()) From f98f8dc49e9878f76932bccd3b6ff94908ba3ce0 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Fri, 13 Sep 2024 14:28:18 +0200 Subject: [PATCH 28/37] pre-commit --- .../icon4py/model/common/states/factory.py | 27 ++++++---- .../icon4py/model/common/states/metadata.py | 54 +++++++++---------- .../common/tests/states_test/test_factory.py | 35 +++++++----- 3 files changed, 66 insertions(+), 50 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 2efc52075..d50b04a2b 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -163,19 +163,30 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension: for k in self._fields.keys() } - # TODO (@halungge) this can be simplified when completely disentangling vertical and horizontal grid. + # TODO (@halungge) this can be simplified when completely disentangling vertical and horizontal grid. # the IconGrid should then only contain horizontal connectivities and no longer any Koff which should be moved to the VerticalGrid - def _get_offset_providers(self, grid:icon_grid.IconGrid, vertical_grid:v_grid.VerticalGrid) -> dict[str, gtx.FieldOffset]: + def _get_offset_providers( + self, grid: icon_grid.IconGrid, vertical_grid: v_grid.VerticalGrid + ) -> dict[str, gtx.FieldOffset]: offset_providers = {} for dim in self._compute_domain.keys(): if dim.kind == gtx.DimensionKind.HORIZONTAL: - horizontal_offsets = {k:v for k , v in grid.offset_providers.items() if isinstance(v, gtx.NeighborTableOffsetProvider) and v.origin_axis.kind == gtx.DimensionKind.HORIZONTAL} + horizontal_offsets = { + k: v + for k, v in grid.offset_providers.items() + if isinstance(v, gtx.NeighborTableOffsetProvider) + and v.origin_axis.kind == gtx.DimensionKind.HORIZONTAL + } offset_providers.update(horizontal_offsets) if dim.kind == gtx.DimensionKind.VERTICAL: - vertical_offsets = {k:v for k , v in grid.offset_providers.items() if isinstance(v, gtx.Dimension) and v.kind == gtx.DimensionKind.VERTICAL} + vertical_offsets = { + k: v + for k, v in grid.offset_providers.items() + if isinstance(v, gtx.Dimension) and v.kind == gtx.DimensionKind.VERTICAL + } offset_providers.update(vertical_offsets) return offset_providers - + def _domain_args( self, grid: icon_grid.IconGrid, vertical_grid: v_grid.VerticalGrid ) -> dict[str : gtx.int32]: @@ -206,7 +217,7 @@ def evaluate(self, factory: "FieldsFactory"): deps.update(self._params) deps.update({k: self._fields[v] for k, v in self._output.items()}) dims = self._domain_args(factory.grid, factory.vertical_grid) - offset_providers =self._get_offset_providers(factory.grid, factory.vertical_grid) + offset_providers = self._get_offset_providers(factory.grid, factory.vertical_grid) deps.update(dims) self._func.with_backend(factory._backend)(**deps, offset_provider=offset_providers) @@ -214,8 +225,6 @@ def fields(self) -> Iterable[str]: return self._output.values() - - class NumpyFieldsProvider(FieldProvider): """ Computes a field defined by a numpy function. @@ -329,7 +338,7 @@ def with_backend(self, backend=settings.backend): @property def backend(self): return self._backend - + @property def grid(self): return self._grid diff --git a/model/common/src/icon4py/model/common/states/metadata.py b/model/common/src/icon4py/model/common/states/metadata.py index 052b83302..ab0fd1726 100644 --- a/model/common/src/icon4py/model/common/states/metadata.py +++ b/model/common/src/icon4py/model/common/states/metadata.py @@ -85,33 +85,29 @@ icon_var_name="scalfac_dd3d", long_name="Scaling factor for 3D divergence damping terms", ), - "model_interface_height": - dict( - standard_name="model_interface_height", - long_name="height value of half levels without topography", - units="m", - dims = (dims.KHalfDim,), - dtype=ta.wpfloat, - positive="up", - icon_var_name="vct_a", - ), - "nudging_coefficient_on_edges": - dict( - standard_name="nudging_coefficient_on_edges", - long_name="nudging coefficients on edges", - units="", - dtype = ta.wpfloat, - dims = (dims.EdgeDim,), - icon_var_name="nudgecoeff_e", - ), - "refin_e_ctrl": - dict( - standard_name="refin_e_ctrl", - long_name="grid refinement control on edgeds", - units="", - dtype = int, - dims = (dims.EdgeDim,), - icon_var_name="refin_e_ctrl", - ) - + "model_interface_height": dict( + standard_name="model_interface_height", + long_name="height value of half levels without topography", + units="m", + dims=(dims.KHalfDim,), + dtype=ta.wpfloat, + positive="up", + icon_var_name="vct_a", + ), + "nudging_coefficient_on_edges": dict( + standard_name="nudging_coefficient_on_edges", + long_name="nudging coefficients on edges", + units="", + dtype=ta.wpfloat, + dims=(dims.EdgeDim,), + icon_var_name="nudgecoeff_e", + ), + "refin_e_ctrl": dict( + standard_name="refin_e_ctrl", + long_name="grid refinement control on edgeds", + units="", + dtype=int, + dims=(dims.EdgeDim,), + icon_var_name="refin_e_ctrl", + ), } diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index b76cd9126..72345c602 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -249,8 +249,11 @@ def test_factory_for_k_only_field(icon_grid, metrics_savepoint, grid_savepoint, divdamp_type = 3 pre_computed_fields = factory.PrecomputedFieldsProvider({"model_interface_height": vct_a}) fields_factory.register_provider(pre_computed_fields) - vertical_grid = v_grid.VerticalGrid(v_grid.VerticalGridConfig(grid_savepoint.num(dims.KDim)), - grid_savepoint.vct_a(), grid_savepoint.vct_b()) + vertical_grid = v_grid.VerticalGrid( + v_grid.VerticalGridConfig(grid_savepoint.num(dims.KDim)), + grid_savepoint.vct_a(), + grid_savepoint.vct_b(), + ) provider = factory.ProgramFieldProvider( func=mf.compute_scalfac_dd3d, domain={ @@ -263,12 +266,13 @@ def test_factory_for_k_only_field(icon_grid, metrics_savepoint, grid_savepoint, "divdamp_trans_end": divdamp_trans_end, "divdamp_type": divdamp_type, }, - ) fields_factory.register_provider(provider) fields_factory.with_grid(icon_grid, vertical_grid).with_backend(backend) - helpers.dallclose(fields_factory.get("scaling_factor_for_3d_divergence_damping").asnumpy(), - metrics_savepoint.scalfac_dd3d().asnumpy()) + helpers.dallclose( + fields_factory.get("scaling_factor_for_3d_divergence_damping").asnumpy(), + metrics_savepoint.scalfac_dd3d().asnumpy(), + ) def test_horizontal_only_field(icon_grid, interpolation_savepoint, grid_savepoint, backend): @@ -276,12 +280,18 @@ def test_horizontal_only_field(icon_grid, interpolation_savepoint, grid_savepoin refin_ctl = grid_savepoint.refin_ctrl(dims.EdgeDim) pre_computed_fields = factory.PrecomputedFieldsProvider({"refin_e_ctrl": refin_ctl}) fields_factory.register_provider(pre_computed_fields) - vertical_grid = v_grid.VerticalGrid(v_grid.VerticalGridConfig(grid_savepoint.num(dims.KDim)), - grid_savepoint.vct_a(), grid_savepoint.vct_b()) + vertical_grid = v_grid.VerticalGrid( + v_grid.VerticalGridConfig(grid_savepoint.num(dims.KDim)), + grid_savepoint.vct_a(), + grid_savepoint.vct_b(), + ) provider = factory.ProgramFieldProvider( func=compute_nudgecoeffs.compute_nudgecoeffs, domain={ - dims.EdgeDim: (edge_domain(h_grid.Zone.NUDGING_LEVEL_2), edge_domain(h_grid.Zone.LOCAL)), + dims.EdgeDim: ( + edge_domain(h_grid.Zone.NUDGING_LEVEL_2), + edge_domain(h_grid.Zone.LOCAL), + ), }, deps={"refin_ctrl": "refin_e_ctrl"}, fields={"nudgecoeffs_e": "nudging_coefficient_on_edges"}, @@ -289,11 +299,12 @@ def test_horizontal_only_field(icon_grid, interpolation_savepoint, grid_savepoin "grf_nudge_start_e": 10, "nudge_max_coeffs": 0.375, "nudge_efold_width": 2.0, - "nudge_zone_width": 10 + "nudge_zone_width": 10, }, - ) fields_factory.register_provider(provider) fields_factory.with_grid(icon_grid, vertical_grid).with_backend(backend) - helpers.dallclose(fields_factory.get("nudging_coefficient_on_edges").asnumpy(), - interpolation_savepoint.nudgecoeff_e().asnumpy()) + helpers.dallclose( + fields_factory.get("nudging_coefficient_on_edges").asnumpy(), + interpolation_savepoint.nudgecoeff_e().asnumpy(), + ) From e417fd1d755a7e1809ea1a56e93b241b72646130 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Wed, 2 Oct 2024 21:33:38 +0200 Subject: [PATCH 29/37] add types for metadata attributes --- .../src/icon4py/model/common/states/metadata.py | 4 +++- .../src/icon4py/model/common/states/model.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/metadata.py b/model/common/src/icon4py/model/common/states/metadata.py index ab0fd1726..2b03954c4 100644 --- a/model/common/src/icon4py/model/common/states/metadata.py +++ b/model/common/src/icon4py/model/common/states/metadata.py @@ -5,14 +5,16 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import Final import gt4py.next as gtx import icon4py.model.common.io.cf_utils as cf_utils from icon4py.model.common import dimension as dims, type_alias as ta +from icon4py.model.common.states import model -attrs = { +attrs:Final[dict[str, model.FieldMetaData]] = { "functional_determinant_of_metrics_on_interface_levels": dict( standard_name="functional_determinant_of_metrics_on_interface_levels", long_name="functional determinant of the metrics [sqrt(gamma)] on half levels", diff --git a/model/common/src/icon4py/model/common/states/model.py b/model/common/src/icon4py/model/common/states/model.py index 9905eedfe..2c89d70b0 100644 --- a/model/common/src/icon4py/model/common/states/model.py +++ b/model/common/src/icon4py/model/common/states/model.py @@ -9,18 +9,22 @@ import dataclasses import functools -from typing import Protocol, TypedDict, Union, runtime_checkable +from typing import Literal, Protocol, TypedDict, Union, runtime_checkable import gt4py._core.definitions as gt_coredefs import gt4py.next as gtx import gt4py.next.common as gt_common import numpy.typing as np_t +import icon4py.model.common.type_alias as ta -"""Contains type definitions used for the model`s state representation.""" -DimensionT = Union[gtx.Dimension, str] +"""Contains type definitions used for the model`s state representation.""" +DimensionNames = Literal["cell", "edge", "vertex"] +DimensionT = Union[gtx.Dimension, DimensionNames] #TODO use Literal instead of str BufferT = Union[np_t.ArrayLike, gtx.Field] +DTypeT = Union[ta.wpfloat, ta.vpfloat, gtx.int32, gtx.int64, gtx.float32, gtx.float64] + class OptionalMetaData(TypedDict, total=False): @@ -28,8 +32,10 @@ class OptionalMetaData(TypedDict, total=False): long_name: str #: we might not have this one for all fields. But it is useful to have it for tractability with ICON icon_var_name: str - # TODO (@halungge) dims should probably be required + # TODO (@halungge) dims should probably be required? dims: tuple[DimensionT, ...] + dtype: Union[ta.wpfloat, ta.vpfloat, gtx.int32, gtx.int64, gtx.float32, gtx.float64] + class RequiredMetaData(TypedDict, total=True): From 75bda6d51beb89af92ea571632fd0c64c8afa5ab Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Wed, 2 Oct 2024 21:34:12 +0200 Subject: [PATCH 30/37] fix int32 issues (ad hoc fix) --- model/common/src/icon4py/model/common/grid/icon.py | 12 ++++++------ .../common/src/icon4py/model/common/grid/vertical.py | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/model/common/src/icon4py/model/common/grid/icon.py b/model/common/src/icon4py/model/common/grid/icon.py index 8b1549687..7334c3bf1 100644 --- a/model/common/src/icon4py/model/common/grid/icon.py +++ b/model/common/src/icon4py/model/common/grid/icon.py @@ -168,7 +168,7 @@ def n_shift(self): def lvert_nest(self): return True if self.config.lvertnest else False - def start_index(self, domain: h_grid.Domain): + def start_index(self, domain: h_grid.Domain)->gtx.int32: """ Use to specify lower end of domains of a field for field_operators. @@ -177,10 +177,10 @@ def start_index(self, domain: h_grid.Domain): """ if domain.local: # special treatment because this value is not set properly in the underlying data. - return 0 - return self._start_indices[domain.dim][domain()].item() + return gtx.int32(0) + return gtx.int32(self._start_indices[domain.dim][domain()]) - def end_index(self, domain: h_grid.Domain): + def end_index(self, domain: h_grid.Domain)->gtx.int32: """ Use to specify upper end of domains of a field for field_operators. @@ -189,5 +189,5 @@ def end_index(self, domain: h_grid.Domain): """ if domain.zone == h_grid.Zone.INTERIOR and not self.limited_area: # special treatment because this value is not set properly in the underlying data, for a global grid - return self.size[domain.dim] - return self._end_indices[domain.dim][domain()].item() + return gtx.int32(self.size[domain.dim]) + return gtx.int32(self._end_indices[domain.dim][domain()].item()) diff --git a/model/common/src/icon4py/model/common/grid/vertical.py b/model/common/src/icon4py/model/common/grid/vertical.py index 30ae233e7..d450d019c 100644 --- a/model/common/src/icon4py/model/common/grid/vertical.py +++ b/model/common/src/icon4py/model/common/grid/vertical.py @@ -178,7 +178,7 @@ def num_levels(self): def index(self, domain: Domain) -> gtx.int32: match domain.marker: case Zone.TOP: - index = gtx.int32(0) + index = 0 case Zone.BOTTOM: index = self._bottom_level(domain) case Zone.MOIST: @@ -194,10 +194,10 @@ def index(self, domain: Domain) -> gtx.int32: assert ( 0 <= index <= self._bottom_level(domain) ), f"vertical index {index} outside of grid levels for {domain.dim}" - return index + return gtx.int32(index) - def _bottom_level(self, domain: Domain) -> gtx.int32: - return gtx.int32(self.size(domain.dim)) + def _bottom_level(self, domain: Domain) -> int: + return self.size(domain.dim) @property def interface_physical_height(self) -> fa.KField[float]: From f978d729615a91b3f77c3bda66b11b35bb28448f Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Wed, 2 Oct 2024 21:35:11 +0200 Subject: [PATCH 31/37] rename providers, fixes in FieldProvider Protocol --- .../model/common/metrics/metrics_factory.py | 2 +- .../icon4py/model/common/states/factory.py | 151 +++++++++++------- .../common/tests/states_test/test_factory.py | 117 +++++++++----- 3 files changed, 171 insertions(+), 99 deletions(-) diff --git a/model/common/src/icon4py/model/common/metrics/metrics_factory.py b/model/common/src/icon4py/model/common/metrics/metrics_factory.py index 58a28a0f7..c7cddd629 100644 --- a/model/common/src/icon4py/model/common/metrics/metrics_factory.py +++ b/model/common/src/icon4py/model/common/metrics/metrics_factory.py @@ -47,7 +47,7 @@ grid = grid_savepoint.global_grid_params fields_factory.register_provider( - factory.PrecomputedFieldsProvider( + factory.PrecomputedFieldProvider( { "height_on_interface_levels": interface_model_height, "cell_to_edge_interpolation_coefficient": c_lin_e, diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index d50b04a2b..23e55545e 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -11,8 +11,9 @@ import functools import inspect from typing import ( + Any, Callable, - Iterable, + Mapping, Optional, Protocol, Sequence, @@ -33,23 +34,23 @@ vertical as v_grid, ) from icon4py.model.common.settings import xp -from icon4py.model.common.states import metadata as metadata, utils as state_utils +from icon4py.model.common.states import metadata as metadata, model, utils as state_utils from icon4py.model.common.utils import builder DomainType = TypeVar("DomainType", h_grid.Domain, v_grid.Domain) -class RetrievalType(enum.IntEnum): - FIELD = (0,) - DATA_ARRAY = (1,) - METADATA = (2,) +class RetrievalType(enum.Enum): + FIELD = 0 + DATA_ARRAY = 1 + METADATA = 2 -def valid(func): +def check_setup(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - if not self.validate(): + if not self.is_setup(): raise exceptions.IncompleteSetupError( "Factory not fully instantiated, missing grid or allocator" ) @@ -67,36 +68,36 @@ class FieldProvider(Protocol): A FieldProvider is a callable that has three methods (except for __call__): - evaluate (abstract) : computes the fields based on the instructions of the concrete implementation - - fields(): returns the list of field names provided by the provider - - dependencies(): returns a list of field_names that the fields provided by this provider depend on. + - fields: Mapping of a field_name to list of field names provided by the provider + - dependencies: returns a list of field_names that the fields provided by this provider depend on. - evaluate must be implemented, for the others default implementations are provided. """ - - def __init__(self, func: Callable): - self._func = func - self._fields: dict[str, Optional[state_utils.FieldType]] = {} - self._dependencies: dict[str, str] = {} + @abc.abstractmethod def evaluate(self, factory: "FieldsFactory") -> None: - pass + ... def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType: - if field_name not in self.fields(): - raise ValueError(f"Field {field_name} not provided by f{self._func.__name__}.") - if any([f is None for f in self._fields.values()]): + if field_name not in self.fields: + raise ValueError(f"Field {field_name} not provided by f{self.func.__name__}.") + if any([f is None for f in self.fields.values()]): self.evaluate(factory) - return self._fields[field_name] + return self.fields[field_name] - def dependencies(self) -> Iterable[str]: - return self._dependencies.values() - - def fields(self) -> Iterable[str]: - return self._fields.keys() + @property + def dependencies(self) -> Sequence[str]: + ... + @property + def fields(self) -> Mapping[str, Any]: + ... + + @property + def func(self)->Callable: + ... -class PrecomputedFieldsProvider(FieldProvider): +class PrecomputedFieldProvider(FieldProvider): """Simple FieldProvider that does not do any computation but gets its fields at construction and returns it upon provider.get(field_name).""" def __init__(self, fields: dict[str, state_utils.FieldType]): @@ -105,11 +106,22 @@ def __init__(self, fields: dict[str, state_utils.FieldType]): def evaluate(self, factory: "FieldsFactory") -> None: pass + @property def dependencies(self) -> Sequence[str]: return [] def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType: - return self._fields[field_name] + return self.fields[field_name] + + # TODO signature should this only return the field_names produced by this provider? + @property + def fields(self) -> Mapping[str, Any]: + return self._fields + + + @property + def func(self) -> Callable: + return lambda : self.fields class ProgramFieldProvider(FieldProvider): @@ -119,7 +131,7 @@ class ProgramFieldProvider(FieldProvider): Args: func: GT4Py Program that computes the fields domain: the compute domain used for the stencil computation - fields: dict[str, str], fields produced by this stencils: the key is the variable name of the out arguments used in the program and the value the name the field is registered under and declared in the metadata. + fields: dict[str, str], fields computed by this stencil: the key is the variable name of the out arguments used in the program and the value the name the field is registered under and declared in the metadata. deps: dict[str, str], input fields used for computing this stencil: the key is the variable name used in the program and the value the name of the field it depends on. params: scalar parameters used in the program """ @@ -127,8 +139,8 @@ class ProgramFieldProvider(FieldProvider): def __init__( self, func: gtx_decorator.Program, - domain: dict[gtx.Dimension : tuple[DomainType, DomainType]], - fields: dict[str:str], + domain: dict[gtx.Dimension, tuple[DomainType, DomainType]], + fields: dict[str, str], deps: dict[str, str], params: Optional[dict[str, state_utils.Scalar]] = None, ): @@ -221,11 +233,19 @@ def evaluate(self, factory: "FieldsFactory"): deps.update(dims) self._func.with_backend(factory._backend)(**deps, offset_provider=offset_providers) - def fields(self) -> Iterable[str]: - return self._output.values() - + @property + def fields(self) -> Mapping[str, Any]: + return self._fields + + @property + def func(self) ->Callable: + return self._func + @property + def dependencies(self) -> Sequence[str]: + return list(self._dependencies.values()) + -class NumpyFieldsProvider(FieldProvider): +class NumpyFieldProvider(FieldProvider): """ Computes a field defined by a numpy function. @@ -266,7 +286,7 @@ def evaluate(self, factory: "FieldsFactory") -> None: results = (results,) if isinstance(results, xp.ndarray) else results self._fields = { - k: gtx.as_field(tuple(self._dims), results[i]) for i, k in enumerate(self.fields()) + k: gtx.as_field(tuple(self._dims), results[i]) for i, k in enumerate(self.fields) } def _validate_dependencies(self): @@ -289,6 +309,17 @@ def _validate_dependencies(self): f"exist or has the wrong type: {type(param_value)}." ) + @property + def func(self) ->Callable: + return self._func + + @property + def dependencies(self) -> Sequence[str]: + return list(self._dependencies.values()) + + @property + def fields(self) -> Mapping[str, Any]: + return self._fields def _check( parameter_definition: inspect.Parameter, @@ -304,26 +335,30 @@ def _check( class FieldsFactory: - """ - Factory for fields. - - Lazily compute fields and cache them. - """ - def __init__( self, + metadata: dict[str, model.FieldMetaData], grid: icon_grid.IconGrid = None, vertical_grid: v_grid.VerticalGrid = None, - backend=settings.backend, + backend=None, + + ): + self._metadata = metadata self._grid = grid self._vertical = vertical_grid self._providers: dict[str, "FieldProvider"] = {} self._backend = backend self._allocator = gtx.constructors.zeros.partial(allocator=backend) - def validate(self): - return self._grid is not None + """ + Factory for fields. + + Lazily compute fields and cache them. + """ + + def is_setup(self): + return self._grid is not None and self.backend is not None @builder.builder def with_grid(self, grid: base_grid.BaseGrid, vertical_grid: v_grid.VerticalGrid): @@ -352,25 +387,25 @@ def allocator(self): return self._allocator def register_provider(self, provider: FieldProvider): - for dependency in provider.dependencies(): + for dependency in provider.dependencies: if dependency not in self._providers.keys(): raise ValueError(f"Dependency '{dependency}' not found in registered providers") - for field in provider.fields(): + for field in provider.fields: self._providers[field] = provider - @valid + @check_setup def get( self, field_name: str, type_: RetrievalType = RetrievalType.FIELD - ) -> Union[state_utils.FieldType, xa.DataArray, dict]: - if field_name not in metadata.attrs: - raise ValueError(f"Field {field_name} not found in metric fields") + ) -> Union[state_utils.FieldType, xa.DataArray, model.FieldMetaData]: + if field_name not in self._providers: + raise ValueError(f"Field {field_name} not provided by the factory") if type_ == RetrievalType.METADATA: - return metadata.attrs[field_name] - if type_ == RetrievalType.FIELD: - return self._providers[field_name](field_name, self) - if type_ == RetrievalType.DATA_ARRAY: - return state_utils.to_data_array( - self._providers[field_name](field_name, self), metadata.attrs[field_name] - ) + return self._metadata[field_name] + if type_ in (RetrievalType.FIELD,RetrievalType.DATA_ARRAY): + provider = self._providers[field_name] + buffer = provider(field_name, self) + return buffer if type_ == RetrievalType.FIELD else state_utils.to_data_array(buffer, self._metadata[field_name]) + + raise ValueError(f"Invalid retrieval type {type_}") diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index 72345c602..9d88d277c 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -8,7 +8,6 @@ import gt4py.next as gtx import pytest -from common.tests.metric_tests.test_metric_fields import edge_domain import icon4py.model.common.test_utils.helpers as helpers from icon4py.model.common import dimension as dims, exceptions @@ -20,7 +19,7 @@ compute_wgtfacq_e_dsl, ) from icon4py.model.common.settings import xp -from icon4py.model.common.states import factory +from icon4py.model.common.states import factory, metadata cell_domain = h_grid.domain(dims.CellDim) @@ -29,8 +28,16 @@ @pytest.mark.datatest -def test_factory_check_dependencies_on_register(icon_grid, backend): - fields_factory = factory.FieldsFactory(icon_grid, backend) +def test_factory_check_dependencies_on_register(grid_savepoint, backend): + grid = grid_savepoint.construct_icon_grid(False) + vertical = v_grid.VerticalGrid( + v_grid.VerticalGridConfig(num_levels=10), + grid_savepoint.vct_a(), + grid_savepoint.vct_b(), + ) + + fields_factory = (factory.FieldsFactory(metadata.attrs).with_grid(grid, vertical) + .with_backend(backend)) provider = factory.ProgramFieldProvider( func=mf.compute_z_mc, domain={ @@ -47,23 +54,42 @@ def test_factory_check_dependencies_on_register(icon_grid, backend): @pytest.mark.datatest -def test_factory_raise_error_if_no_grid_is_set(metrics_savepoint): +def test_factory_raise_error_if_no_grid_is_set(metrics_savepoint, backend): z_ifc = metrics_savepoint.z_ifc() k_index = gtx.as_field((dims.KDim,), xp.arange(1, dtype=gtx.int32)) - pre_computed_fields = factory.PrecomputedFieldsProvider( + pre_computed_fields = factory.PrecomputedFieldProvider( {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} ) - fields_factory = factory.FieldsFactory(grid=None) + fields_factory = factory.FieldsFactory(metadata = metadata.attrs).with_backend(backend) fields_factory.register_provider(pre_computed_fields) with pytest.raises(exceptions.IncompleteSetupError) as e: fields_factory.get("height_on_interface_levels") assert e.value.match("not fully instantiated") +@pytest.mark.datatest +def test_factory_raise_error_if_no_backend_is_set(metrics_savepoint, grid_savepoint): + grid = grid_savepoint.construct_icon_grid(False) # TODO fix this should go away + z_ifc = metrics_savepoint.z_ifc() + k_index = gtx.as_field((dims.KDim,), xp.arange(1, dtype=gtx.int32)) + pre_computed_fields = factory.PrecomputedFieldProvider( + {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} + ) + vertical = v_grid.VerticalGrid( + v_grid.VerticalGridConfig(num_levels=10), + grid_savepoint.vct_a(), + grid_savepoint.vct_b(), + ) + fields_factory = factory.FieldsFactory(metadata=metadata.attrs).with_grid(grid, vertical) + fields_factory.register_provider(pre_computed_fields) + with pytest.raises(exceptions.IncompleteSetupError) as e: + fields_factory.get("height_on_interface_levels") + assert e.value.match("not fully instantiated") + @pytest.mark.datatest def test_factory_returns_field(grid_savepoint, metrics_savepoint, backend): z_ifc = metrics_savepoint.z_ifc() - grid = grid_savepoint.construct_icon_grid(on_gpu=False) # TODO: determine from backend + grid = grid_savepoint.construct_icon_grid(on_gpu=False) num_levels = grid_savepoint.num(dims.KDim) vertical = v_grid.VerticalGrid( v_grid.VerticalGridConfig(num_levels=num_levels), @@ -71,10 +97,10 @@ def test_factory_returns_field(grid_savepoint, metrics_savepoint, backend): grid_savepoint.vct_b(), ) k_index = gtx.as_field((dims.KDim,), xp.arange(num_levels + 1, dtype=gtx.int32)) - pre_computed_fields = factory.PrecomputedFieldsProvider( + pre_computed_fields = factory.PrecomputedFieldProvider( {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} ) - fields_factory = factory.FieldsFactory() + fields_factory = factory.FieldsFactory(metadata=metadata.attrs) fields_factory.register_provider(pre_computed_fields) fields_factory.with_grid(grid, vertical).with_backend(backend) field = fields_factory.get("height_on_interface_levels", factory.RetrievalType.FIELD) @@ -97,22 +123,20 @@ def test_factory_returns_field(grid_savepoint, metrics_savepoint, backend): def test_field_provider_for_program(grid_savepoint, metrics_savepoint, backend): horizontal_grid = grid_savepoint.construct_icon_grid( on_gpu=False - ) # TODO: determine from backend + ) num_levels = grid_savepoint.num(dims.KDim) - vct_a = grid_savepoint.vct_a() - vct_b = grid_savepoint.vct_b() vertical_grid = v_grid.VerticalGrid( - v_grid.VerticalGridConfig(num_levels=num_levels), vct_a, vct_b + v_grid.VerticalGridConfig(num_levels=num_levels), grid_savepoint.vct_a(), grid_savepoint.vct_b() ) - fields_factory = factory.FieldsFactory() + fields_factory = factory.FieldsFactory(metadata=metadata.attrs) k_index = gtx.as_field((dims.KDim,), xp.arange(num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() local_cell_domain = cell_domain(h_grid.Zone.LOCAL) end_cell_domain = cell_domain(h_grid.Zone.END) - pre_computed_fields = factory.PrecomputedFieldsProvider( + pre_computed_fields = factory.PrecomputedFieldProvider( {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} ) @@ -157,22 +181,29 @@ def test_field_provider_for_program(grid_savepoint, metrics_savepoint, backend): assert helpers.dallclose(data.ndarray, ref) -def test_field_provider_for_numpy_function( - icon_grid, metrics_savepoint, interpolation_savepoint, backend +def test_field_provider_for_numpy_function(grid_savepoint, + metrics_savepoint, interpolation_savepoint, backend ): - fields_factory = factory.FieldsFactory(grid=icon_grid, backend=backend) - k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) + grid = grid_savepoint.construct_icon_grid(False) # TODO fix this should be come obsolete + vertical_grid = v_grid.VerticalGrid( + v_grid.VerticalGridConfig(num_levels=grid.num_levels), grid_savepoint.vct_a(), + grid_savepoint.vct_b() + ) + + fields_factory = (factory.FieldsFactory(metadata=metadata.attrs) + .with_grid(grid=grid, vertical_grid=vertical_grid).with_backend(backend)) + k_index = gtx.as_field((dims.KDim,), xp.arange(grid.num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() wgtfacq_c_ref = metrics_savepoint.wgtfacq_c_dsl() - pre_computed_fields = factory.PrecomputedFieldsProvider( + pre_computed_fields = factory.PrecomputedFieldProvider( {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} ) fields_factory.register_provider(pre_computed_fields) func = compute_wgtfacq_c_dsl deps = {"z_ifc": "height_on_interface_levels"} - params = {"nlev": icon_grid.num_levels} - compute_wgtfacq_c_provider = factory.NumpyFieldsProvider( + params = {"nlev": grid.num_levels} + compute_wgtfacq_c_provider = factory.NumpyFieldProvider( func=func, domain={ dims.CellDim: (cell_domain(h_grid.Zone.LOCAL), cell_domain(h_grid.Zone.END)), @@ -192,15 +223,20 @@ def test_field_provider_for_numpy_function( def test_field_provider_for_numpy_function_with_offsets( - icon_grid, metrics_savepoint, interpolation_savepoint, backend + grid_savepoint, metrics_savepoint, interpolation_savepoint, backend ): - fields_factory = factory.FieldsFactory(grid=icon_grid, backend=backend) - k_index = gtx.as_field((dims.KDim,), xp.arange(icon_grid.num_levels + 1, dtype=gtx.int32)) + grid = grid_savepoint.construct_icon_grid(False) # TODO fix this should be come obsolete + vertical = v_grid.VerticalGrid( + v_grid.VerticalGridConfig(num_levels=grid.num_levels), grid_savepoint.vct_a(), + grid_savepoint.vct_b() + ) + fields_factory = factory.FieldsFactory(metadata=metadata.attrs).with_grid(grid=grid, vertical_grid=vertical).with_backend(backend=backend) + k_index = gtx.as_field((dims.KDim,), xp.arange(grid.num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() c_lin_e = interpolation_savepoint.c_lin_e() - wgtfacq_e_ref = metrics_savepoint.wgtfacq_e_dsl(icon_grid.num_levels + 1) + wgtfacq_e_ref = metrics_savepoint.wgtfacq_e_dsl(grid.num_levels + 1) - pre_computed_fields = factory.PrecomputedFieldsProvider( + pre_computed_fields = factory.PrecomputedFieldProvider( { "height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index, @@ -210,10 +246,10 @@ def test_field_provider_for_numpy_function_with_offsets( fields_factory.register_provider(pre_computed_fields) func = compute_wgtfacq_c_dsl # TODO (magdalena): need to fix this for parameters - params = {"nlev": icon_grid.num_levels} - compute_wgtfacq_c_provider = factory.NumpyFieldsProvider( + params = {"nlev": grid.num_levels} + compute_wgtfacq_c_provider = factory.NumpyFieldProvider( func=func, - domain={dims.CellDim: (0, icon_grid.num_cells), dims.KDim: (0, icon_grid.num_levels)}, + domain={dims.CellDim: (0, grid.num_cells), dims.KDim: (0, grid.num_levels)}, fields=["weighting_factor_for_quadratic_interpolation_to_cell_surface"], deps={"z_ifc": "height_on_interface_levels"}, params=params, @@ -224,13 +260,13 @@ def test_field_provider_for_numpy_function_with_offsets( "c_lin_e": "cell_to_edge_interpolation_coefficient", } fields_factory.register_provider(compute_wgtfacq_c_provider) - wgtfacq_e_provider = factory.NumpyFieldsProvider( + wgtfacq_e_provider = factory.NumpyFieldProvider( func=compute_wgtfacq_e_dsl, deps=deps, offsets={"e2c": dims.E2CDim}, - domain={dims.EdgeDim: (0, icon_grid.num_edges), dims.KDim: (0, icon_grid.num_levels)}, + domain={dims.EdgeDim: (0, grid.num_edges), dims.KDim: (0, grid.num_levels)}, fields=["weighting_factor_for_quadratic_interpolation_to_edge_center"], - params={"n_edges": icon_grid.num_edges, "nlev": icon_grid.num_levels}, + params={"n_edges": grid.num_edges, "nlev": grid.num_levels}, ) fields_factory.register_provider(wgtfacq_e_provider) @@ -242,12 +278,12 @@ def test_field_provider_for_numpy_function_with_offsets( def test_factory_for_k_only_field(icon_grid, metrics_savepoint, grid_savepoint, backend): - fields_factory = factory.FieldsFactory() + fields_factory = factory.FieldsFactory(metadata=metadata.attrs) vct_a = grid_savepoint.vct_a() divdamp_trans_start = 12500.0 divdamp_trans_end = 17500.0 divdamp_type = 3 - pre_computed_fields = factory.PrecomputedFieldsProvider({"model_interface_height": vct_a}) + pre_computed_fields = factory.PrecomputedFieldProvider({"model_interface_height": vct_a}) fields_factory.register_provider(pre_computed_fields) vertical_grid = v_grid.VerticalGrid( v_grid.VerticalGridConfig(grid_savepoint.num(dims.KDim)), @@ -276,21 +312,22 @@ def test_factory_for_k_only_field(icon_grid, metrics_savepoint, grid_savepoint, def test_horizontal_only_field(icon_grid, interpolation_savepoint, grid_savepoint, backend): - fields_factory = factory.FieldsFactory() + fields_factory = factory.FieldsFactory(metadata=metadata.attrs) refin_ctl = grid_savepoint.refin_ctrl(dims.EdgeDim) - pre_computed_fields = factory.PrecomputedFieldsProvider({"refin_e_ctrl": refin_ctl}) + pre_computed_fields = factory.PrecomputedFieldProvider({"refin_e_ctrl": refin_ctl}) fields_factory.register_provider(pre_computed_fields) vertical_grid = v_grid.VerticalGrid( v_grid.VerticalGridConfig(grid_savepoint.num(dims.KDim)), grid_savepoint.vct_a(), grid_savepoint.vct_b(), ) + domain = h_grid.domain(dims.EdgeDim) provider = factory.ProgramFieldProvider( func=compute_nudgecoeffs.compute_nudgecoeffs, domain={ dims.EdgeDim: ( - edge_domain(h_grid.Zone.NUDGING_LEVEL_2), - edge_domain(h_grid.Zone.LOCAL), + domain(h_grid.Zone.NUDGING_LEVEL_2), + domain(h_grid.Zone.LOCAL), ), }, deps={"refin_ctrl": "refin_e_ctrl"}, From e635e3df58fdea733a4019b1087772f76ba75767 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 3 Oct 2024 18:47:24 +0200 Subject: [PATCH 32/37] add FieldSource Protocol --- .../icon4py/model/common/states/factory.py | 92 +++++++++++-------- .../common/tests/states_test/test_factory.py | 29 ++---- 2 files changed, 59 insertions(+), 62 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 23e55545e..5c7b88d40 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -41,22 +41,13 @@ DomainType = TypeVar("DomainType", h_grid.Domain, v_grid.Domain) + class RetrievalType(enum.Enum): FIELD = 0 DATA_ARRAY = 1 METADATA = 2 -def check_setup(func): - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - if not self.is_setup(): - raise exceptions.IncompleteSetupError( - "Factory not fully instantiated, missing grid or allocator" - ) - return func(self, *args, **kwargs) - - return wrapper class FieldProvider(Protocol): @@ -72,19 +63,10 @@ class FieldProvider(Protocol): - dependencies: returns a list of field_names that the fields provided by this provider depend on. """ - - - @abc.abstractmethod - def evaluate(self, factory: "FieldsFactory") -> None: - ... def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType: - if field_name not in self.fields: - raise ValueError(f"Field {field_name} not provided by f{self.func.__name__}.") - if any([f is None for f in self.fields.values()]): - self.evaluate(factory) - return self.fields[field_name] - + ... + @property def dependencies(self) -> Sequence[str]: ... @@ -103,9 +85,6 @@ class PrecomputedFieldProvider(FieldProvider): def __init__(self, fields: dict[str, state_utils.FieldType]): self._fields = fields - def evaluate(self, factory: "FieldsFactory") -> None: - pass - @property def dependencies(self) -> Sequence[str]: return [] @@ -113,9 +92,8 @@ def dependencies(self) -> Sequence[str]: def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType: return self.fields[field_name] - # TODO signature should this only return the field_names produced by this provider? @property - def fields(self) -> Mapping[str, Any]: + def fields(self) -> Mapping[str, state_utils.FieldType]: return self._fields @@ -128,6 +106,8 @@ class ProgramFieldProvider(FieldProvider): """ Computes a field defined by a GT4Py Program. + TODO (halungge): use field_operator instead. + Args: func: GT4Py Program that computes the fields domain: the compute domain used for the stencil computation @@ -223,7 +203,12 @@ def _domain_args( raise ValueError(f"DimensionKind '{dim.kind}' not supported in Program Domain") return domain_args - def evaluate(self, factory: "FieldsFactory"): + def __call__(self, field_name: str, factory: "FieldsFactory"): + if any([f is None for f in self.fields.values()]): + self._compute(factory) + return self.fields[field_name] + + def _compute(self, factory)->None: self._fields = self._allocate(factory.allocator, factory.grid) deps = {k: factory.get(v) for k, v in self._dependencies.items()} deps.update(self._params) @@ -234,7 +219,7 @@ def evaluate(self, factory: "FieldsFactory"): self._func.with_backend(factory._backend)(**deps, offset_provider=offset_providers) @property - def fields(self) -> Mapping[str, Any]: + def fields(self) -> Mapping[str, state_utils.FieldType]: return self._fields @property @@ -245,7 +230,7 @@ def dependencies(self) -> Sequence[str]: return list(self._dependencies.values()) -class NumpyFieldProvider(FieldProvider): +class NumpyFieldsProvider(FieldProvider): """ Computes a field defined by a numpy function. @@ -275,7 +260,12 @@ def __init__( self._offsets = offsets if offsets is not None else {} self._params = params if params is not None else {} - def evaluate(self, factory: "FieldsFactory") -> None: + def __call__(self, field_name:str, factory: "FieldsFactory") -> None: + if any([f is None for f in self.fields.values()]): + self._compute(factory) + return self.fields[field_name] + + def _compute(self, factory)->None: self._validate_dependencies() args = {k: factory.get(v).ndarray for k, v in self._dependencies.items()} offsets = {k: factory.grid.connectivities[v] for k, v in self._offsets.items()} @@ -284,7 +274,6 @@ def evaluate(self, factory: "FieldsFactory") -> None: results = self._func(**args) ## TODO: can the order of return values be checked? results = (results,) if isinstance(results, xp.ndarray) else results - self._fields = { k: gtx.as_field(tuple(self._dims), results[i]) for i, k in enumerate(self.fields) } @@ -318,7 +307,7 @@ def dependencies(self) -> Sequence[str]: return list(self._dependencies.values()) @property - def fields(self) -> Mapping[str, Any]: + def fields(self) -> Mapping[str, state_utils.FieldType]: return self._fields def _check( @@ -334,15 +323,32 @@ def _check( ) -class FieldsFactory: +class FieldSource(Protocol): + def get(self, field_name: str, type_: RetrievalType = RetrievalType.FIELD): + ... + +class PartialConfigurable(Protocol): + def is_fully_configured(self)->bool: + return False + + @staticmethod + def check_setup(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if not self.is_fully_configured(): + raise exceptions.IncompleteSetupError( + "Factory not fully instantiated" + ) + return func(self, *args, **kwargs) + return wrapper + +class FieldsFactory(FieldSource, PartialConfigurable): def __init__( self, metadata: dict[str, model.FieldMetaData], grid: icon_grid.IconGrid = None, vertical_grid: v_grid.VerticalGrid = None, backend=None, - - ): self._metadata = metadata self._grid = grid @@ -357,8 +363,10 @@ def __init__( Lazily compute fields and cache them. """ - def is_setup(self): - return self._grid is not None and self.backend is not None + def is_fully_configured(self): + has_grid = self._grid is not None + has_vertical = self._vertical is not None + return has_grid and has_vertical @builder.builder def with_grid(self, grid: base_grid.BaseGrid, vertical_grid: v_grid.VerticalGrid): @@ -394,7 +402,7 @@ def register_provider(self, provider: FieldProvider): for field in provider.fields: self._providers[field] = provider - @check_setup + @PartialConfigurable.check_setup def get( self, field_name: str, type_: RetrievalType = RetrievalType.FIELD ) -> Union[state_utils.FieldType, xa.DataArray, model.FieldMetaData]: @@ -402,8 +410,14 @@ def get( raise ValueError(f"Field {field_name} not provided by the factory") if type_ == RetrievalType.METADATA: return self._metadata[field_name] - if type_ in (RetrievalType.FIELD,RetrievalType.DATA_ARRAY): + if type_ in (RetrievalType.FIELD, RetrievalType.DATA_ARRAY): provider = self._providers[field_name] + if field_name not in provider.fields: + raise ValueError(f"Field {field_name} not provided by f{provider.func.__name__}.") + if any([f is None for f in provider.fields.values()]): + provider(field_name, self) + + buffer = provider(field_name, self) return buffer if type_ == RetrievalType.FIELD else state_utils.to_data_array(buffer, self._metadata[field_name]) diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index 9d88d277c..872389bc4 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -62,29 +62,12 @@ def test_factory_raise_error_if_no_grid_is_set(metrics_savepoint, backend): ) fields_factory = factory.FieldsFactory(metadata = metadata.attrs).with_backend(backend) fields_factory.register_provider(pre_computed_fields) - with pytest.raises(exceptions.IncompleteSetupError) as e: + with pytest.raises(exceptions.IncompleteSetupError) or pytest.raises(AssertionError) as e: fields_factory.get("height_on_interface_levels") - assert e.value.match("not fully instantiated") + assert e.value.match("grid") + -@pytest.mark.datatest -def test_factory_raise_error_if_no_backend_is_set(metrics_savepoint, grid_savepoint): - grid = grid_savepoint.construct_icon_grid(False) # TODO fix this should go away - z_ifc = metrics_savepoint.z_ifc() - k_index = gtx.as_field((dims.KDim,), xp.arange(1, dtype=gtx.int32)) - pre_computed_fields = factory.PrecomputedFieldProvider( - {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} - ) - vertical = v_grid.VerticalGrid( - v_grid.VerticalGridConfig(num_levels=10), - grid_savepoint.vct_a(), - grid_savepoint.vct_b(), - ) - fields_factory = factory.FieldsFactory(metadata=metadata.attrs).with_grid(grid, vertical) - fields_factory.register_provider(pre_computed_fields) - with pytest.raises(exceptions.IncompleteSetupError) as e: - fields_factory.get("height_on_interface_levels") - assert e.value.match("not fully instantiated") @pytest.mark.datatest def test_factory_returns_field(grid_savepoint, metrics_savepoint, backend): @@ -203,7 +186,7 @@ def test_field_provider_for_numpy_function(grid_savepoint, func = compute_wgtfacq_c_dsl deps = {"z_ifc": "height_on_interface_levels"} params = {"nlev": grid.num_levels} - compute_wgtfacq_c_provider = factory.NumpyFieldProvider( + compute_wgtfacq_c_provider = factory.NumpyFieldsProvider( func=func, domain={ dims.CellDim: (cell_domain(h_grid.Zone.LOCAL), cell_domain(h_grid.Zone.END)), @@ -247,7 +230,7 @@ def test_field_provider_for_numpy_function_with_offsets( func = compute_wgtfacq_c_dsl # TODO (magdalena): need to fix this for parameters params = {"nlev": grid.num_levels} - compute_wgtfacq_c_provider = factory.NumpyFieldProvider( + compute_wgtfacq_c_provider = factory.NumpyFieldsProvider( func=func, domain={dims.CellDim: (0, grid.num_cells), dims.KDim: (0, grid.num_levels)}, fields=["weighting_factor_for_quadratic_interpolation_to_cell_surface"], @@ -260,7 +243,7 @@ def test_field_provider_for_numpy_function_with_offsets( "c_lin_e": "cell_to_edge_interpolation_coefficient", } fields_factory.register_provider(compute_wgtfacq_c_provider) - wgtfacq_e_provider = factory.NumpyFieldProvider( + wgtfacq_e_provider = factory.NumpyFieldsProvider( func=compute_wgtfacq_e_dsl, deps=deps, offsets={"e2c": dims.E2CDim}, From fef2cedb20f773ad814820b434532574ae619f28 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 3 Oct 2024 19:04:58 +0200 Subject: [PATCH 33/37] fix doc strings, pre-commit --- .../src/icon4py/model/common/grid/icon.py | 6 +- .../icon4py/model/common/states/factory.py | 115 ++++++++++-------- .../icon4py/model/common/states/metadata.py | 2 +- .../src/icon4py/model/common/states/model.py | 6 +- .../common/tests/states_test/test_factory.py | 48 +++++--- 5 files changed, 100 insertions(+), 77 deletions(-) diff --git a/model/common/src/icon4py/model/common/grid/icon.py b/model/common/src/icon4py/model/common/grid/icon.py index 7334c3bf1..7f3de94a7 100644 --- a/model/common/src/icon4py/model/common/grid/icon.py +++ b/model/common/src/icon4py/model/common/grid/icon.py @@ -168,7 +168,7 @@ def n_shift(self): def lvert_nest(self): return True if self.config.lvertnest else False - def start_index(self, domain: h_grid.Domain)->gtx.int32: + def start_index(self, domain: h_grid.Domain) -> gtx.int32: """ Use to specify lower end of domains of a field for field_operators. @@ -178,9 +178,10 @@ def start_index(self, domain: h_grid.Domain)->gtx.int32: if domain.local: # special treatment because this value is not set properly in the underlying data. return gtx.int32(0) + # ndarray.item() does not respect the dtype of the array, returns a copy of the value _as the default python type_ return gtx.int32(self._start_indices[domain.dim][domain()]) - def end_index(self, domain: h_grid.Domain)->gtx.int32: + def end_index(self, domain: h_grid.Domain) -> gtx.int32: """ Use to specify upper end of domains of a field for field_operators. @@ -190,4 +191,5 @@ def end_index(self, domain: h_grid.Domain)->gtx.int32: if domain.zone == h_grid.Zone.INTERIOR and not self.limited_area: # special treatment because this value is not set properly in the underlying data, for a global grid return gtx.int32(self.size[domain.dim]) + # ndarray.item() does not respect the dtype of the array, returns a copy of the value _as the default python builtin type_ return gtx.int32(self._end_indices[domain.dim][domain()].item()) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 5c7b88d40..36af86562 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import abc + import enum import functools import inspect @@ -41,15 +41,12 @@ DomainType = TypeVar("DomainType", h_grid.Domain, v_grid.Domain) - class RetrievalType(enum.Enum): FIELD = 0 DATA_ARRAY = 1 METADATA = 2 - - class FieldProvider(Protocol): """ Protocol for field providers. @@ -57,16 +54,16 @@ class FieldProvider(Protocol): A field provider is responsible for the computation and caching of a set of fields. The fields can be accessed by their field_name (str). - A FieldProvider is a callable that has three methods (except for __call__): - - evaluate (abstract) : computes the fields based on the instructions of the concrete implementation - - fields: Mapping of a field_name to list of field names provided by the provider + A FieldProvider is a callable and additionally has three properties (except for __call__): + - func: the function used to compute the fields + - fields: Mapping of a field_name to the data buffer holding the computed values - dependencies: returns a list of field_names that the fields provided by this provider depend on. """ def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType: ... - + @property def dependencies(self) -> Sequence[str]: ... @@ -74,45 +71,52 @@ def dependencies(self) -> Sequence[str]: @property def fields(self) -> Mapping[str, Any]: ... - + @property - def func(self)->Callable: + def func(self) -> Callable: ... + class PrecomputedFieldProvider(FieldProvider): - """Simple FieldProvider that does not do any computation but gets its fields at construction and returns it upon provider.get(field_name).""" + """Simple FieldProvider that does not do any computation but gets its fields at construction + and returns it upon provider.get(field_name).""" def __init__(self, fields: dict[str, state_utils.FieldType]): self._fields = fields @property def dependencies(self) -> Sequence[str]: - return [] + return () def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType: return self.fields[field_name] - + @property def fields(self) -> Mapping[str, state_utils.FieldType]: return self._fields - - + @property def func(self) -> Callable: - return lambda : self.fields + return lambda: self.fields class ProgramFieldProvider(FieldProvider): """ Computes a field defined by a GT4Py Program. - TODO (halungge): use field_operator instead. + TODO (halungge): use field_operator instead? + TODO (halungge): need a way to specify where the dependencies and params can be retrieved. + As not all parameters can be resolved at the definition time Args: func: GT4Py Program that computes the fields domain: the compute domain used for the stencil computation - fields: dict[str, str], fields computed by this stencil: the key is the variable name of the out arguments used in the program and the value the name the field is registered under and declared in the metadata. - deps: dict[str, str], input fields used for computing this stencil: the key is the variable name used in the program and the value the name of the field it depends on. + fields: dict[str, str], fields computed by this stencil: the key is the variable name of + the out arguments used in the program and the value the name the field is registered + under and declared in the metadata. + deps: dict[str, str], input fields used for computing this stencil: + the key is the variable name used in the program and the value the name + of the field it depends on. params: scalar parameters used in the program """ @@ -208,7 +212,7 @@ def __call__(self, field_name: str, factory: "FieldsFactory"): self._compute(factory) return self.fields[field_name] - def _compute(self, factory)->None: + def _compute(self, factory) -> None: self._fields = self._allocate(factory.allocator, factory.grid) deps = {k: factory.get(v) for k, v in self._dependencies.items()} deps.update(self._params) @@ -221,24 +225,29 @@ def _compute(self, factory)->None: @property def fields(self) -> Mapping[str, state_utils.FieldType]: return self._fields - + @property - def func(self) ->Callable: + def func(self) -> Callable: return self._func + @property def dependencies(self) -> Sequence[str]: return list(self._dependencies.values()) - + class NumpyFieldsProvider(FieldProvider): """ Computes a field defined by a numpy function. + TODO (halungge): need to specify a parameter source to be able to postpone evaluation + + Args: func: numpy function that computes the fields domain: the compute domain used for the stencil computation fields: Seq[str] names under which the results fo the function will be registered - deps: dict[str, str] input fields used for computing this stencil: the key is the variable name used in the program and the value the name of the field it depends on. + deps: dict[str, str] input fields used for computing this stencil: the key is the variable name + used in the program and the value the name of the field it depends on. params: scalar arguments for the function """ @@ -260,12 +269,12 @@ def __init__( self._offsets = offsets if offsets is not None else {} self._params = params if params is not None else {} - def __call__(self, field_name:str, factory: "FieldsFactory") -> None: + def __call__(self, field_name: str, factory: "FieldsFactory") -> None: if any([f is None for f in self.fields.values()]): self._compute(factory) return self.fields[field_name] - def _compute(self, factory)->None: + def _compute(self, factory) -> None: self._validate_dependencies() args = {k: factory.get(v).ndarray for k, v in self._dependencies.items()} offsets = {k: factory.grid.connectivities[v] for k, v in self._offsets.items()} @@ -299,17 +308,18 @@ def _validate_dependencies(self): ) @property - def func(self) ->Callable: + def func(self) -> Callable: return self._func - + @property def dependencies(self) -> Sequence[str]: return list(self._dependencies.values()) - + @property def fields(self) -> Mapping[str, state_utils.FieldType]: return self._fields + def _check( parameter_definition: inspect.Parameter, value: Union[state_utils.Scalar, gtx.Field], @@ -327,8 +337,9 @@ class FieldSource(Protocol): def get(self, field_name: str, type_: RetrievalType = RetrievalType.FIELD): ... + class PartialConfigurable(Protocol): - def is_fully_configured(self)->bool: + def is_fully_configured(self) -> bool: return False @staticmethod @@ -336,12 +347,12 @@ def check_setup(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): if not self.is_fully_configured(): - raise exceptions.IncompleteSetupError( - "Factory not fully instantiated" - ) + raise exceptions.IncompleteSetupError("Factory not fully instantiated") return func(self, *args, **kwargs) + return wrapper + class FieldsFactory(FieldSource, PartialConfigurable): def __init__( self, @@ -359,8 +370,9 @@ def __init__( """ Factory for fields. - - Lazily compute fields and cache them. + + It can be queried at runtime for fields. Fields will be computed upon first request. + Uses FieldProvider to delegate the computation of the fields """ def is_fully_configured(self): @@ -408,18 +420,21 @@ def get( ) -> Union[state_utils.FieldType, xa.DataArray, model.FieldMetaData]: if field_name not in self._providers: raise ValueError(f"Field {field_name} not provided by the factory") - if type_ == RetrievalType.METADATA: - return self._metadata[field_name] - if type_ in (RetrievalType.FIELD, RetrievalType.DATA_ARRAY): - provider = self._providers[field_name] - if field_name not in provider.fields: - raise ValueError(f"Field {field_name} not provided by f{provider.func.__name__}.") - if any([f is None for f in provider.fields.values()]): - provider(field_name, self) - - - buffer = provider(field_name, self) - return buffer if type_ == RetrievalType.FIELD else state_utils.to_data_array(buffer, self._metadata[field_name]) - - - raise ValueError(f"Invalid retrieval type {type_}") + match type_: + case RetrievalType.METADATA: + return self._metadata[field_name] + case RetrievalType.FIELD | RetrievalType.DATA_ARRAY: + provider = self._providers[field_name] + if field_name not in provider.fields: + raise ValueError( + f"Field {field_name} not provided by f{provider.func.__name__}." + ) + + buffer = provider(field_name, self) + return ( + buffer + if type_ == RetrievalType.FIELD + else state_utils.to_data_array(buffer, self._metadata[field_name]) + ) + case _: + raise ValueError(f"Invalid retrieval type {type_}") diff --git a/model/common/src/icon4py/model/common/states/metadata.py b/model/common/src/icon4py/model/common/states/metadata.py index 2b03954c4..2bbe2854e 100644 --- a/model/common/src/icon4py/model/common/states/metadata.py +++ b/model/common/src/icon4py/model/common/states/metadata.py @@ -14,7 +14,7 @@ from icon4py.model.common.states import model -attrs:Final[dict[str, model.FieldMetaData]] = { +attrs: Final[dict[str, model.FieldMetaData]] = { "functional_determinant_of_metrics_on_interface_levels": dict( standard_name="functional_determinant_of_metrics_on_interface_levels", long_name="functional determinant of the metrics [sqrt(gamma)] on half levels", diff --git a/model/common/src/icon4py/model/common/states/model.py b/model/common/src/icon4py/model/common/states/model.py index 2c89d70b0..dff293a2a 100644 --- a/model/common/src/icon4py/model/common/states/model.py +++ b/model/common/src/icon4py/model/common/states/model.py @@ -20,11 +20,10 @@ """Contains type definitions used for the model`s state representation.""" -DimensionNames = Literal["cell", "edge", "vertex"] -DimensionT = Union[gtx.Dimension, DimensionNames] #TODO use Literal instead of str +DimensionNames = Literal["cell", "edge", "vertex"] +DimensionT = Union[gtx.Dimension, DimensionNames] # TODO use Literal instead of str BufferT = Union[np_t.ArrayLike, gtx.Field] DTypeT = Union[ta.wpfloat, ta.vpfloat, gtx.int32, gtx.int64, gtx.float32, gtx.float64] - class OptionalMetaData(TypedDict, total=False): @@ -35,7 +34,6 @@ class OptionalMetaData(TypedDict, total=False): # TODO (@halungge) dims should probably be required? dims: tuple[DimensionT, ...] dtype: Union[ta.wpfloat, ta.vpfloat, gtx.int32, gtx.int64, gtx.float32, gtx.float64] - class RequiredMetaData(TypedDict, total=True): diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index 872389bc4..742eaf774 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -36,8 +36,9 @@ def test_factory_check_dependencies_on_register(grid_savepoint, backend): grid_savepoint.vct_b(), ) - fields_factory = (factory.FieldsFactory(metadata.attrs).with_grid(grid, vertical) - .with_backend(backend)) + fields_factory = ( + factory.FieldsFactory(metadata.attrs).with_grid(grid, vertical).with_backend(backend) + ) provider = factory.ProgramFieldProvider( func=mf.compute_z_mc, domain={ @@ -60,19 +61,17 @@ def test_factory_raise_error_if_no_grid_is_set(metrics_savepoint, backend): pre_computed_fields = factory.PrecomputedFieldProvider( {"height_on_interface_levels": z_ifc, cf_utils.INTERFACE_LEVEL_STANDARD_NAME: k_index} ) - fields_factory = factory.FieldsFactory(metadata = metadata.attrs).with_backend(backend) + fields_factory = factory.FieldsFactory(metadata=metadata.attrs).with_backend(backend) fields_factory.register_provider(pre_computed_fields) with pytest.raises(exceptions.IncompleteSetupError) or pytest.raises(AssertionError) as e: fields_factory.get("height_on_interface_levels") assert e.value.match("grid") - - @pytest.mark.datatest def test_factory_returns_field(grid_savepoint, metrics_savepoint, backend): z_ifc = metrics_savepoint.z_ifc() - grid = grid_savepoint.construct_icon_grid(on_gpu=False) + grid = grid_savepoint.construct_icon_grid(on_gpu=False) num_levels = grid_savepoint.num(dims.KDim) vertical = v_grid.VerticalGrid( v_grid.VerticalGridConfig(num_levels=num_levels), @@ -104,12 +103,12 @@ def test_factory_returns_field(grid_savepoint, metrics_savepoint, backend): @pytest.mark.datatest def test_field_provider_for_program(grid_savepoint, metrics_savepoint, backend): - horizontal_grid = grid_savepoint.construct_icon_grid( - on_gpu=False - ) + horizontal_grid = grid_savepoint.construct_icon_grid(on_gpu=False) num_levels = grid_savepoint.num(dims.KDim) vertical_grid = v_grid.VerticalGrid( - v_grid.VerticalGridConfig(num_levels=num_levels), grid_savepoint.vct_a(), grid_savepoint.vct_b() + v_grid.VerticalGridConfig(num_levels=num_levels), + grid_savepoint.vct_a(), + grid_savepoint.vct_b(), ) fields_factory = factory.FieldsFactory(metadata=metadata.attrs) @@ -164,17 +163,21 @@ def test_field_provider_for_program(grid_savepoint, metrics_savepoint, backend): assert helpers.dallclose(data.ndarray, ref) -def test_field_provider_for_numpy_function(grid_savepoint, - metrics_savepoint, interpolation_savepoint, backend +def test_field_provider_for_numpy_function( + grid_savepoint, metrics_savepoint, interpolation_savepoint, backend ): - grid = grid_savepoint.construct_icon_grid(False) # TODO fix this should be come obsolete + grid = grid_savepoint.construct_icon_grid(False) # TODO fix this should be come obsolete vertical_grid = v_grid.VerticalGrid( - v_grid.VerticalGridConfig(num_levels=grid.num_levels), grid_savepoint.vct_a(), - grid_savepoint.vct_b() + v_grid.VerticalGridConfig(num_levels=grid.num_levels), + grid_savepoint.vct_a(), + grid_savepoint.vct_b(), ) - fields_factory = (factory.FieldsFactory(metadata=metadata.attrs) - .with_grid(grid=grid, vertical_grid=vertical_grid).with_backend(backend)) + fields_factory = ( + factory.FieldsFactory(metadata=metadata.attrs) + .with_grid(grid=grid, vertical_grid=vertical_grid) + .with_backend(backend) + ) k_index = gtx.as_field((dims.KDim,), xp.arange(grid.num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() wgtfacq_c_ref = metrics_savepoint.wgtfacq_c_dsl() @@ -210,10 +213,15 @@ def test_field_provider_for_numpy_function_with_offsets( ): grid = grid_savepoint.construct_icon_grid(False) # TODO fix this should be come obsolete vertical = v_grid.VerticalGrid( - v_grid.VerticalGridConfig(num_levels=grid.num_levels), grid_savepoint.vct_a(), - grid_savepoint.vct_b() + v_grid.VerticalGridConfig(num_levels=grid.num_levels), + grid_savepoint.vct_a(), + grid_savepoint.vct_b(), + ) + fields_factory = ( + factory.FieldsFactory(metadata=metadata.attrs) + .with_grid(grid=grid, vertical_grid=vertical) + .with_backend(backend=backend) ) - fields_factory = factory.FieldsFactory(metadata=metadata.attrs).with_grid(grid=grid, vertical_grid=vertical).with_backend(backend=backend) k_index = gtx.as_field((dims.KDim,), xp.arange(grid.num_levels + 1, dtype=gtx.int32)) z_ifc = metrics_savepoint.z_ifc() c_lin_e = interpolation_savepoint.c_lin_e() From 02cce48d5f1b116cd42ee1a6035d64576ed98d65 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Fri, 4 Oct 2024 15:46:53 +0200 Subject: [PATCH 34/37] add documentation --- .../icon4py/model/common/states/factory.py | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 36af86562..9ca427e18 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -6,6 +6,42 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +""" +Provide a FieldFactory that can serve as a simple in memory database for Fields. + +Once setup, the factory can be queried for fields using a string name for the field. Three query modes are available: +_ `FIELD`: return the buffer containing the computed values as a GT4Py `Field` +- `METADATA`: return metadata such as units, CF standard_name or similar, dimensions... +- `DATA_ARRAY`: combination of the two above in the form of `xarray.dataarray` + +The factory can be used to "store" already computed fields or register functions and call arguments +and only compute the fields lazily upon request. In order to do so the user registers the fields computation with factory. + +It should be possible to setup the factory and computations and the factory independent of concrete runtime parameters that define +the computation, passing those only once they are defined at runtime, for example +--- +factory = Factory(metadata) +foo_provider = FieldProvider("foo", func = f1, dependencies = []) +bar_provider = FieldProvider("bar", func = f2, dependencies = ["foo"]) + +factory.register_provider(foo_provider) +factory.register_provider(bar_provider) +(...) + +--- +def main(backend, grid) +factory.with_backend(backend).with_grid(grid) + +val = factory.get("foo", RetrievalType.DATA_ARRAY) + +TODO (halungge): except for domain parameters and other fields managed by the same factory we currently lack the ability to specify + other input sources in the factory for lazy evaluation. + factory.with_sources({"geometry": x}, where x:FieldSourceN + + +TODO: for the numpy functions we might have to work on the func interfaces to make them a bit more uniform. + +""" import enum import functools @@ -51,7 +87,7 @@ class FieldProvider(Protocol): """ Protocol for field providers. - A field provider is responsible for the computation and caching of a set of fields. + A field provider is responsible for the computation (and caching) of a set of fields. The fields can be accessed by their field_name (str). A FieldProvider is a callable and additionally has three properties (except for __call__): @@ -334,11 +370,18 @@ def _check( class FieldSource(Protocol): + """Protocol for object that can be queried for fields.""" def get(self, field_name: str, type_: RetrievalType = RetrievalType.FIELD): ... class PartialConfigurable(Protocol): + """ + Protocol to mark classes that are not yet fully configured upon instaniation. + + Additionally provides a decorator that makes use of the Protocol an can be used in + concrete examples to trigger a check whether the setup is complete. + """ def is_fully_configured(self) -> bool: return False From 70063a75eaa7b207e5e38324f994fdf80d7fb5f6 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Thu, 10 Oct 2024 11:01:06 +0200 Subject: [PATCH 35/37] move FieldSource protocol --- .../icon4py/model/common/states/factory.py | 24 +++++-------------- .../src/icon4py/model/common/states/utils.py | 17 +++++++++++-- .../common/tests/states_test/test_factory.py | 18 +++++++++----- 3 files changed, 33 insertions(+), 26 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 9ca427e18..67be65cc9 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -43,7 +43,6 @@ def main(backend, grid) """ -import enum import functools import inspect from typing import ( @@ -77,12 +76,6 @@ def main(backend, grid) DomainType = TypeVar("DomainType", h_grid.Domain, v_grid.Domain) -class RetrievalType(enum.Enum): - FIELD = 0 - DATA_ARRAY = 1 - METADATA = 2 - - class FieldProvider(Protocol): """ Protocol for field providers. @@ -369,12 +362,6 @@ def _check( ) -class FieldSource(Protocol): - """Protocol for object that can be queried for fields.""" - def get(self, field_name: str, type_: RetrievalType = RetrievalType.FIELD): - ... - - class PartialConfigurable(Protocol): """ Protocol to mark classes that are not yet fully configured upon instaniation. @@ -382,6 +369,7 @@ class PartialConfigurable(Protocol): Additionally provides a decorator that makes use of the Protocol an can be used in concrete examples to trigger a check whether the setup is complete. """ + def is_fully_configured(self) -> bool: return False @@ -396,7 +384,7 @@ def wrapper(self, *args, **kwargs): return wrapper -class FieldsFactory(FieldSource, PartialConfigurable): +class FieldsFactory(state_utils.FieldSource, PartialConfigurable): def __init__( self, metadata: dict[str, model.FieldMetaData], @@ -459,14 +447,14 @@ def register_provider(self, provider: FieldProvider): @PartialConfigurable.check_setup def get( - self, field_name: str, type_: RetrievalType = RetrievalType.FIELD + self, field_name: str, type_: state_utils.RetrievalType = state_utils.RetrievalType.FIELD ) -> Union[state_utils.FieldType, xa.DataArray, model.FieldMetaData]: if field_name not in self._providers: raise ValueError(f"Field {field_name} not provided by the factory") match type_: - case RetrievalType.METADATA: + case state_utils.RetrievalType.METADATA: return self._metadata[field_name] - case RetrievalType.FIELD | RetrievalType.DATA_ARRAY: + case state_utils.RetrievalType.FIELD | state_utils.RetrievalType.DATA_ARRAY: provider = self._providers[field_name] if field_name not in provider.fields: raise ValueError( @@ -476,7 +464,7 @@ def get( buffer = provider(field_name, self) return ( buffer - if type_ == RetrievalType.FIELD + if type_ == state_utils.RetrievalType.FIELD else state_utils.to_data_array(buffer, self._metadata[field_name]) ) case _: diff --git a/model/common/src/icon4py/model/common/states/utils.py b/model/common/src/icon4py/model/common/states/utils.py index e8ad795ae..29035f0f6 100644 --- a/model/common/src/icon4py/model/common/states/utils.py +++ b/model/common/src/icon4py/model/common/states/utils.py @@ -5,8 +5,8 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - -from typing import Sequence, TypeAlias, TypeVar, Union +import enum +from typing import Protocol, Sequence, TypeAlias, TypeVar, Union import gt4py.next as gtx import xarray as xa @@ -28,3 +28,16 @@ def to_data_array(field: FieldType, attrs: dict): data = field if isinstance(field, xp.ndarray) else field.ndarray return xa.DataArray(data, attrs=attrs) + + +class RetrievalType(enum.Enum): + FIELD = 0 + DATA_ARRAY = 1 + METADATA = 2 + + +class FieldSource(Protocol): + """Protocol for object that can be queried for fields.""" + + def get(self, field_name: str, type_: RetrievalType = RetrievalType.FIELD): + ... diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index 742eaf774..f8f98d7fb 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -9,6 +9,7 @@ import gt4py.next as gtx import pytest +import icon4py.model.common.states.utils as state_utils import icon4py.model.common.test_utils.helpers as helpers from icon4py.model.common import dimension as dims, exceptions from icon4py.model.common.grid import horizontal as h_grid, vertical as v_grid @@ -85,16 +86,18 @@ def test_factory_returns_field(grid_savepoint, metrics_savepoint, backend): fields_factory = factory.FieldsFactory(metadata=metadata.attrs) fields_factory.register_provider(pre_computed_fields) fields_factory.with_grid(grid, vertical).with_backend(backend) - field = fields_factory.get("height_on_interface_levels", factory.RetrievalType.FIELD) + field = fields_factory.get("height_on_interface_levels", state_utils.RetrievalType.FIELD) assert field.ndarray.shape == (grid.num_cells, num_levels + 1) - meta = fields_factory.get("height_on_interface_levels", factory.RetrievalType.METADATA) + meta = fields_factory.get("height_on_interface_levels", state_utils.RetrievalType.METADATA) assert meta["standard_name"] == "height_on_interface_levels" assert meta["dims"] == ( dims.CellDim, dims.KHalfDim, ) assert meta["units"] == "m" - data_array = fields_factory.get("height_on_interface_levels", factory.RetrievalType.DATA_ARRAY) + data_array = fields_factory.get( + "height_on_interface_levels", state_utils.RetrievalType.DATA_ARRAY + ) assert data_array.data.shape == (grid.num_cells, num_levels + 1) assert data_array.data.dtype == xp.float64 for key in ("dims", "standard_name", "units", "icon_var_name"): @@ -157,7 +160,8 @@ def test_field_provider_for_program(grid_savepoint, metrics_savepoint, backend): fields_factory.register_provider(functional_determinant_provider) fields_factory.with_grid(horizontal_grid, vertical_grid).with_backend(backend) data = fields_factory.get( - "functional_determinant_of_metrics_on_interface_levels", type_=factory.RetrievalType.FIELD + "functional_determinant_of_metrics_on_interface_levels", + type_=state_utils.RetrievalType.FIELD, ) ref = metrics_savepoint.ddqz_z_half().ndarray assert helpers.dallclose(data.ndarray, ref) @@ -202,7 +206,8 @@ def test_field_provider_for_numpy_function( fields_factory.register_provider(compute_wgtfacq_c_provider) wgtfacq_c = fields_factory.get( - "weighting_factor_for_quadratic_interpolation_to_cell_surface", factory.RetrievalType.FIELD + "weighting_factor_for_quadratic_interpolation_to_cell_surface", + state_utils.RetrievalType.FIELD, ) assert helpers.dallclose(wgtfacq_c.asnumpy(), wgtfacq_c_ref.asnumpy()) @@ -262,7 +267,8 @@ def test_field_provider_for_numpy_function_with_offsets( fields_factory.register_provider(wgtfacq_e_provider) wgtfacq_e = fields_factory.get( - "weighting_factor_for_quadratic_interpolation_to_edge_center", factory.RetrievalType.FIELD + "weighting_factor_for_quadratic_interpolation_to_edge_center", + state_utils.RetrievalType.FIELD, ) assert helpers.dallclose(wgtfacq_e.asnumpy(), wgtfacq_e_ref.asnumpy()) From 19ceae88040240823a9cdb6f031fec0dc3faa0a8 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Fri, 11 Oct 2024 12:59:12 +0200 Subject: [PATCH 36/37] add return type to FieldSource.get(...) --- model/common/src/icon4py/model/common/states/factory.py | 1 - model/common/src/icon4py/model/common/states/utils.py | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 67be65cc9..33d140f4c 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -133,7 +133,6 @@ class ProgramFieldProvider(FieldProvider): """ Computes a field defined by a GT4Py Program. - TODO (halungge): use field_operator instead? TODO (halungge): need a way to specify where the dependencies and params can be retrieved. As not all parameters can be resolved at the definition time diff --git a/model/common/src/icon4py/model/common/states/utils.py b/model/common/src/icon4py/model/common/states/utils.py index 29035f0f6..3eb9a88d7 100644 --- a/model/common/src/icon4py/model/common/states/utils.py +++ b/model/common/src/icon4py/model/common/states/utils.py @@ -13,6 +13,7 @@ from icon4py.model.common import dimension as dims, type_alias as ta from icon4py.model.common.settings import xp +from icon4py.model.common.states import model T = TypeVar("T", ta.wpfloat, ta.vpfloat, float, bool, gtx.int32, gtx.int64) @@ -39,5 +40,7 @@ class RetrievalType(enum.Enum): class FieldSource(Protocol): """Protocol for object that can be queried for fields.""" - def get(self, field_name: str, type_: RetrievalType = RetrievalType.FIELD): + def get( + self, field_name: str, type_: RetrievalType = RetrievalType.FIELD + ) -> Union[FieldType, xa.DataArray, model.FieldMetaData]: ... From ae937d46ec979e0d8d8fbe0253aa48a3c76cc0e5 Mon Sep 17 00:00:00 2001 From: Magdalena Luz Date: Fri, 11 Oct 2024 14:20:32 +0200 Subject: [PATCH 37/37] Split factory argument in FieldProvider to several protocols --- .../icon4py/model/common/states/factory.py | 69 ++++++++++--------- .../common/tests/states_test/test_factory.py | 2 +- 2 files changed, 39 insertions(+), 32 deletions(-) diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 33d140f4c..9f1860d62 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -58,10 +58,11 @@ def main(backend, grid) ) import gt4py.next as gtx +import gt4py.next.backend as gtx_backend import gt4py.next.ffront.decorator as gtx_decorator import xarray as xa -from icon4py.model.common import dimension as dims, exceptions, settings +from icon4py.model.common import dimension as dims, exceptions from icon4py.model.common.grid import ( base as base_grid, horizontal as h_grid, @@ -69,12 +70,21 @@ def main(backend, grid) vertical as v_grid, ) from icon4py.model.common.settings import xp -from icon4py.model.common.states import metadata as metadata, model, utils as state_utils +from icon4py.model.common.states import model, utils as state_utils from icon4py.model.common.utils import builder DomainType = TypeVar("DomainType", h_grid.Domain, v_grid.Domain) +class GridProvider(Protocol): + @property + def grid(self)-> Optional[icon_grid.IconGrid]: + ... + + @property + def vertical_grid(self) -> Optional[v_grid.VerticalGrid]: + ... + class FieldProvider(Protocol): """ @@ -90,7 +100,7 @@ class FieldProvider(Protocol): """ - def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType: + def __call__(self, field_name: str, field_src: Optional[state_utils.FieldSource], backend:Optional[gtx_backend.Backend], grid: Optional[GridProvider]) -> state_utils.FieldType: ... @property @@ -117,7 +127,7 @@ def __init__(self, fields: dict[str, state_utils.FieldType]): def dependencies(self) -> Sequence[str]: return () - def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType: + def __call__(self, field_name: str, field_src = None, backend = None, grid = None) -> state_utils.FieldType: return self.fields[field_name] @property @@ -168,7 +178,7 @@ def __init__( def _unallocated(self) -> bool: return not all(self._fields.values()) - def _allocate(self, allocator, grid: base_grid.BaseGrid) -> dict[str, state_utils.FieldType]: + def _allocate(self, backend: gtx_backend.Backend, grid: base_grid.BaseGrid, metadata: dict[str, model.FieldMetaData]) -> dict[str, state_utils.FieldType]: def _map_size(dim: gtx.Dimension, grid: base_grid.BaseGrid) -> int: if dim == dims.KHalfDim: return grid.num_levels + 1 @@ -179,18 +189,19 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension: return dims.KDim return dim + allocate = gtx.constructors.zeros.partial(allocator=backend) field_domain = { _map_dim(dim): (0, _map_size(dim, grid)) for dim in self._compute_domain.keys() } return { - k: allocator(field_domain, dtype=metadata.attrs[k]["dtype"]) + k: allocate(field_domain, dtype=metadata[k]["dtype"]) for k in self._fields.keys() } # TODO (@halungge) this can be simplified when completely disentangling vertical and horizontal grid. # the IconGrid should then only contain horizontal connectivities and no longer any Koff which should be moved to the VerticalGrid def _get_offset_providers( - self, grid: icon_grid.IconGrid, vertical_grid: v_grid.VerticalGrid + self, grid: icon_grid.IconGrid ) -> dict[str, gtx.FieldOffset]: offset_providers = {} for dim in self._compute_domain.keys(): @@ -235,20 +246,21 @@ def _domain_args( raise ValueError(f"DimensionKind '{dim.kind}' not supported in Program Domain") return domain_args - def __call__(self, field_name: str, factory: "FieldsFactory"): + def __call__(self, field_name: str, factory:state_utils.FieldSource, backend:gtx_backend.Backend, grid_provider:GridProvider ): if any([f is None for f in self.fields.values()]): - self._compute(factory) + self._compute(factory, backend, grid_provider) return self.fields[field_name] - def _compute(self, factory) -> None: - self._fields = self._allocate(factory.allocator, factory.grid) + def _compute(self, factory:state_utils.FieldSource, backend:gtx_backend.Backend, grid_provider:GridProvider) -> None: + metadata = {v: factory.get(v, state_utils.RetrievalType.METADATA) for k, v in self._output.items()} + self._fields = self._allocate(backend, grid_provider.grid, metadata) deps = {k: factory.get(v) for k, v in self._dependencies.items()} deps.update(self._params) deps.update({k: self._fields[v] for k, v in self._output.items()}) - dims = self._domain_args(factory.grid, factory.vertical_grid) - offset_providers = self._get_offset_providers(factory.grid, factory.vertical_grid) + dims = self._domain_args(grid_provider.grid, grid_provider.vertical_grid) + offset_providers = self._get_offset_providers(grid_provider.grid) deps.update(dims) - self._func.with_backend(factory._backend)(**deps, offset_provider=offset_providers) + self._func.with_backend(backend)(**deps, offset_provider=offset_providers) @property def fields(self) -> Mapping[str, state_utils.FieldType]: @@ -297,22 +309,22 @@ def __init__( self._offsets = offsets if offsets is not None else {} self._params = params if params is not None else {} - def __call__(self, field_name: str, factory: "FieldsFactory") -> None: + def __call__(self, field_name: str, factory:state_utils.FieldSource, backend:gtx_backend.Backend, grid: GridProvider) -> state_utils.FieldType: if any([f is None for f in self.fields.values()]): - self._compute(factory) + self._compute(factory, backend, grid) return self.fields[field_name] - def _compute(self, factory) -> None: + def _compute(self, factory:state_utils.FieldSource, backend:gtx_backend.Backend, grid_provider:GridProvider) -> None: self._validate_dependencies() args = {k: factory.get(v).ndarray for k, v in self._dependencies.items()} - offsets = {k: factory.grid.connectivities[v] for k, v in self._offsets.items()} + offsets = {k: grid_provider.grid.connectivities[v] for k, v in self._offsets.items()} args.update(offsets) args.update(self._params) results = self._func(**args) ## TODO: can the order of return values be checked? results = (results,) if isinstance(results, xp.ndarray) else results self._fields = { - k: gtx.as_field(tuple(self._dims), results[i]) for i, k in enumerate(self.fields) + k: gtx.as_field(tuple(self._dims), results[i], allocator = backend) for i, k in enumerate(self.fields) } def _validate_dependencies(self): @@ -383,20 +395,19 @@ def wrapper(self, *args, **kwargs): return wrapper -class FieldsFactory(state_utils.FieldSource, PartialConfigurable): +class FieldsFactory(state_utils.FieldSource, PartialConfigurable, GridProvider): def __init__( self, metadata: dict[str, model.FieldMetaData], - grid: icon_grid.IconGrid = None, - vertical_grid: v_grid.VerticalGrid = None, - backend=None, + grid: Optional[icon_grid.IconGrid] = None, + vertical_grid: Optional[v_grid.VerticalGrid] = None, + backend:Optional[gtx_backend.Backend]=None, ): self._metadata = metadata self._grid = grid self._vertical = vertical_grid self._providers: dict[str, "FieldProvider"] = {} self._backend = backend - self._allocator = gtx.constructors.zeros.partial(allocator=backend) """ Factory for fields. @@ -411,14 +422,13 @@ def is_fully_configured(self): return has_grid and has_vertical @builder.builder - def with_grid(self, grid: base_grid.BaseGrid, vertical_grid: v_grid.VerticalGrid): + def with_grid(self, grid: icon_grid.IconGrid, vertical_grid: v_grid.VerticalGrid): self._grid = grid self._vertical = vertical_grid @builder.builder - def with_backend(self, backend=settings.backend): + def with_backend(self, backend): self._backend = backend - self._allocator = gtx.constructors.zeros.partial(allocator=backend) @property def backend(self): @@ -432,9 +442,6 @@ def grid(self): def vertical_grid(self): return self._vertical - @property - def allocator(self): - return self._allocator def register_provider(self, provider: FieldProvider): for dependency in provider.dependencies: @@ -460,7 +467,7 @@ def get( f"Field {field_name} not provided by f{provider.func.__name__}." ) - buffer = provider(field_name, self) + buffer = provider(field_name, self, self.backend, self) return ( buffer if type_ == state_utils.RetrievalType.FIELD diff --git a/model/common/tests/states_test/test_factory.py b/model/common/tests/states_test/test_factory.py index f8f98d7fb..901ab5268 100644 --- a/model/common/tests/states_test/test_factory.py +++ b/model/common/tests/states_test/test_factory.py @@ -170,7 +170,7 @@ def test_field_provider_for_program(grid_savepoint, metrics_savepoint, backend): def test_field_provider_for_numpy_function( grid_savepoint, metrics_savepoint, interpolation_savepoint, backend ): - grid = grid_savepoint.construct_icon_grid(False) # TODO fix this should be come obsolete + grid = grid_savepoint.construct_icon_grid(False) vertical_grid = v_grid.VerticalGrid( v_grid.VerticalGridConfig(num_levels=grid.num_levels), grid_savepoint.vct_a(),