Skip to content

Commit

Permalink
resolved comments
Browse files Browse the repository at this point in the history
  • Loading branch information
moonbucks committed Jul 5, 2023
1 parent 38a461d commit d77473b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 3 additions & 3 deletions examples/selective2d/2d_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def pp_and_tp(model, mesh, args):
from pippy.microbatch import TensorChunkSpec, sum_reducer

pp_dim, tp_dim = 0, 1
pp_rank, tp_rank = args.local_rank // args.tp_size, args.local_rank % args.tp_size
pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size
pp_groups = mesh.get_dim_groups()[pp_dim]

# TP
Expand Down Expand Up @@ -191,7 +191,7 @@ def pp_and_tp_fg(model, mesh, args, tp_attn_layers=None, tp_mlp_layers=None, cut
from pippy.microbatch import TensorChunkSpec, sum_reducer

pp_dim, tp_dim = 0, 1
pp_rank, tp_rank = args.local_rank // args.tp_size, args.local_rank % args.tp_size
pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size
pp_groups = mesh.get_dim_groups()[pp_dim]

# TP
Expand All @@ -215,7 +215,7 @@ def pp_and_tp_fg(model, mesh, args, tp_attn_layers=None, tp_mlp_layers=None, cut

def pp_tp_train(stage, mesh, args):
pp_dim, tp_dim = 0, 1
pp_rank, tp_rank = args.local_rank // args.tp_size, args.local_rank % args.tp_size
pp_rank, tp_rank = args.rank // args.tp_size, args.rank % args.tp_size
pp_groups = mesh.get_dim_groups()[pp_dim]

train_iters = 10 if args.debug else args.train_iters
Expand Down
4 changes: 3 additions & 1 deletion examples/selective2d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ class CausalSelfAttention(nn.Module):

def __init__(self, mesh, config):
super().__init__()
tp_size = mesh.mesh.size(0)
assert config.n_head % tp_size == 0
assert config.n_embd % config.n_head == 0
self.mesh = mesh
self.tp_size = tp_size
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
Expand All @@ -60,7 +63,6 @@ def __init__(self, mesh, config):
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0

self.tp_size = self.mesh.mesh.size(0)

if not self.flash:
print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
Expand Down

0 comments on commit d77473b

Please sign in to comment.