Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DaCe Orchestration for the Diffusion Granule #514

Open
wants to merge 185 commits into
base: main
Choose a base branch
from

Conversation

kotsaloscv
Copy link
Collaborator

@kotsaloscv kotsaloscv commented Jul 30, 2024

This PR concerns only the DaCe backend.

Currently, the Diffusion granule calls the various stencils one after the other (in _do_diffusion_step ). This means that for every stencil, there is a separate SDFG, and consequently one stencil is not aware of the others (this behavior limits the analyzability of the full granule). This PR introduces a decorator that fuses all these SDFGs under one compilation unit, allowing DaCe for further analysis and optimizations. Placing a GT4Py program inside a DaCe program region, and extracting the underlying SDFG, is possible due to this GT4Py PR.

The halo exchanges are also taken care from the DaCe orchestrator. A follow-up PR will introduce automated halo exchanges. Currently, the halo exchange class implements the SDFGConvertible interface like GT4Py Program.

The orchestrator is activated either through an env var ICON4PY_DACE_ORCHESTRATION=AnyValue or through this pytest option --dace-orchestration=AnyValue.

The orchestrator suppports ahead of time compilation, however given that DaCe does not support nested Structures -like self in the Diffusion class-, some of the arguments need to be provided at compile time, through dace.compiletime. This annotation means that the correspoding argument will be considered from the closure of the function.

The orchestrator provides full caching support and takes into consideration when a dace.compiletime argument changes, with a subsequent re-compilation of the fused SDFG.

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run default

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run dace

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run default

@kotsaloscv
Copy link
Collaborator Author

launch jenkins spack

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run default

@kotsaloscv
Copy link
Collaborator Author

launch jenkins spack

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run dace

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run default

@kotsaloscv
Copy link
Collaborator Author

launch jenkins spack

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run dace

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run default

@kotsaloscv
Copy link
Collaborator Author

launch jenkins spack

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run dace

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run default

@kotsaloscv
Copy link
Collaborator Author

launch jenkins spack

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run dace

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run default

Copy link
Contributor

@egparedes egparedes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only looked at the decorator.py module and tried to provide as much information and suggestions as possible to make your life easier and try to get this done and merged ASAP.

Comment on lines 62 to 159
def orchestrate(func: Callable | None = None, *, method: bool | None = None):
def _decorator(fuse_func: Callable):
compiled_sdfgs = {} # Caching

def wrapper(*args, **kwargs):
if settings.dace_orchestration is not None:
if "dace" not in settings.backend.name.lower():
raise ValueError(
"DaCe Orchestration works only with DaCe backends. Change the backend to a DaCe supported one."
)

if method:
# self is used to retrieve the _exchange object -on the fly halo exchanges- and the grid object -offset providers-
self = args[0]
self_name = next(iter(inspect.signature(fuse_func).parameters))
else:
raise ValueError(
"The orchestration decorator is only for methods -at least for now-."
)

fuse_func_orig_annotations = copy.deepcopy(fuse_func.__annotations__)
fuse_func.__annotations__ = to_dace_annotations(
fuse_func
) # every arg/kwarg is annotated with DaCe data types

exchange_obj = None
grid = None
for attr_name, attr_value in self.__dict__.items():
if isinstance(attr_value, decomposition.ExchangeRuntime):
exchange_obj = getattr(self, attr_name)
if isinstance(attr_value, icon_grid.IconGrid):
grid = getattr(self, attr_name)

if not grid:
raise ValueError("No grid object found.")

order_kwargs_by_annotations(fuse_func, kwargs)

compile_time_args_kwargs = {}
all_args_kwargs = [*args, *kwargs.values()]
for i, (k, v) in enumerate(fuse_func.__annotations__.items()):
if v is dace.compiletime:
compile_time_args_kwargs[k] = all_args_kwargs[i]

unique_id = make_uid(fuse_func, compile_time_args_kwargs, exchange_obj)

default_build_folder = Path(".dacecache") / f"uid_{unique_id}"

parse_compile_cache_sdfg(
unique_id,
compiled_sdfgs,
default_build_folder,
exchange_obj,
fuse_func,
compile_time_args_kwargs,
self_name,
simplify_fused_sdfg=True,
)
dace_program = compiled_sdfgs[unique_id]["dace_program"]
sdfg = compiled_sdfgs[unique_id]["sdfg"]
compiled_sdfg = compiled_sdfgs[unique_id]["compiled_sdfg"]

# update the args/kwargs with runtime related values, such as
# concretized symbols, runtime connectivity tables, GHEX C++ pointers, and DaCe structures pointers
updated_args, updated_kwargs = mod_xargs_for_dace_structures(
fuse_func, fuse_func_orig_annotations, args, kwargs
)
updated_kwargs = {
**updated_kwargs,
**dace_specific_kwargs(exchange_obj, grid.offset_providers),
}
updated_kwargs = {
**updated_kwargs,
**dace_symbols_concretization(
grid, fuse_func, fuse_func_orig_annotations, args, kwargs
),
}
#

sdfg_args = dace_program._create_sdfg_args(sdfg, updated_args, updated_kwargs)
if method:
del sdfg_args[self_name]

fuse_func.__annotations__ = (
fuse_func_orig_annotations # restore the original annotations
)

with dace.config.temporary_config():
dace.config.Config.set(
"compiler", "allow_view_arguments", value=True
) # Allow numpy views as arguments: If true, allows users to call DaCe programs with NumPy views (for example, “A[:,1]” or “w.T”)
return compiled_sdfg(**sdfg_args)
else:
return fuse_func(*args, **kwargs)

return wrapper

return _decorator(func) if func else _decorator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would push as much as possible work from the run-time function wrapper to the decorator, since the decorator is only executed once at decoration time. I think there were still some errors in the logic to sort the kwargs in case some of them are missing at runtime and I would move the cache lookup out of the parse_compile function, since I don't think it belongs there and forces to make an extra function call for no reason at runtime.

Here is a possible rewrite of the function addressing some of the issues. I think there is still room for further improvements and cleanups at the bottom of the wrapper, where the dace compilation and some other processing seems to happen but I don't know enough about DaCe to judge if all the operations there need to be repeated at every function call.

Suggested change
def orchestrate(func: Callable | None = None, *, method: bool | None = None):
def _decorator(fuse_func: Callable):
compiled_sdfgs = {} # Caching
def wrapper(*args, **kwargs):
if settings.dace_orchestration is not None:
if "dace" not in settings.backend.name.lower():
raise ValueError(
"DaCe Orchestration works only with DaCe backends. Change the backend to a DaCe supported one."
)
if method:
# self is used to retrieve the _exchange object -on the fly halo exchanges- and the grid object -offset providers-
self = args[0]
self_name = next(iter(inspect.signature(fuse_func).parameters))
else:
raise ValueError(
"The orchestration decorator is only for methods -at least for now-."
)
fuse_func_orig_annotations = copy.deepcopy(fuse_func.__annotations__)
fuse_func.__annotations__ = to_dace_annotations(
fuse_func
) # every arg/kwarg is annotated with DaCe data types
exchange_obj = None
grid = None
for attr_name, attr_value in self.__dict__.items():
if isinstance(attr_value, decomposition.ExchangeRuntime):
exchange_obj = getattr(self, attr_name)
if isinstance(attr_value, icon_grid.IconGrid):
grid = getattr(self, attr_name)
if not grid:
raise ValueError("No grid object found.")
order_kwargs_by_annotations(fuse_func, kwargs)
compile_time_args_kwargs = {}
all_args_kwargs = [*args, *kwargs.values()]
for i, (k, v) in enumerate(fuse_func.__annotations__.items()):
if v is dace.compiletime:
compile_time_args_kwargs[k] = all_args_kwargs[i]
unique_id = make_uid(fuse_func, compile_time_args_kwargs, exchange_obj)
default_build_folder = Path(".dacecache") / f"uid_{unique_id}"
parse_compile_cache_sdfg(
unique_id,
compiled_sdfgs,
default_build_folder,
exchange_obj,
fuse_func,
compile_time_args_kwargs,
self_name,
simplify_fused_sdfg=True,
)
dace_program = compiled_sdfgs[unique_id]["dace_program"]
sdfg = compiled_sdfgs[unique_id]["sdfg"]
compiled_sdfg = compiled_sdfgs[unique_id]["compiled_sdfg"]
# update the args/kwargs with runtime related values, such as
# concretized symbols, runtime connectivity tables, GHEX C++ pointers, and DaCe structures pointers
updated_args, updated_kwargs = mod_xargs_for_dace_structures(
fuse_func, fuse_func_orig_annotations, args, kwargs
)
updated_kwargs = {
**updated_kwargs,
**dace_specific_kwargs(exchange_obj, grid.offset_providers),
}
updated_kwargs = {
**updated_kwargs,
**dace_symbols_concretization(
grid, fuse_func, fuse_func_orig_annotations, args, kwargs
),
}
#
sdfg_args = dace_program._create_sdfg_args(sdfg, updated_args, updated_kwargs)
if method:
del sdfg_args[self_name]
fuse_func.__annotations__ = (
fuse_func_orig_annotations # restore the original annotations
)
with dace.config.temporary_config():
dace.config.Config.set(
"compiler", "allow_view_arguments", value=True
) # Allow numpy views as arguments: If true, allows users to call DaCe programs with NumPy views (for example, “A[:,1]” or “w.T”)
return compiled_sdfg(**sdfg_args)
else:
return fuse_func(*args, **kwargs)
return wrapper
return _decorator(func) if func else _decorator
def orchestrate(func: Callable | None = None, *, method: bool | None = None):
def _decorator(fuse_func: Callable):
if settings.dace_orchestration is not None:
if "dace" not in settings.backend.name.lower():
raise ValueError(
"DaCe Orchestration works only with DaCe backends. Change the backend to a DaCe supported one."
)
self_name = next(iter(inspect.signature(fuse_func).parameters))
if method is None:
# Assume the provided callable is a method if its first argument is called 'self'
method = self_name == "self"
if not method:
raise ValueError(
"The orchestration decorator is only for methods -at least for now-."
)
local_cache = {} # Caching compiled func versions
def wrapper(*args, **kwargs):
# self is used to retrieve the _exchange object -on the fly halo exchanges- and the grid object -offset providers-
self = args[0]
exchange_obj = None
grid = None
for attr_name, attr_value in self.__dict__.items():
if isinstance(attr_value, decomposition.ExchangeRuntime):
exchange_obj = getattr(self, attr_name)
elif isinstance(attr_value, icon_grid.IconGrid):
grid = getattr(self, attr_name)
# Use assert here to allow disabling the check when running in production
assert grid is not None, "No grid object found in the call arguments."
# Add DaCe data types annotations for all args and kwargs
dace_annotations = to_dace_annotations(fuse_func)
# To extract the actual values from the function parameters defined as compile-time,
# we first need to sort the run-time arguments according to their definition
# order and also adding `None`s for the missing ones to make use we don't use
# the wrong one by mistake.
ordered_kwargs = [kwargs.get(key, None) for key in dace_annotations]
all_args = [*args, *ordered_kwargs]
compile_time_args_kwargs = {
arg
for arg, (k, v) in zip(all_args, dace_annotations.items(), strict=True)
if v is dace.compiletime
}
unique_id = make_uid(fuse_func, compile_time_args_kwargs, exchange_obj)
if (cache_item := local_cache.get(unique_id, None)) is None:
fuse_func_orig_annotations = fuse_func.__annotations__
fuse_func.__annotations__ = dace_annotations
default_build_folder = Path(".dacecache") / f"uid_{unique_id}"
cache_item = local_cache[unique_id] = parse_compile_cache_sdfg(
default_build_folder,
exchange_obj,
fuse_func,
compile_time_args_kwargs,
self_name,
simplify_fused_sdfg=True,
)
dace_program = cache_item["dace_program"]
sdfg = cache_item["sdfg"]
compiled_sdfg = cache_item["compiled_sdfg"]
# update the args/kwargs with runtime related values, such as
# concretized symbols, runtime connectivity tables, GHEX C++ pointers, and DaCe structures pointers
updated_args, updated_kwargs = mod_xargs_for_dace_structures(
fuse_func, fuse_func_orig_annotations, args, kwargs
)
updated_kwargs = {
**updated_kwargs,
**dace_specific_kwargs(exchange_obj, grid.offset_providers),
}
updated_kwargs = {
**updated_kwargs,
**dace_symbols_concretization(
grid, fuse_func, fuse_func_orig_annotations, args, kwargs
),
}
#
sdfg_args = dace_program._create_sdfg_args(sdfg, updated_args, updated_kwargs)
if method:
del sdfg_args[self_name]
fuse_func.__annotations__ = fuse_func_orig_annotations
with dace.config.temporary_config():
dace.config.Config.set(
"compiler", "allow_view_arguments", value=True
) # Allow numpy views as arguments: If true, allows users to call DaCe programs with NumPy views (for example, “A[:,1]” or “w.T”)
return compiled_sdfg(**sdfg_args)
return wrapper
else:
return fuse_func
return _decorator(func) if func else _decorator

