Skip to content

Commit

Permalink
Use _internal_dict.keys instead of _get_signature_keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Disty0 committed Oct 14, 2024
1 parent bbedeaf commit 62d7431
Showing 1 changed file with 13 additions and 28 deletions.
41 changes: 13 additions & 28 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,33 +740,20 @@ def eval_model(model, op=None, sd_model=None): # pylint: disable=unused-argument
set_diffuser_offload(sd_model, op)


def set_accelerate_to_module(pipe):
module_names, _ = pipe._get_signature_keys(pipe) # pylint: disable=protected-access
for module_name in module_names:
module = getattr(pipe, module_name)
if isinstance(module, torch.nn.Module):
module.has_accelerate = True
def set_accelerate_to_module(model):
for k in model._internal_dict.keys(): # pylint: disable=protected-access
component = getattr(model, k, None)
if isinstance(component, torch.nn.Module):
component.has_accelerate = True


def set_accelerate(sd_model):
sd_model.has_accelerate = True
if hasattr(sd_model, "_get_signature_keys"):
set_accelerate_to_module(sd_model)
if hasattr(sd_model, "prior_pipe"):
set_accelerate_to_module(sd_model.prior_pipe)
if hasattr(sd_model, "decoder_pipe"):
set_accelerate_to_module(sd_model.decoder_pipe)
else:
if getattr(sd_model, 'vae', None) is not None:
sd_model.vae.has_accelerate = True
if getattr(sd_model, 'unet', None) is not None:
sd_model.unet.has_accelerate = True
if getattr(sd_model, 'transformer', None) is not None:
sd_model.transformer.has_accelerate = True
if getattr(sd_model, 'text_encoder', None) is not None:
sd_model.text_encoder.has_accelerate = True
if getattr(sd_model, 'text_encoder_2', None) is not None:
sd_model.text_encoder_2.has_accelerate = True
set_accelerate_to_module(sd_model)
if hasattr(sd_model, "prior_pipe"):
set_accelerate_to_module(sd_model.prior_pipe)
if hasattr(sd_model, "decoder_pipe"):
set_accelerate_to_module(sd_model.decoder_pipe)


def set_diffuser_offload(sd_model, op: str = 'model'):
Expand Down Expand Up @@ -857,8 +844,7 @@ def detach_hook(self, module):
return module

def apply_balanced_offload_to_module(pipe):
module_names, _ = pipe._get_signature_keys(pipe) # pylint: disable=protected-access
for module_name in module_names:
for module_name in pipe._internal_dict.keys(): # pylint: disable=protected-access
module = getattr(pipe, module_name)
if isinstance(module, torch.nn.Module):
checkpoint_name = pipe.sd_checkpoint_info.name if getattr(pipe, "sd_checkpoint_info", None) is not None else None
Expand Down Expand Up @@ -1593,10 +1579,9 @@ def set_diffusers_attention(pipe):
def set_attn(pipe, attention):
if attention is None:
return
if not hasattr(pipe, "_get_signature_keys"):
if not hasattr(pipe, "_internal_dict"):
return
module_names, _ = pipe._get_signature_keys(pipe) # pylint: disable=protected-access
modules = [getattr(pipe, n, None) for n in module_names]
modules = [getattr(pipe, n, None) for n in pipe._internal_dict.keys()] # pylint: disable=protected-access
modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attn_processor")]
for module in modules:
if module.__class__.__name__ in ['SD3Transformer2DModel']:
Expand Down

0 comments on commit 62d7431

Please sign in to comment.