Skip to content

Commit

Permalink
add graph cudnn conv alg config (#6799)
Browse files Browse the repository at this point in the history
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
strint and oneflow-ci-bot committed Nov 23, 2021
1 parent e94e55a commit 6ee48a1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/oneflow/nn/graph/graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ def set_zero_redundancy_optimizer_mode(self, mode: str = "distributed_split"):
assert mode in ("distributed_split", "non_distributed")
self.proto.set_optimizer_placement_optimization_mode(mode)

def enable_cudnn_conv_heuristic_search_algo(self, mode: bool = True):
""" Whether enable cudnn conv operatioin to use heuristic search algorithm.
Args:
mode (bool, optional): Whether enable cudnn conv operatioin to use heuristic
search algorithm. Default is True.
"""
self.proto.set_cudnn_conv_heuristic_search_algo(mode)

def _generate_optimizer_and_variable_configs(
self, opt_dict: OptDict = None, variables_conf: OrderedDict = None,
):
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/test/graph/test_optimization_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self):
self.config.allow_fuse_cast_scale(True)
self.config.set_gradient_accumulation_steps(100)
self.config.set_zero_redundancy_optimizer_mode("distributed_split")
self.config.enable_cudnn_conv_heuristic_search_algo(False)

def build(self, x):
x = self.m(x)
Expand Down

0 comments on commit 6ee48a1

Please sign in to comment.