Comment on lines +97 to +98
# Add DaCe data types annotations for **all args and kwargs**
dace_annotations = to_dace_annotations(fuse_func)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was my mistake, but this should be moved out of the wrapper to the decorator, since it only needs to happen once at decoration time.

Comment on lines +206 to +264
def generate_orchestration_uid(
obj: Any, obj_name: str = "", members_to_disregard: tuple[str] = ()
) -> str:
"""Generate a unique id for a runtime object.

The unique id is generated by creating a dictionary that describes the runtime state of the object.
For primitive types, the dictionary contains the type and the value.
For arrays, the dictionary contains the shape and the dtype -not their content-.

Keep in mind that this function is not supposed to be generic, but it is used only for the DaCe orchestration purposes.
"""
primitive_dtypes = (*orchestration_dtypes.ICON4PY_PRIMITIVE_DTYPES, str, uuid.UUID, np.dtype)

unique_dict = {}

def _populate_entry(key: str, value: Any, parent_key: str = ""):
full_key = f"{parent_key}.{key}" if parent_key else key

if full_key in members_to_disregard:
return

if isinstance(value, primitive_dtypes):
unique_dict[full_key] = {"type": "primitive_dtypes", "value": str(value)}
elif isinstance(value, (np.ndarray, gtx.Field)):
unique_dict[full_key] = {
"type": "array/field",
"shape": str(value.shape),
"dtype": str(value.dtype),
}
elif isinstance(value, (list, tuple)):
if all(isinstance(i, primitive_dtypes) for i in value):
unique_dict[full_key] = {
"type": f"array-like[{'empty' if len(value) == 0 else type(value[0])}]",
"length": str(len(value)),
}
else:
for i, v in enumerate(value):
_populate_entry(str(i), v, full_key)
elif value is None:
unique_dict[full_key] = {"type": "None", "value": "None"}
elif hasattr(value, "__dict__") or isinstance(value, dict):
_populate_unique_dict(value, full_key)
else:
raise ValueError(f"Type {type(value)} is not supported.")

def _populate_unique_dict(obj: Any, parent_key: str = ""):
if (hasattr(obj, "__dict__") or isinstance(obj, dict)) and not isinstance(
obj, decomposition.ExchangeRuntime
):
obj_to_traverse = obj.__dict__ if hasattr(obj, "__dict__") else obj
for key, value in obj_to_traverse.items():
_populate_entry(key, value, parent_key)

if hasattr(obj, "__dict__") or isinstance(obj, dict):
_populate_unique_dict(obj)
else:
_populate_entry(obj_name, obj)

return uid_from_hashlib(str(unique_dict))
Copy link
Contributor

@egparedes egparedes Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names of this function and the previous one do not follow a similar naming scheme despite they are closely related, I would try to rename them. Additionally, both functions are lacking the arg description in their docstrings and this one seems more convoluted than needed. I would suggest to refactor it to something with a single internal helper function like:

Suggested change
def generate_orchestration_uid(
obj: Any, obj_name: str = "", members_to_disregard: tuple[str] = ()
) -> str:
"""Generate a unique id for a runtime object.
The unique id is generated by creating a dictionary that describes the runtime state of the object.
For primitive types, the dictionary contains the type and the value.
For arrays, the dictionary contains the shape and the dtype -not their content-.
Keep in mind that this function is not supposed to be generic, but it is used only for the DaCe orchestration purposes.
"""
primitive_dtypes = (*orchestration_dtypes.ICON4PY_PRIMITIVE_DTYPES, str, uuid.UUID, np.dtype)
unique_dict = {}
def _populate_entry(key: str, value: Any, parent_key: str = ""):
full_key = f"{parent_key}.{key}" if parent_key else key
if full_key in members_to_disregard:
return
if isinstance(value, primitive_dtypes):
unique_dict[full_key] = {"type": "primitive_dtypes", "value": str(value)}
elif isinstance(value, (np.ndarray, gtx.Field)):
unique_dict[full_key] = {
"type": "array/field",
"shape": str(value.shape),
"dtype": str(value.dtype),
}
elif isinstance(value, (list, tuple)):
if all(isinstance(i, primitive_dtypes) for i in value):
unique_dict[full_key] = {
"type": f"array-like[{'empty' if len(value) == 0 else type(value[0])}]",
"length": str(len(value)),
}
else:
for i, v in enumerate(value):
_populate_entry(str(i), v, full_key)
elif value is None:
unique_dict[full_key] = {"type": "None", "value": "None"}
elif hasattr(value, "__dict__") or isinstance(value, dict):
_populate_unique_dict(value, full_key)
else:
raise ValueError(f"Type {type(value)} is not supported.")
def _populate_unique_dict(obj: Any, parent_key: str = ""):
if (hasattr(obj, "__dict__") or isinstance(obj, dict)) and not isinstance(
obj, decomposition.ExchangeRuntime
):
obj_to_traverse = obj.__dict__ if hasattr(obj, "__dict__") else obj
for key, value in obj_to_traverse.items():
_populate_entry(key, value, parent_key)
if hasattr(obj, "__dict__") or isinstance(obj, dict):
_populate_unique_dict(obj)
else:
_populate_entry(obj_name, obj)
return uid_from_hashlib(str(unique_dict))
def generate_orchestration_uid(
obj: Any, obj_name: str = "", members_to_disregard: tuple[str] = ()
) -> str:
"""Generate a unique id for a runtime object.
The unique id is generated by creating a dictionary that describes the runtime state of the object.
For primitive types, the dictionary contains the type and the value.
For arrays, the dictionary contains the shape and the dtype -not their content-.
Keep in mind that this function is not supposed to be generic, and should only be used for
DaCe orchestration purposes.
Args:
obj:
obj_name:
members_to_disregard:
"""
primitive_dtypes = (*orchestration_dtypes.ICON4PY_PRIMITIVE_DTYPES, str, uuid.UUID, np.dtype)
static_data = {}
def _populate_entry(key: str, value: Any) -> None:
if key in members_to_disregard:
return
if isinstance(value, primitive_dtypes):
static_data[key] = {"type": "primitive_dtypes", "value": str(value)}
elif isinstance(value, (np.ndarray, gtx.Field)):
static_data[key] = {
"type": "array/field",
"shape": str(value.shape),
"dtype": str(value.dtype),
}
elif isinstance(value, (list, tuple)):
item_types = set(type(i) for i in value) or {None}
if len(item_types) == 1 and issubclass(
prim_type := item_types.pop(), (*primitive_dtypes, None)
):
static_data[key] = {
"type": f"array-like[{prim_type!s}]",
"length": str(len(value)),
}
else:
for child_key, child_value in enumerate(value):
_populate_entry(f"{key}.{child_key!s}", child_value)
elif isinstance(value, decomposition.ExchangeRuntime):
pass
elif value is None:
static_data[key] = {"type": "None", "value": "None"}
elif isinstance(value, dict) or getattr(obj, "__dict__", None):
for child_key, child_value in getattr(obj, "__dict__", obj).items():
_populate_entry(f"{key}.{child_key!s}", child_value)
else:
raise ValueError(f"Type {type(value)} is not supported.")
_populate_entry(obj_name, obj)
return uid_from_hashlib(str(sorted(static_data.items(), key=operator.itemgetter(0))))

Also note that the final dict with the static information should be sorted in a consistent way to avoid false negatives if the items are created/traversed in a different order.

Comment on lines +83 to +84
def wrapper(*args, **kwargs):
self = args[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply

Suggested change
def wrapper(*args, **kwargs):
self = args[0]
def wrapper(self, *args, **kwargs):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The orchestrator could be used for non member functions as well -not yet implemented, but possibly in the future-. So, this is the reason for not having self in the arg list.

Copy link

Mandatory Tests

Please make sure you run these tests via comment before you merge!

  • cscs-ci run default
  • launch jenkins spack

Optional Tests

To run benchmarks you can use:

  • cscs-ci run benchmark

To run tests and benchmarks with the DaCe backend you can use:

  • cscs-ci run dace

In case your change might affect downstream icon-exclaim, please consider running

  • launch jenkins icon

For more detailed information please look at CI in the EXCLAIM universe.

@kotsaloscv
Copy link
Collaborator Author

cscs-ci run dace

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants