Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[recipe] refine yaml for librispeech #2227

Merged
merged 2 commits into from
Dec 13, 2023
Merged

[recipe] refine yaml for librispeech #2227

merged 2 commits into from
Dec 13, 2023

Conversation

xingchensong
Copy link
Member

follow #2205

@xingchensong
Copy link
Member Author

librispeech 模型转换脚本(sos id从vocabsize - 1 变为 2)

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2023-12-13] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>

import torch

old_state = torch.load('/mnt/d/BaiduSyncdisk/downloads/ckpt/20210610_u2pp_conformer_exp_librispeech/final.pt')
new_state = {}
change_list = ['decoder.left_decoder.output_layer.weight',
               'decoder.left_decoder.output_layer.bias',                                                                                                                                                                                            'decoder.left_decoder.embed.0.weight',
               'decoder.right_decoder.output_layer.weight',
               'decoder.right_decoder.output_layer.bias',
               'decoder.right_decoder.embed.0.weight',
               'ctc.ctc_lo.weight',
               'ctc.ctc_lo.bias']
for key in old_state.keys():
    if key in change_list:
        print("processing {}, {}".format(key, old_state[key].size()))
        tensor = old_state[key]
        new_tensor = torch.zeros_like(tensor)
        if len(tensor.size()) == 2:  # weight
            new_tensor[:2, :] = tensor[:2, :]
            new_tensor[2, :] = tensor[-1, :]
            new_tensor[3:, :] = tensor[2:-1, :]
        elif len(tensor.size()) == 1:  # bias
            new_tensor[:2] = tensor[:2]
            new_tensor[2] = tensor[-1]
            new_tensor[3:] = tensor[2:-1]
        else:
            raise NotImplementedError
        new_state[key] = new_tensor
    elif "concat_linear" in key:
        continue
    else:
        new_state[key] = old_state[key]

torch.save(new_state, "/mnt/d/BaiduSyncdisk/downloads/ckpt/20210610_u2pp_conformer_exp_librispeech/final.sos2.pt")

old_units = '/mnt/d/BaiduSyncdisk/downloads/ckpt/20210610_u2pp_conformer_exp_librispeech/units.txt'
new_units = '/mnt/d/BaiduSyncdisk/downloads/ckpt/20210610_u2pp_conformer_exp_librispeech/units.sos2.txt'

with open(old_units, "r") as fin, open(new_units, "w") as fout:
    lines = fin.readlines()
    fout.write(lines[0])
    fout.write(lines[1])
    fout.write("<sos/eos> 2\n")
    for line in lines[2:-1]:
        line = line.strip().split()
        token, token_id = line[0], line[1]
        fout.write("{} {}\n".format(token, int(token_id) + 1))

解码通过
image

@robin1001 robin1001 merged commit b08ec35 into main Dec 13, 2023
4 checks passed
@robin1001 robin1001 deleted the xcsong-libri branch December 13, 2023 08:29
@xingchensong
Copy link
Member Author

修改后结果对的上

image

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants