Skip to content

Commit

Permalink
Split factory argument in FieldProvider to several protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
halungge committed Oct 11, 2024
1 parent 19ceae8 commit ae937d4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
69 changes: 38 additions & 31 deletions model/common/src/icon4py/model/common/states/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,33 @@ def main(backend, grid)
)

import gt4py.next as gtx
import gt4py.next.backend as gtx_backend
import gt4py.next.ffront.decorator as gtx_decorator
import xarray as xa

from icon4py.model.common import dimension as dims, exceptions, settings
from icon4py.model.common import dimension as dims, exceptions
from icon4py.model.common.grid import (
base as base_grid,
horizontal as h_grid,
icon as icon_grid,
vertical as v_grid,
)
from icon4py.model.common.settings import xp
from icon4py.model.common.states import metadata as metadata, model, utils as state_utils
from icon4py.model.common.states import model, utils as state_utils
from icon4py.model.common.utils import builder


DomainType = TypeVar("DomainType", h_grid.Domain, v_grid.Domain)

class GridProvider(Protocol):
@property
def grid(self)-> Optional[icon_grid.IconGrid]:
...

@property
def vertical_grid(self) -> Optional[v_grid.VerticalGrid]:
...


class FieldProvider(Protocol):
"""
Expand All @@ -90,7 +100,7 @@ class FieldProvider(Protocol):
"""

def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType:
def __call__(self, field_name: str, field_src: Optional[state_utils.FieldSource], backend:Optional[gtx_backend.Backend], grid: Optional[GridProvider]) -> state_utils.FieldType:
...

@property
Expand All @@ -117,7 +127,7 @@ def __init__(self, fields: dict[str, state_utils.FieldType]):
def dependencies(self) -> Sequence[str]:
return ()

def __call__(self, field_name: str, factory: "FieldsFactory") -> state_utils.FieldType:
def __call__(self, field_name: str, field_src = None, backend = None, grid = None) -> state_utils.FieldType:
return self.fields[field_name]

@property
Expand Down Expand Up @@ -168,7 +178,7 @@ def __init__(
def _unallocated(self) -> bool:
return not all(self._fields.values())

def _allocate(self, allocator, grid: base_grid.BaseGrid) -> dict[str, state_utils.FieldType]:
def _allocate(self, backend: gtx_backend.Backend, grid: base_grid.BaseGrid, metadata: dict[str, model.FieldMetaData]) -> dict[str, state_utils.FieldType]:
def _map_size(dim: gtx.Dimension, grid: base_grid.BaseGrid) -> int:
if dim == dims.KHalfDim:
return grid.num_levels + 1
Expand All @@ -179,18 +189,19 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension:
return dims.KDim
return dim

allocate = gtx.constructors.zeros.partial(allocator=backend)
field_domain = {
_map_dim(dim): (0, _map_size(dim, grid)) for dim in self._compute_domain.keys()
}
return {
k: allocator(field_domain, dtype=metadata.attrs[k]["dtype"])
k: allocate(field_domain, dtype=metadata[k]["dtype"])
for k in self._fields.keys()
}

# TODO (@halungge) this can be simplified when completely disentangling vertical and horizontal grid.
# the IconGrid should then only contain horizontal connectivities and no longer any Koff which should be moved to the VerticalGrid
def _get_offset_providers(
self, grid: icon_grid.IconGrid, vertical_grid: v_grid.VerticalGrid
self, grid: icon_grid.IconGrid
) -> dict[str, gtx.FieldOffset]:
offset_providers = {}
for dim in self._compute_domain.keys():
Expand Down Expand Up @@ -235,20 +246,21 @@ def _domain_args(
raise ValueError(f"DimensionKind '{dim.kind}' not supported in Program Domain")
return domain_args

def __call__(self, field_name: str, factory: "FieldsFactory"):
def __call__(self, field_name: str, factory:state_utils.FieldSource, backend:gtx_backend.Backend, grid_provider:GridProvider ):
if any([f is None for f in self.fields.values()]):
self._compute(factory)
self._compute(factory, backend, grid_provider)
return self.fields[field_name]

def _compute(self, factory) -> None:
self._fields = self._allocate(factory.allocator, factory.grid)
def _compute(self, factory:state_utils.FieldSource, backend:gtx_backend.Backend, grid_provider:GridProvider) -> None:
metadata = {v: factory.get(v, state_utils.RetrievalType.METADATA) for k, v in self._output.items()}
self._fields = self._allocate(backend, grid_provider.grid, metadata)
deps = {k: factory.get(v) for k, v in self._dependencies.items()}
deps.update(self._params)
deps.update({k: self._fields[v] for k, v in self._output.items()})
dims = self._domain_args(factory.grid, factory.vertical_grid)
offset_providers = self._get_offset_providers(factory.grid, factory.vertical_grid)
dims = self._domain_args(grid_provider.grid, grid_provider.vertical_grid)
offset_providers = self._get_offset_providers(grid_provider.grid)
deps.update(dims)
self._func.with_backend(factory._backend)(**deps, offset_provider=offset_providers)
self._func.with_backend(backend)(**deps, offset_provider=offset_providers)

@property
def fields(self) -> Mapping[str, state_utils.FieldType]:
Expand Down Expand Up @@ -297,22 +309,22 @@ def __init__(
self._offsets = offsets if offsets is not None else {}
self._params = params if params is not None else {}

def __call__(self, field_name: str, factory: "FieldsFactory") -> None:
def __call__(self, field_name: str, factory:state_utils.FieldSource, backend:gtx_backend.Backend, grid: GridProvider) -> state_utils.FieldType:
if any([f is None for f in self.fields.values()]):
self._compute(factory)
self._compute(factory, backend, grid)
return self.fields[field_name]

def _compute(self, factory) -> None:
def _compute(self, factory:state_utils.FieldSource, backend:gtx_backend.Backend, grid_provider:GridProvider) -> None:
self._validate_dependencies()
args = {k: factory.get(v).ndarray for k, v in self._dependencies.items()}
offsets = {k: factory.grid.connectivities[v] for k, v in self._offsets.items()}
offsets = {k: grid_provider.grid.connectivities[v] for k, v in self._offsets.items()}
args.update(offsets)
args.update(self._params)
results = self._func(**args)
## TODO: can the order of return values be checked?
results = (results,) if isinstance(results, xp.ndarray) else results
self._fields = {
k: gtx.as_field(tuple(self._dims), results[i]) for i, k in enumerate(self.fields)
k: gtx.as_field(tuple(self._dims), results[i], allocator = backend) for i, k in enumerate(self.fields)
}

def _validate_dependencies(self):
Expand Down Expand Up @@ -383,20 +395,19 @@ def wrapper(self, *args, **kwargs):
return wrapper


class FieldsFactory(state_utils.FieldSource, PartialConfigurable):
class FieldsFactory(state_utils.FieldSource, PartialConfigurable, GridProvider):
def __init__(
self,
metadata: dict[str, model.FieldMetaData],
grid: icon_grid.IconGrid = None,
vertical_grid: v_grid.VerticalGrid = None,
backend=None,
grid: Optional[icon_grid.IconGrid] = None,
vertical_grid: Optional[v_grid.VerticalGrid] = None,
backend:Optional[gtx_backend.Backend]=None,
):
self._metadata = metadata
self._grid = grid
self._vertical = vertical_grid
self._providers: dict[str, "FieldProvider"] = {}
self._backend = backend
self._allocator = gtx.constructors.zeros.partial(allocator=backend)

"""
Factory for fields.
Expand All @@ -411,14 +422,13 @@ def is_fully_configured(self):
return has_grid and has_vertical

@builder.builder
def with_grid(self, grid: base_grid.BaseGrid, vertical_grid: v_grid.VerticalGrid):
def with_grid(self, grid: icon_grid.IconGrid, vertical_grid: v_grid.VerticalGrid):
self._grid = grid
self._vertical = vertical_grid

@builder.builder
def with_backend(self, backend=settings.backend):
def with_backend(self, backend):
self._backend = backend
self._allocator = gtx.constructors.zeros.partial(allocator=backend)

@property
def backend(self):
Expand All @@ -432,9 +442,6 @@ def grid(self):
def vertical_grid(self):
return self._vertical

@property
def allocator(self):
return self._allocator

def register_provider(self, provider: FieldProvider):
for dependency in provider.dependencies:
Expand All @@ -460,7 +467,7 @@ def get(
f"Field {field_name} not provided by f{provider.func.__name__}."
)

buffer = provider(field_name, self)
buffer = provider(field_name, self, self.backend, self)
return (
buffer
if type_ == state_utils.RetrievalType.FIELD
Expand Down
2 changes: 1 addition & 1 deletion model/common/tests/states_test/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_field_provider_for_program(grid_savepoint, metrics_savepoint, backend):
def test_field_provider_for_numpy_function(
grid_savepoint, metrics_savepoint, interpolation_savepoint, backend
):
grid = grid_savepoint.construct_icon_grid(False) # TODO fix this should be come obsolete
grid = grid_savepoint.construct_icon_grid(False)
vertical_grid = v_grid.VerticalGrid(
v_grid.VerticalGridConfig(num_levels=grid.num_levels),
grid_savepoint.vct_a(),
Expand Down

0 comments on commit ae937d4

Please sign in to comment.