Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

iOS runtime #1549

Merged
merged 13 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions runtime/core/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@ set(decoder_srcs
ctc_endpoint.cc
)

if(NOT TORCH AND NOT ONNX AND NOT XPU)
message(FATAL_ERROR "Please build with TORCH or ONNX or XPU!!!")
if(NOT TORCH AND NOT ONNX AND NOT XPU AND NOT IOS)
message(FATAL_ERROR "Please build with TORCH or ONNX or XPU or IOS!!!")
endif()
if(TORCH)
list(APPEND decoder_srcs torch_asr_model.cc)
endif()
if(ONNX)
list(APPEND decoder_srcs onnx_asr_model.cc)
endif()
if(IOS)
list(APPEND decoder_srcs ios_asr_model.cc)
endif()

add_library(decoder STATIC ${decoder_srcs})
target_link_libraries(decoder PUBLIC kaldi-decoder frontend
Expand Down
233 changes: 233 additions & 0 deletions runtime/core/decoder/ios_asr_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
pengzhendong marked this conversation as resolved.
Show resolved Hide resolved
// 2022 Dan Ma (1067837450@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Ma-Dan marked this conversation as resolved.
Show resolved Hide resolved
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#include "decoder/ios_asr_model.h"

#include <algorithm>
#include <memory>
#include <utility>
#include <stdexcept>

#include "torch/script.h"

namespace wenet {

void IosAsrModel::Read(const std::string& model_path) {
torch::DeviceType device = at::kCPU;
torch::jit::script::Module model = torch::jit::load(model_path, device);
model_ = std::make_shared<TorchModule>(std::move(model));
torch::NoGradGuard no_grad;
model_->eval();
torch::jit::IValue o1 = model_->run_method("subsampling_rate");
CHECK_EQ(o1.isInt(), true);
subsampling_rate_ = o1.toInt();
torch::jit::IValue o2 = model_->run_method("right_context");
CHECK_EQ(o2.isInt(), true);
right_context_ = o2.toInt();
torch::jit::IValue o3 = model_->run_method("sos_symbol");
CHECK_EQ(o3.isInt(), true);
sos_ = o3.toInt();
torch::jit::IValue o4 = model_->run_method("eos_symbol");
CHECK_EQ(o4.isInt(), true);
eos_ = o4.toInt();
torch::jit::IValue o5 = model_->run_method("is_bidirectional_decoder");
CHECK_EQ(o5.isBool(), true);
is_bidirectional_decoder_ = o5.toBool();

VLOG(1) << "Torch Model Info:";
VLOG(1) << "\tsubsampling_rate " << subsampling_rate_;
VLOG(1) << "\tright context " << right_context_;
VLOG(1) << "\tsos " << sos_;
VLOG(1) << "\teos " << eos_;
VLOG(1) << "\tis bidirectional decoder " << is_bidirectional_decoder_;
}

IosAsrModel::IosAsrModel(const IosAsrModel& other) {
// 1. Init the model info
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidirectional_decoder_ = other.is_bidirectional_decoder_;
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
offset_ = other.offset_;
// 2. Model copy, just copy the model ptr since:
// PyTorch allows using multiple CPU threads during TorchScript model
// inference, please see https://pytorch.org/docs/stable/notes/cpu_
// threading_torchscript_inference.html
model_ = other.model_;

// NOTE(Binbin Zhang):
// inner states for forward are not copied here.
}

std::shared_ptr<AsrModel> IosAsrModel::Copy() const {
auto asr_model = std::make_shared<IosAsrModel>(*this);
// Reset the inner states for new decoding
asr_model->Reset();
return asr_model;
}

void IosAsrModel::Reset() {
offset_ = 0;
att_cache_ = std::move(torch::zeros({0, 0, 0, 0}));
cnn_cache_ = std::move(torch::zeros({0, 0, 0, 0}));
encoder_outs_.clear();
cached_feature_.clear();
}

void IosAsrModel::ForwardEncoderFunc(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* out_prob) {
// 1. Prepare libtorch required data, splice cached_feature_ and chunk_feats
// The first dimension is for batchsize, which is 1.
int num_frames = cached_feature_.size() + chunk_feats.size();
const int feature_dim = chunk_feats[0].size();
torch::Tensor feats =
torch::zeros({1, num_frames, feature_dim}, torch::kFloat);
for (size_t i = 0; i < cached_feature_.size(); ++i) {
torch::Tensor row =
torch::from_blob(const_cast<float*>(cached_feature_[i].data()),
{feature_dim}, torch::kFloat)
.clone();
feats[0][i] = std::move(row);
}
for (size_t i = 0; i < chunk_feats.size(); ++i) {
torch::Tensor row =
torch::from_blob(const_cast<float*>(chunk_feats[i].data()),
{feature_dim}, torch::kFloat)
.clone();
feats[0][cached_feature_.size() + i] = std::move(row);
}

// 2. Encoder chunk forward
int required_cache_size = chunk_size_ * num_left_chunks_;
torch::NoGradGuard no_grad;
std::vector<torch::jit::IValue> inputs = {feats, offset_, required_cache_size,
att_cache_, cnn_cache_};

// Refer interfaces in wenet/transformer/asr_model.py
auto outputs =
model_->get_method("forward_encoder_chunk")(inputs).toTuple()->elements();
CHECK_EQ(outputs.size(), 3);
torch::Tensor chunk_out = outputs[0].toTensor();
att_cache_ = outputs[1].toTensor();
cnn_cache_ = outputs[2].toTensor();
offset_ += chunk_out.size(1);

// The first dimension of returned value is for batchsize, which is 1
torch::Tensor ctc_log_probs =
model_->run_method("ctc_activation", chunk_out).toTensor()[0];
encoder_outs_.push_back(std::move(chunk_out));

// Copy to output
int num_outputs = ctc_log_probs.size(0);
int output_dim = ctc_log_probs.size(1);
out_prob->resize(num_outputs);
for (int i = 0; i < num_outputs; i++) {
(*out_prob)[i].resize(output_dim);
memcpy((*out_prob)[i].data(), ctc_log_probs[i].data_ptr(),
sizeof(float) * output_dim);
}
}

float IosAsrModel::ComputeAttentionScore(const torch::Tensor& prob,
const std::vector<int>& hyp,
int eos) {
float score = 0.0f;
auto accessor = prob.accessor<float, 2>();
for (size_t j = 0; j < hyp.size(); ++j) {
score += accessor[j][hyp[j]];
}
score += accessor[hyp.size()][eos];
return score;
}

void IosAsrModel::AttentionRescoring(
const std::vector<std::vector<int>>& hyps, float reverse_weight,
std::vector<float>* rescoring_score) {
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
rescoring_score->resize(num_hyps, 0.0f);

if (num_hyps == 0) {
return;
}
// No encoder output
if (encoder_outs_.size() == 0) {
return;
}

torch::NoGradGuard no_grad;
// Step 1: Prepare input for libtorch
torch::Tensor hyps_length = torch::zeros({num_hyps}, torch::kLong);
int max_hyps_len = 0;
for (size_t i = 0; i < num_hyps; ++i) {
int length = hyps[i].size() + 1;
max_hyps_len = std::max(length, max_hyps_len);
hyps_length[i] = static_cast<int64_t>(length);
}
torch::Tensor hyps_tensor =
torch::zeros({num_hyps, max_hyps_len}, torch::kLong);
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
hyps_tensor[i][0] = sos_;
for (size_t j = 0; j < hyp.size(); ++j) {
hyps_tensor[i][j + 1] = hyp[j];
}
}

// Step 2: Forward attention decoder by hyps and corresponding encoder_outs_
torch::Tensor encoder_out = torch::cat(encoder_outs_, 1);
auto outputs = model_
->run_method("forward_attention_decoder", hyps_tensor,
hyps_length, encoder_out, reverse_weight)
.toTuple()
->elements();

auto probs = outputs[0].toTensor();
auto r_probs = outputs[1].toTensor();

CHECK_EQ(probs.size(0), num_hyps);
CHECK_EQ(probs.size(1), max_hyps_len);

// Step 3: Compute rescoring score
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
float score = 0.0f;
// left-to-right decoder score
score = ComputeAttentionScore(probs[i], hyp, eos_);
// Optional: Used for right to left score
float r_score = 0.0f;
if (is_bidirectional_decoder_ && reverse_weight > 0) {
// right-to-left score
CHECK_EQ(r_probs.size(0), num_hyps);
CHECK_EQ(r_probs.size(1), max_hyps_len);
std::vector<int> r_hyp(hyp.size());
std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
// right to left decoder score
r_score = ComputeAttentionScore(r_probs[i], r_hyp, eos_);
}

// combined left-to-right and right-to-left score
(*rescoring_score)[i] =
score * (1 - reverse_weight) + r_score * reverse_weight;
}
}

} // namespace wenet
63 changes: 63 additions & 0 deletions runtime/core/decoder/ios_asr_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
Ma-Dan marked this conversation as resolved.
Show resolved Hide resolved
// 2022 Dan Ma (1067837450@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#ifndef DECODER_IOS_ASR_MODEL_H_
#define DECODER_IOS_ASR_MODEL_H_

#include <memory>
#include <string>
#include <vector>

#include "torch/script.h"

#include "decoder/asr_model.h"
#include "utils/utils.h"

namespace wenet {

class IosAsrModel : public AsrModel {
public:
using TorchModule = torch::jit::script::Module;
IosAsrModel() = default;
IosAsrModel(const IosAsrModel& other);
void Read(const std::string& model_path);
std::shared_ptr<TorchModule> torch_model() const { return model_; }
void Reset() override;
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override;
std::shared_ptr<AsrModel> Copy() const override;

protected:
void ForwardEncoderFunc(const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* ctc_prob) override;

float ComputeAttentionScore(const torch::Tensor& prob,
const std::vector<int>& hyp, int eos);

private:
std::shared_ptr<TorchModule> model_ = nullptr;
std::vector<torch::Tensor> encoder_outs_;
// transformer/conformer attention cache
torch::Tensor att_cache_ = torch::zeros({0, 0, 0, 0});
// conformer-only conv_module cache
torch::Tensor cnn_cache_ = torch::zeros({0, 0, 0, 0});
};

} // namespace wenet

#endif // DECODER_IOS_ASR_MODEL_H_
Empty file.
Loading