From 5cf75ed5839953470fece7e651f26ffbc354d4b7 Mon Sep 17 00:00:00 2001 From: Mddct Date: Fri, 22 Mar 2024 09:03:42 +0800 Subject: [PATCH 1/2] [paraformer] fsdp fix submodule call --- .../convert_paraformer_to_wenet_config_and_ckpt.py | 4 +++- wenet/paraformer/layers.py | 6 +----- wenet/paraformer/paraformer.py | 4 +--- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py b/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py index abd9ff16b..a51af2bc2 100644 --- a/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py +++ b/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py @@ -118,7 +118,7 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str, configs['encoder_conf']['pos_enc_layer_type'] = 'abs_pos_paraformer' configs['ctc_conf'] = {} - configs['ctc_conf']['ctc_blank_id'] = 0 + configs['ctc_conf']['ctc_blank_id'] = 0.3 configs['dataset_conf'] = {} configs['dataset_conf']['filter_conf'] = {} @@ -186,6 +186,8 @@ def convert_to_wenet_state_dict(args, wenet_model_path): wenet_name = wenet_name.replace('predictor.', 'predictor.tp_') elif wenet_name.startswith('predictor.blstm'): wenet_name = wenet_name.replace('predictor.', 'predictor.tp_') + elif wenet_name == 'decoder.embed.0.weight': + wenet_name = 'embed.weight' wenet_state_dict[wenet_name] = checkpoint[name].float() diff --git a/wenet/paraformer/layers.py b/wenet/paraformer/layers.py index 20ecf25dc..ff5b849dc 100644 --- a/wenet/paraformer/layers.py +++ b/wenet/paraformer/layers.py @@ -410,11 +410,7 @@ def __init__( normalize_before, src_attention, gradient_checkpointing=gradient_checkpointing) - del self.embed - self.embed = torch.nn.Sequential( - torch.nn.Embedding(vocab_size, encoder_output_size)) - - del self.decoders + del self.embed, self.decoders self.decoders = torch.nn.ModuleList([ SanmDecoderLayer( encoder_output_size, diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index 612d19ed7..64b3587ec 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -141,9 +141,7 @@ def __init__(self, self.sampler = sampler self.sampling_ratio = sampling_ratio if sampler: - self.embed = self.decoder.embed - else: - del self.decoder.embed + self.embed = torch.nn.Embedding(vocab_size, encoder.output_size()) # NOTE(Mddct): add eos in tail of labels for predictor # eg: # gt: 你 好 we@@ net From 5734825ded14fe6dfda752b4f5d802050b7a6f87 Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Fri, 22 Mar 2024 14:22:50 +0800 Subject: [PATCH 2/2] fix blank_id --- wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py b/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py index a51af2bc2..6dee02b08 100644 --- a/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py +++ b/wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py @@ -118,7 +118,7 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str, configs['encoder_conf']['pos_enc_layer_type'] = 'abs_pos_paraformer' configs['ctc_conf'] = {} - configs['ctc_conf']['ctc_blank_id'] = 0.3 + configs['ctc_conf']['ctc_blank_id'] = 0 configs['dataset_conf'] = {} configs['dataset_conf']['filter_conf'] = {}