From cfbfd5fbb998bb2723319f42f55b7012c760df6b Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 16 Apr 2024 13:58:46 +0800 Subject: [PATCH 1/6] [e_branchformer] simplified ctl --- wenet/e_branchformer/encoder.py | 302 +++++++------------------------- 1 file changed, 62 insertions(+), 240 deletions(-) diff --git a/wenet/e_branchformer/encoder.py b/wenet/e_branchformer/encoder.py index 2d4c6097e..6b779c091 100644 --- a/wenet/e_branchformer/encoder.py +++ b/wenet/e_branchformer/encoder.py @@ -17,23 +17,18 @@ """Encoder definition.""" import torch -import torch.nn as nn -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer from wenet.branchformer.cgmlp import ConvolutionalGatingMLP -from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward -from wenet.utils.mask import make_pad_mask -from wenet.utils.mask import add_optional_chunk_mask +from wenet.transformer.encoder import ConformerEncoder from wenet.utils.class_utils import ( WENET_ATTENTION_CLASSES, - WENET_EMB_CLASSES, - WENET_SUBSAMPLE_CLASSES, - WENET_ACTIVATION_CLASSES, + WENET_MLP_CLASSES, ) -class EBranchformerEncoder(nn.Module): +class EBranchformerEncoder(ConformerEncoder): """E-Branchformer encoder module.""" def __init__( @@ -42,20 +37,18 @@ def __init__( output_size: int = 256, attention_heads: int = 4, linear_units: int = 2048, - attention_layer_type: str = "rel_selfattn", + selfattention_layer_type: str = "rel_selfattn", pos_enc_layer_type: str = "rel_pos", activation_type: str = "swish", cgmlp_linear_units: int = 2048, cgmlp_conv_kernel: int = 31, use_linear_after_conv: bool = False, gate_activation: str = "identity", - merge_method: str = "concat", num_blocks: int = 12, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.0, - input_layer: Optional[str] = "conv2d", - padding_idx: int = -1, + input_layer: str = "conv2d", stochastic_depth_rate: Union[float, List[float]] = 0.0, static_chunk_size: int = 0, use_dynamic_chunk: bool = False, @@ -65,18 +58,54 @@ def __init__( merge_conv_kernel: int = 3, use_ffn: bool = True, macaron_style: bool = True, + query_bias: bool = True, + key_bias: bool = True, + value_bias: bool = True, + conv_bias: bool = True, + gradient_checkpointing: bool = False, + use_sdpa: bool = False, + layer_norm_type: str = 'layer_norm', + norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): - super().__init__() - activation = WENET_ACTIVATION_CLASSES[activation_type]() - self._output_size = output_size - - self.embed = WENET_SUBSAMPLE_CLASSES[input_layer]( - input_size, - output_size, - dropout_rate, - WENET_EMB_CLASSES[pos_enc_layer_type](output_size, - positional_dropout_rate), - ) + super().__init__(input_size, + output_size, + attention_heads, + linear_units, + num_blocks, + dropout_rate, + positional_dropout_rate, + attention_dropout_rate, + input_layer, + pos_enc_layer_type, + True, + static_chunk_size, + use_dynamic_chunk, + global_cmvn, + use_dynamic_left_chunk, + 1, + macaron_style, + selfattention_layer_type, + activation_type, + query_bias=query_bias, + key_bias=key_bias, + value_bias=value_bias, + conv_bias=conv_bias, + gradient_checkpointing=gradient_checkpointing, + use_sdpa=use_sdpa, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + n_kv_head=n_kv_head, + head_dim=head_dim, + mlp_type=mlp_type, + mlp_bias=mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated) encoder_selfattn_layer_args = ( attention_heads, @@ -90,12 +119,16 @@ def __init__( gate_activation, causal) # feed-forward module definition - positionwise_layer = PositionwiseFeedForward + mlp_class = WENET_MLP_CLASSES[mlp_type] + # feed-forward module definition positionwise_layer_args = ( output_size, linear_units, dropout_rate, - activation, + activation_type, + mlp_bias, + n_expert, + n_expert_activated, ) if isinstance(stochastic_depth_rate, float): @@ -108,12 +141,11 @@ def __init__( self.encoders = torch.nn.ModuleList([ EBranchformerEncoderLayer( output_size, - WENET_ATTENTION_CLASSES[attention_layer_type]( + WENET_ATTENTION_CLASSES[selfattention_layer_type]( *encoder_selfattn_layer_args), cgmlp_layer(*cgmlp_layer_args), - positionwise_layer( - *positionwise_layer_args) if use_ffn else None, - positionwise_layer(*positionwise_layer_args) + mlp_class(*positionwise_layer_args) if use_ffn else None, + mlp_class(*positionwise_layer_args) if use_ffn and macaron_style else None, dropout_rate, merge_conv_kernel=merge_conv_kernel, @@ -121,213 +153,3 @@ def __init__( stochastic_depth_rate=stochastic_depth_rate[lnum], ) for lnum in range(num_blocks) ]) - - self.after_norm = nn.LayerNorm(output_size) - self.static_chunk_size = static_chunk_size - self.global_cmvn = global_cmvn - self.use_dynamic_chunk = use_dynamic_chunk - self.use_dynamic_left_chunk = use_dynamic_left_chunk - - def output_size(self) -> int: - return self._output_size - - def forward( - self, - xs: torch.Tensor, - ilens: torch.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - xs (torch.Tensor): Input tensor (B, T, D). - ilens (torch.Tensor): Input length (#batch). - decoding_chunk_size: decoding chunk size for dynamic chunk - 0: default for training, use random dynamic chunk. - <0: for decoding, use full chunk. - >0: for decoding, use fixed chunk size as set. - num_decoding_left_chunks: number of left chunks, this is for decoding, - the chunk size is decoding_chunk_size. - >=0: use num_decoding_left_chunks - <0: use all left chunks - - Returns: - encoder output tensor xs, and subsampled masks - xs: padded output tensor (B, T' ~= T/subsample_rate, D) - masks: torch.Tensor batch padding mask after subsample - (B, 1, T' ~= T/subsample_rate) - """ - - T = xs.size(1) - masks = ~make_pad_mask(ilens, T).unsqueeze(1) # (B, 1, T) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - xs, pos_emb, masks = self.embed(xs, masks) - mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks) - for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) - - xs = self.after_norm(xs) - # Here we assume the mask is not changed in encoder layers, so just - # return the masks before encoder layers, and the masks will be used - # for cross attention with decoder later - return xs, masks - - def forward_chunk( - self, - xs: torch.Tensor, - offset: int, - required_cache_size: int, - att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ Forward just one chunk - - Args: - xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), - where `time == (chunk_size - 1) * subsample_rate + \ - subsample.right_context + 1` - offset (int): current offset in encoder output time stamp - required_cache_size (int): cache size required for next chunk - compuation - >=0: actual cache size - <0: means all history cache is required - att_cache (torch.Tensor): cache tensor for KEY & VALUE in - transformer/conformer attention, with shape - (elayers, head, cache_t1, d_k * 2), where - `head * d_k == hidden-dim` and - `cache_t1 == chunk_size * num_decoding_left_chunks`. - cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, - (elayers, b=1, hidden-dim, cache_t2), where - `cache_t2 == cnn.lorder - 1` - - Returns: - torch.Tensor: output of current input xs, - with shape (b=1, chunk_size, hidden-dim). - torch.Tensor: new attention cache required for next chunk, with - dynamic shape (elayers, head, ?, d_k * 2) - depending on required_cache_size. - torch.Tensor: new conformer cnn cache required for next chunk, with - same shape as the original cnn_cache. - - """ - assert xs.size(0) == 1 - # tmp_masks is just for interface compatibility - tmp_masks = torch.ones(1, - xs.size(1), - device=xs.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) - xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) - # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) - elayers, cache_t1 = att_cache.size(0), att_cache.size(2) - chunk_size = xs.size(1) - attention_key_size = cache_t1 + chunk_size - pos_emb = self.embed.position_encoding(offset=offset - cache_t1, - size=attention_key_size) - if required_cache_size < 0: - next_cache_start = 0 - elif required_cache_size == 0: - next_cache_start = attention_key_size - else: - next_cache_start = max(attention_key_size - required_cache_size, 0) - r_att_cache = [] - r_cnn_cache = [] - for i, layer in enumerate(self.encoders): - # NOTE(xcsong): Before layer.forward - # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), - # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) - xs, _, new_att_cache, new_cnn_cache = layer( - xs, - att_mask, - pos_emb, - att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, - cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache) - # NOTE(xcsong): After layer.forward - # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), - # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) - r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) - r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) - - xs = self.after_norm(xs) - - # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), - # ? may be larger than cache_t1, it depends on required_cache_size - r_att_cache = torch.cat(r_att_cache, dim=0) - # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) - r_cnn_cache = torch.cat(r_cnn_cache, dim=0) - - return (xs, r_att_cache, r_cnn_cache) - - def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - decoding_chunk_size: int, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Forward input chunk by chunk with chunk_size like a streaming - fashion - - Here we should pay special attention to computation cache in the - streaming style forward chunk by chunk. Three things should be taken - into account for computation in the current network: - 1. transformer/conformer encoder layers output cache - 2. convolution in conformer - 3. convolution in subsampling - - However, we don't implement subsampling cache for: - 1. We can control subsampling module to output the right result by - overlapping input instead of cache left context, even though it - wastes some computation, but subsampling only takes a very - small fraction of computation in the whole model. - 2. Typically, there are several covolution layers with subsampling - in subsampling module, it is tricky and complicated to do cache - with different convolution layers with different subsampling - rate. - 3. Currently, nn.Sequential is used to stack all the convolution - layers in subsampling, we need to rewrite it to make it work - with cache, which is not prefered. - Args: - xs (torch.Tensor): (1, max_len, dim) - chunk_size (int): decoding chunk size - """ - assert decoding_chunk_size > 0 - # The model is trained by static or dynamic chunk - assert self.static_chunk_size > 0 or self.use_dynamic_chunk - subsampling = self.embed.subsampling_rate - context = self.embed.right_context + 1 # Add current frame - stride = subsampling * decoding_chunk_size - decoding_window = (decoding_chunk_size - 1) * subsampling + context - num_frames = xs.size(1) - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - outputs = [] - offset = 0 - required_cache_size = decoding_chunk_size * num_decoding_left_chunks - - # Feed forward overlap input step by step - for cur in range(0, num_frames - context + 1, stride): - end = min(cur + decoding_window, num_frames) - chunk_xs = xs[:, cur:end, :] - (y, att_cache, - cnn_cache) = self.forward_chunk(chunk_xs, offset, - required_cache_size, att_cache, - cnn_cache) - outputs.append(y) - offset += y.size(1) - ys = torch.cat(outputs, 1) - masks = torch.ones((1, 1, ys.size(1)), - device=ys.device, - dtype=torch.bool) - return ys, masks From 5743ac02b4dcbd410d86126fbd1736a4557e1c1d Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 16 Apr 2024 14:18:55 +0800 Subject: [PATCH 2/6] try to fix ut --- examples/aishell/s0/conf/train_ebranchformer.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/aishell/s0/conf/train_ebranchformer.yaml b/examples/aishell/s0/conf/train_ebranchformer.yaml index 5136f1ad9..0387bd554 100644 --- a/examples/aishell/s0/conf/train_ebranchformer.yaml +++ b/examples/aishell/s0/conf/train_ebranchformer.yaml @@ -18,7 +18,7 @@ encoder_conf: activation_type: 'swish' causal: false pos_enc_layer_type: 'rel_pos' - attention_layer_type: 'rel_selfattn' + self_attention_layer_type: 'rel_selfattn' # decoder related decoder: transformer From 0bfd672a661229b275807d0e5b6bf56e067b3237 Mon Sep 17 00:00:00 2001 From: Mddct Date: Tue, 16 Apr 2024 14:23:41 +0800 Subject: [PATCH 3/6] try to fix ut --- examples/aishell/s0/conf/train_ebranchformer.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/aishell/s0/conf/train_ebranchformer.yaml b/examples/aishell/s0/conf/train_ebranchformer.yaml index 0387bd554..0e789dda3 100644 --- a/examples/aishell/s0/conf/train_ebranchformer.yaml +++ b/examples/aishell/s0/conf/train_ebranchformer.yaml @@ -18,7 +18,7 @@ encoder_conf: activation_type: 'swish' causal: false pos_enc_layer_type: 'rel_pos' - self_attention_layer_type: 'rel_selfattn' + selfattention_layer_type: 'rel_selfattn' # decoder related decoder: transformer From 08e8970af407f9177708bd804285ead836cf5933 Mon Sep 17 00:00:00 2001 From: Mddct Date: Wed, 17 Apr 2024 12:30:07 +0800 Subject: [PATCH 4/6] fix activation --- wenet/e_branchformer/encoder.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/wenet/e_branchformer/encoder.py b/wenet/e_branchformer/encoder.py index 6b779c091..47f72abcf 100644 --- a/wenet/e_branchformer/encoder.py +++ b/wenet/e_branchformer/encoder.py @@ -23,6 +23,7 @@ from wenet.branchformer.cgmlp import ConvolutionalGatingMLP from wenet.transformer.encoder import ConformerEncoder from wenet.utils.class_utils import ( + WENET_ACTIVATION_CLASSES, WENET_ATTENTION_CLASSES, WENET_MLP_CLASSES, ) @@ -111,6 +112,15 @@ def __init__( attention_heads, output_size, attention_dropout_rate, + attention_heads, + output_size, + attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, ) cgmlp_layer = ConvolutionalGatingMLP @@ -120,12 +130,12 @@ def __init__( # feed-forward module definition mlp_class = WENET_MLP_CLASSES[mlp_type] - # feed-forward module definition + activation = WENET_ACTIVATION_CLASSES[activation_type]() positionwise_layer_args = ( output_size, linear_units, dropout_rate, - activation_type, + activation, mlp_bias, n_expert, n_expert_activated, From e8cacab22943a1ff85abbb39944e5e037d2c0489 Mon Sep 17 00:00:00 2001 From: Mddct Date: Wed, 17 Apr 2024 12:34:27 +0800 Subject: [PATCH 5/6] fix att args --- wenet/e_branchformer/encoder.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/wenet/e_branchformer/encoder.py b/wenet/e_branchformer/encoder.py index 47f72abcf..36d279fb9 100644 --- a/wenet/e_branchformer/encoder.py +++ b/wenet/e_branchformer/encoder.py @@ -109,9 +109,6 @@ def __init__( n_expert_activated=n_expert_activated) encoder_selfattn_layer_args = ( - attention_heads, - output_size, - attention_dropout_rate, attention_heads, output_size, attention_dropout_rate, From 52e6f2a6d3d08b1d6b730b6483cfab27eae07269 Mon Sep 17 00:00:00 2001 From: Mddct Date: Wed, 17 Apr 2024 13:52:01 +0800 Subject: [PATCH 6/6] e-branformer works --- wenet/branchformer/encoder.py | 9 ++-- wenet/branchformer/encoder_layer.py | 2 - wenet/e_branchformer/encoder.py | 40 +++++++++------ wenet/e_branchformer/encoder_layer.py | 74 +++++++++++++++------------ 4 files changed, 72 insertions(+), 53 deletions(-) diff --git a/wenet/branchformer/encoder.py b/wenet/branchformer/encoder.py index 1c67a91d2..fab8bd4e3 100644 --- a/wenet/branchformer/encoder.py +++ b/wenet/branchformer/encoder.py @@ -126,9 +126,12 @@ def __init__( WENET_ATTENTION_CLASSES[selfattention_layer_type]( *encoder_selfattn_layer_args) if use_attn else None, cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, - dropout_rate, merge_method, cgmlp_weight[lnum], - attn_branch_drop_rate[lnum], stochastic_depth_rate[lnum], - gradient_checkpointing) for lnum in range(num_blocks) + dropout_rate, + merge_method, + cgmlp_weight[lnum], + attn_branch_drop_rate[lnum], + stochastic_depth_rate[lnum], + ) for lnum in range(num_blocks) ]) @torch.jit.ignore(drop=True) diff --git a/wenet/branchformer/encoder_layer.py b/wenet/branchformer/encoder_layer.py index 0cbd2e6f9..9b011dd66 100644 --- a/wenet/branchformer/encoder_layer.py +++ b/wenet/branchformer/encoder_layer.py @@ -46,7 +46,6 @@ def __init__( cgmlp_weight: float = 0.5, attn_branch_drop_rate: float = 0.0, stochastic_depth_rate: float = 0.0, - gradient_checkpointing: bool = False, ): super().__init__() assert (attn is not None) or ( @@ -106,7 +105,6 @@ def __init__( raise ValueError(f"unknown merge method: {merge_method}") else: self.merge_proj = torch.nn.Identity() - self.gradient_checkpointing = gradient_checkpointing def _forward( self, diff --git a/wenet/e_branchformer/encoder.py b/wenet/e_branchformer/encoder.py index 36d279fb9..f6272eefe 100644 --- a/wenet/e_branchformer/encoder.py +++ b/wenet/e_branchformer/encoder.py @@ -18,6 +18,7 @@ import torch from typing import List, Optional, Union +from wenet.branchformer.encoder import LayerDropModuleList from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer from wenet.branchformer.cgmlp import ConvolutionalGatingMLP @@ -145,18 +146,27 @@ def __init__( f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " f"should be equal to num_blocks ({num_blocks})") - self.encoders = torch.nn.ModuleList([ - EBranchformerEncoderLayer( - output_size, - WENET_ATTENTION_CLASSES[selfattention_layer_type]( - *encoder_selfattn_layer_args), - cgmlp_layer(*cgmlp_layer_args), - mlp_class(*positionwise_layer_args) if use_ffn else None, - mlp_class(*positionwise_layer_args) - if use_ffn and macaron_style else None, - dropout_rate, - merge_conv_kernel=merge_conv_kernel, - causal=causal, - stochastic_depth_rate=stochastic_depth_rate[lnum], - ) for lnum in range(num_blocks) - ]) + self.encoders = LayerDropModuleList( + p=stochastic_depth_rate, + modules=[ + EBranchformerEncoderLayer( + output_size, + WENET_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args), + cgmlp_layer(*cgmlp_layer_args), + mlp_class(*positionwise_layer_args) if use_ffn else None, + mlp_class(*positionwise_layer_args) + if use_ffn and macaron_style else None, + dropout_rate, + merge_conv_kernel=merge_conv_kernel, + causal=causal, + stochastic_depth_rate=stochastic_depth_rate[lnum], + ) for lnum in range(num_blocks) + ]) + + @torch.jit.ignore(drop=True) + def forward_layers_checkpointed(self, xs: torch.Tensor, + chunk_masks: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor) -> torch.Tensor: + return self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) diff --git a/wenet/e_branchformer/encoder_layer.py b/wenet/e_branchformer/encoder_layer.py index dba232383..4b3eef2c1 100644 --- a/wenet/e_branchformer/encoder_layer.py +++ b/wenet/e_branchformer/encoder_layer.py @@ -88,7 +88,7 @@ def __init__( self.merge_proj = torch.nn.Linear(size + size, size) self.stochastic_depth_rate = stochastic_depth_rate - def forward( + def _forward( self, x: torch.Tensor, mask: torch.Tensor, @@ -96,39 +96,8 @@ def forward( mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + stoch_layer_coeff: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute encoded features. - - Args: - x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size). - mask (torch.Tensor): Mask tensor for the input (#batch, time, time). - pos_emb (torch.Tensor): positional encoding, must not be None - for BranchformerEncoderLayer. - mask_pad (torch.Tensor): batch padding mask used for conv module. - (#batch, 1,time), (0, 0, 0) means fake mask. - att_cache (torch.Tensor): Cache tensor of the KEY & VALUE - (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. - cnn_cache (torch.Tensor): Convolution cache in cgmlp layer - (#batch=1, size, cache_t2) - - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time, time. - torch.Tensor: att_cache tensor, - (#batch=1, head, cache_t1 + time, d_k * 2). - torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). - """ - - stoch_layer_coeff = 1.0 - skip_layer = False - # with stochastic depth, residual connection `x + f(x)` becomes - # `x <- x + 1 / (1 - p) * f(x)` at training time. - if self.training and self.stochastic_depth_rate > 0: - skip_layer = torch.rand(1).item() < self.stochastic_depth_rate - stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) - - if skip_layer: - return x, mask, att_cache, cnn_cache if self.feed_forward_macaron is not None: residual = x @@ -173,3 +142,42 @@ def forward( x = self.norm_final(x) return x, mask, new_att_cache, new_cnn_cache + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (Union[Tuple, torch.Tensor]): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time, time). + pos_emb (torch.Tensor): positional encoding, must not be None + for BranchformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in cgmlp layer + (#batch=1, size, cache_t2) + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time. + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + stoch_layer_coeff = 1.0 + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + if self.training: + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + return self._forward(x, mask, pos_emb, mask_pad, att_cache, cnn_cache, + stoch_layer_coeff)