Skip to content

Commit

Permalink
Fix ONNX/Olive generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed May 5, 2024
1 parent 216340d commit 7fd77f2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 17 deletions.
8 changes: 4 additions & 4 deletions modules/onnx_impl/pipelines/onnx_stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def __call__(
if generator is None:
generator = torch.Generator("cpu")

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
Expand All @@ -84,12 +87,9 @@ def __call__(
width,
prompt_embeds.dtype,
generator,
latents
latents,
)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def __call__(
if generator is None:
generator = torch.Generator("cpu")

self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
Expand Down Expand Up @@ -97,9 +100,6 @@ def __call__(
generator,
)

self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps

# 5. Add noise to image
noise_level = np.array([noise_level]).astype(np.int64)
noise = randn_tensor(
Expand Down
13 changes: 3 additions & 10 deletions modules/onnx_impl/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,14 @@ def prepare_latents(
width: int,
dtype: np.dtype,
generator: Union[torch.Generator, List[torch.Generator]],
latents: Union[np.ndarray, None]=None,
num_channels_latents=4,
vae_scale_factor=8,
latents: Union[np.ndarray, None] = None,
num_channels_latents = 4,
vae_scale_factor = 8,
):
shape = (batch_size, num_channels_latents, height // vae_scale_factor, width // vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if latents is None:
latents = randn_tensor(shape, dtype, generator)
elif latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * np.float64(init_noise_sigma)
Expand Down

0 comments on commit 7fd77f2

Please sign in to comment.