You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation
The text was updated successfully, but these errors were encountered:
The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation
Everyone with the same problem.
Need to add .to(torch.bfloat16) in flux/model.py here:
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
).to(torch.bfloat16)
for _ in range(params.depth)
]
)
and here
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio).to(torch.bfloat16)
for _ in range(params.depth_single_blocks)
]
)
When loading Flux model, the entire model is being created before the cast.
The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation
The text was updated successfully, but these errors were encountered: