Skip to content

Commit

Permalink
[e_branchformer] simplified e_branchformer (#2484)
Browse files Browse the repository at this point in the history
* [e_branchformer] simplified ctl

* try to fix ut

* try to fix ut

* fix activation

* fix att args

* e-branformer works
  • Loading branch information
Mddct committed Apr 17, 2024
1 parent 2b67e6c commit 4e9da62
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 290 deletions.
2 changes: 1 addition & 1 deletion examples/aishell/s0/conf/train_ebranchformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ encoder_conf:
activation_type: 'swish'
causal: false
pos_enc_layer_type: 'rel_pos'
attention_layer_type: 'rel_selfattn'
selfattention_layer_type: 'rel_selfattn'

# decoder related
decoder: transformer
Expand Down
9 changes: 6 additions & 3 deletions wenet/branchformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,12 @@ def __init__(
WENET_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args) if use_attn else None,
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
dropout_rate, merge_method, cgmlp_weight[lnum],
attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum],
gradient_checkpointing) for lnum in range(num_blocks)
dropout_rate,
merge_method,
cgmlp_weight[lnum],
attn_branch_drop_rate[lnum],
stochastic_depth_rate[lnum],
) for lnum in range(num_blocks)
])

@torch.jit.ignore(drop=True)
Expand Down
2 changes: 0 additions & 2 deletions wenet/branchformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(
cgmlp_weight: float = 0.5,
attn_branch_drop_rate: float = 0.0,
stochastic_depth_rate: float = 0.0,
gradient_checkpointing: bool = False,
):
super().__init__()
assert (attn is not None) or (
Expand Down Expand Up @@ -106,7 +105,6 @@ def __init__(
raise ValueError(f"unknown merge method: {merge_method}")
else:
self.merge_proj = torch.nn.Identity()
self.gradient_checkpointing = gradient_checkpointing

def _forward(
self,
Expand Down
Loading

0 comments on commit 4e9da62

Please sign in to comment.