From 62d74318a09818522c518d013fcf8a297c62e4a9 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Mon, 14 Oct 2024 22:44:14 +0300 Subject: [PATCH] Use _internal_dict.keys instead of _get_signature_keys --- modules/sd_models.py | 41 +++++++++++++---------------------------- 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 8a90f3c7e..51a87915c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -740,33 +740,20 @@ def eval_model(model, op=None, sd_model=None): # pylint: disable=unused-argument set_diffuser_offload(sd_model, op) -def set_accelerate_to_module(pipe): - module_names, _ = pipe._get_signature_keys(pipe) # pylint: disable=protected-access - for module_name in module_names: - module = getattr(pipe, module_name) - if isinstance(module, torch.nn.Module): - module.has_accelerate = True +def set_accelerate_to_module(model): + for k in model._internal_dict.keys(): # pylint: disable=protected-access + component = getattr(model, k, None) + if isinstance(component, torch.nn.Module): + component.has_accelerate = True def set_accelerate(sd_model): sd_model.has_accelerate = True - if hasattr(sd_model, "_get_signature_keys"): - set_accelerate_to_module(sd_model) - if hasattr(sd_model, "prior_pipe"): - set_accelerate_to_module(sd_model.prior_pipe) - if hasattr(sd_model, "decoder_pipe"): - set_accelerate_to_module(sd_model.decoder_pipe) - else: - if getattr(sd_model, 'vae', None) is not None: - sd_model.vae.has_accelerate = True - if getattr(sd_model, 'unet', None) is not None: - sd_model.unet.has_accelerate = True - if getattr(sd_model, 'transformer', None) is not None: - sd_model.transformer.has_accelerate = True - if getattr(sd_model, 'text_encoder', None) is not None: - sd_model.text_encoder.has_accelerate = True - if getattr(sd_model, 'text_encoder_2', None) is not None: - sd_model.text_encoder_2.has_accelerate = True + set_accelerate_to_module(sd_model) + if hasattr(sd_model, "prior_pipe"): + set_accelerate_to_module(sd_model.prior_pipe) + if hasattr(sd_model, "decoder_pipe"): + set_accelerate_to_module(sd_model.decoder_pipe) def set_diffuser_offload(sd_model, op: str = 'model'): @@ -857,8 +844,7 @@ def detach_hook(self, module): return module def apply_balanced_offload_to_module(pipe): - module_names, _ = pipe._get_signature_keys(pipe) # pylint: disable=protected-access - for module_name in module_names: + for module_name in pipe._internal_dict.keys(): # pylint: disable=protected-access module = getattr(pipe, module_name) if isinstance(module, torch.nn.Module): checkpoint_name = pipe.sd_checkpoint_info.name if getattr(pipe, "sd_checkpoint_info", None) is not None else None @@ -1593,10 +1579,9 @@ def set_diffusers_attention(pipe): def set_attn(pipe, attention): if attention is None: return - if not hasattr(pipe, "_get_signature_keys"): + if not hasattr(pipe, "_internal_dict"): return - module_names, _ = pipe._get_signature_keys(pipe) # pylint: disable=protected-access - modules = [getattr(pipe, n, None) for n in module_names] + modules = [getattr(pipe, n, None) for n in pipe._internal_dict.keys()] # pylint: disable=protected-access modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attn_processor")] for module in modules: if module.__class__.__name__ in ['SD3Transformer2DModel']: