Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux model is being fully created before bfloat16 cast #67

Open
PRPA1984 opened this issue Sep 13, 2024 · 1 comment
Open

Flux model is being fully created before bfloat16 cast #67

PRPA1984 opened this issue Sep 13, 2024 · 1 comment

Comments

@PRPA1984
Copy link

When loading Flux model, the entire model is being created before the cast.

model = Flux(configs[name].params).to(torch.bfloat16)

The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation

@denred0
Copy link

denred0 commented Sep 15, 2024

When loading Flux model, the entire model is being created before the cast.

model = Flux(configs[name].params).to(torch.bfloat16)

The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation

Everyone with the same problem.
Need to add .to(torch.bfloat16) in flux/model.py here:

 self.double_blocks = nn.ModuleList(
            [
                DoubleStreamBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=params.mlp_ratio,
                    qkv_bias=params.qkv_bias,
                ).to(torch.bfloat16)

                for _ in range(params.depth)
            ]
        )

and here

self.single_blocks = nn.ModuleList(
            [
                SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio).to(torch.bfloat16)

                for _ in range(params.depth_single_blocks)
            ]
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants