Skip to content

Commit

Permalink
[bin/recognize.py] support numworkers and compute dtype (#2379)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Mar 4, 2024
1 parent 77d951b commit 8179fe1
Showing 1 changed file with 55 additions and 33 deletions.
88 changes: 55 additions & 33 deletions wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def get_args():
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--dtype',
type=str,
default='float32',
choices=['float16', 'flaot32', 'bfloat16'],
help='model\'s dtype')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--beam_size',
type=int,
Expand Down Expand Up @@ -206,7 +215,9 @@ def main():
test_conf,
partition=False)

test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
test_data_loader = DataLoader(test_dataset,
batch_size=None,
num_workers=args.num_workers)

# Init asr model from configs
args.jit = False
Expand All @@ -216,6 +227,12 @@ def main():
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
model.eval()
dtype = torch.float32
if args.dtype == 'float16':
dtype = torch.float16
elif args.dtype == 'bfloat16':
dtype = torch.bfloat16
logging.info("compute dtype is {}".format(dtype))

context_graph = None
if 'decoding-graph' in args.context_bias_mode:
Expand All @@ -237,38 +254,43 @@ def main():
file_name = os.path.join(dir_name, 'text')
files[mode] = open(file_name, 'w')
max_format_len = max([len(mode) for mode in args.modes])
with torch.no_grad():
for batch_idx, batch in enumerate(test_data_loader):
keys = batch["keys"]
feats = batch["feats"].to(device)
target = batch["target"].to(device)
feats_lengths = batch["feats_lengths"].to(device)
target_lengths = batch["target_lengths"].to(device)
infos = {"tasks": batch["tasks"], "langs": batch["langs"]}
results = model.decode(
args.modes,
feats,
feats_lengths,
args.beam_size,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
ctc_weight=args.ctc_weight,
simulate_streaming=args.simulate_streaming,
reverse_weight=args.reverse_weight,
context_graph=context_graph,
blank_id=blank_id,
blank_penalty=args.blank_penalty,
length_penalty=args.length_penalty,
infos=infos)
for i, key in enumerate(keys):
for mode, hyps in results.items():
tokens = hyps[i].tokens
line = '{} {}'.format(key, tokenizer.detokenize(tokens)[0])
logging.info('{} {}'.format(mode.ljust(max_format_len),
line))
files[mode].write(line + '\n')
for mode, f in files.items():
f.close()

with torch.cuda.amp.autocast(enabled=True,
dtype=dtype,
cache_enabled=False):
with torch.no_grad():
for batch_idx, batch in enumerate(test_data_loader):
keys = batch["keys"]
feats = batch["feats"].to(device)
target = batch["target"].to(device)
feats_lengths = batch["feats_lengths"].to(device)
target_lengths = batch["target_lengths"].to(device)
infos = {"tasks": batch["tasks"], "langs": batch["langs"]}
results = model.decode(
args.modes,
feats,
feats_lengths,
args.beam_size,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
ctc_weight=args.ctc_weight,
simulate_streaming=args.simulate_streaming,
reverse_weight=args.reverse_weight,
context_graph=context_graph,
blank_id=blank_id,
blank_penalty=args.blank_penalty,
length_penalty=args.length_penalty,
infos=infos)
for i, key in enumerate(keys):
for mode, hyps in results.items():
tokens = hyps[i].tokens
line = '{} {}'.format(key,
tokenizer.detokenize(tokens)[0])
logging.info('{} {}'.format(mode.ljust(max_format_len),
line))
files[mode].write(line + '\n')
for mode, f in files.items():
f.close()


if __name__ == '__main__':
Expand Down

0 comments on commit 8179fe1

Please sign in to comment.