Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修改FunASR实时识别框架,实时识别时2pass模式下支持框架层面返回句子级别的时间戳,单位毫秒 #2216

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions runtime/onnxruntime/include/audio.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class DLLAPI Audio {
int offset = 0;
int speech_start=-1, speech_end=0;
int speech_offline_start=-1;
int64_t start = 0;
int64_t end = 0;

int seg_sample = MODEL_SAMPLE_RATE/1000;
bool LoadPcmwavOnline(const char* buf, int n_file_len, int32_t* sampling_rate);
Expand Down
2 changes: 2 additions & 0 deletions runtime/onnxruntime/include/funasrruntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const char* FunASRGetStamp(FUNASR_RESULT result);
_FUNASRAPI const char* FunASRGetStampSents(FUNASR_RESULT result);
_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const int64_t FunASRGetTpassStart(FUNASR_RESULT result);
_FUNASRAPI const int64_t FunASRGetTpassEnd(FUNASR_RESULT result);
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
Expand Down
14 changes: 14 additions & 0 deletions runtime/onnxruntime/src/audio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,10 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP
}
}
}else{

int sample_rate = 16000; // sample_rate 是音频的采样率 这里固定为16000 Hz
float segment_duration = (static_cast<float>(seg_sample) / sample_rate) * 1000; // 每个分段的持续时间(毫秒)

for(auto vad_segment: vad_segments){
int speech_start_i=-1, speech_end_i=-1;
if(vad_segment[0] != -1){
Expand Down Expand Up @@ -1325,6 +1329,12 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP
frame = nullptr;
}

//设置开始时间和结束时间
float start_time = speech_start_i * segment_duration; // 开始时间(毫秒)
float end_time = speech_end_i * segment_duration; // 结束时间(毫秒)
// 转换为 int64_t 类型并赋值给类的成员变量
this->start = static_cast<int64_t>(start_time);
this->end = static_cast<int64_t>(end_time);
speech_start = -1;
speech_offline_start = -1;
// [70, -1]
Expand All @@ -1350,6 +1360,8 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP
}
}

float start_time = speech_start_i * segment_duration; // 仅有开始时间
this->start = static_cast<int64_t>(start_time);
}else if(speech_end_i != -1){ // [-1,100]
if(speech_start == -1 || speech_offline_start == -1){
LOG(ERROR) <<"Vad start is null while vad end is available. Set vad start 0" ;
Expand Down Expand Up @@ -1399,6 +1411,8 @@ void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYP
frame = nullptr;
}
}
float end_time = speech_end_i * segment_duration; // 仅有结束时间
this->end = static_cast<int64_t>(end_time);
speech_start = -1;
speech_offline_start = -1;
}
Expand Down
2 changes: 2 additions & 0 deletions runtime/onnxruntime/src/commonfunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ typedef struct
std::string stamp_sents;
std::string tpass_msg;
float snippet_time;
int64_t start = 0;
int64_t end = 0;
}FUNASR_RECOG_RESULT;

typedef struct
Expand Down
19 changes: 19 additions & 0 deletions runtime/onnxruntime/src/funasrruntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,8 @@
p_result->snippet_time = audio->GetTimeLen();

audio->Split(vad_online_handle, chunk_len, input_finished, mode);
p_result->start = audio->start;
p_result->end = audio->end;

funasr::AudioFrame* frame = nullptr;
while(audio->FetchChunck(frame) > 0){
Expand Down Expand Up @@ -695,6 +697,23 @@
return p_result->tpass_msg.c_str();
}

_FUNASRAPI const int64_t FunASRGetTpassStart(FUNASR_RESULT result)
{
funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
if(!p_result)
return 0;

return p_result->start;
}
_FUNASRAPI const int64_t FunASRGetTpassEnd(FUNASR_RESULT result)
{
funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
if(!p_result)
return 0;

return p_result->end;
}

_FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index)
{
funasr::FUNASR_PUNC_RESULT * p_result = (funasr::FUNASR_PUNC_RESULT*)result;
Expand Down
39 changes: 36 additions & 3 deletions runtime/websocket/bin/websocket-server-2pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@
#include <thread>
#include <utility>
#include <vector>
#include <iostream>
#include <chrono>

extern std::unordered_map<std::string, int> hws_map_;
extern int fst_inc_wts_;
extern float global_beam_, lattice_beam_, am_scale_;

int64_t getCurrentTimeMillis() {
auto now = std::chrono::system_clock::now();
auto millis = std::chrono::duration_cast<std::chrono::milliseconds>(now.time_since_epoch()).count();
return millis;
}

context_ptr WebSocketServer::on_tls_init(tls_mode mode,
websocketpp::connection_hdl hdl,
std::string& s_certfile,
Expand Down Expand Up @@ -57,7 +65,13 @@ context_ptr WebSocketServer::on_tls_init(tls_mode mode,
return ctx;
}

nlohmann::json handle_result(FUNASR_RESULT result) {
nlohmann::json handle_result(FUNASR_RESULT result, websocketpp::connection_hdl& hdl, std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,std::owner_less<websocketpp::connection_hdl>>& data_map) {
std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
auto it = data_map.find(hdl);
if (it != data_map.end()) {
data_msg = it->second;
}

websocketpp::lib::error_code ec;
nlohmann::json jsonresult;
jsonresult["text"] = "";
Expand All @@ -67,12 +81,26 @@ nlohmann::json handle_result(FUNASR_RESULT result) {
LOG(INFO) << "online_res :" << tmp_online_msg;
jsonresult["text"] = tmp_online_msg;
jsonresult["mode"] = "2pass-online";

// 如果是第一句话的第一个实时结果或新的句子开始
if (!data_msg->is_sentence_started) {
data_msg->start_time = FunASRGetTpassStart(result); // 记录句子的开始时间
data_msg->is_sentence_started = true;
}
}

data_msg->end_time = FunASRGetTpassEnd(result); // 记录句子的结束时间

std::string tmp_tpass_msg = FunASRGetTpassResult(result, 0);
if (tmp_tpass_msg != "") {
LOG(INFO) << "offline results : " << tmp_tpass_msg;
jsonresult["text"] = tmp_tpass_msg;
jsonresult["mode"] = "2pass-offline";

// 句子结束,记录结束时间
jsonresult["start_time"] = data_msg->start_time;
jsonresult["end_time"] = data_msg->end_time;
data_msg->is_sentence_started = false; // 重置句子状态
}

std::string tmp_stamp_msg = FunASRGetStamp(result);
Expand All @@ -98,6 +126,7 @@ nlohmann::json handle_result(FUNASR_RESULT result) {
}
// feed buffer to asr engine for decoder
void WebSocketServer::do_decoder(
std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,std::owner_less<websocketpp::connection_hdl>>& data_map,
std::vector<char>& buffer,
websocketpp::connection_hdl& hdl,
nlohmann::json& msg,
Expand Down Expand Up @@ -158,7 +187,7 @@ void WebSocketServer::do_decoder(
}
if (Result) {
websocketpp::lib::error_code ec;
nlohmann::json jsonresult = handle_result(Result);
nlohmann::json jsonresult = handle_result(Result, hdl, data_map);
jsonresult["wav_name"] = wav_name;
jsonresult["is_final"] = false;
if (jsonresult["text"] != "") {
Expand Down Expand Up @@ -200,7 +229,7 @@ void WebSocketServer::do_decoder(
}
if (Result) {
websocketpp::lib::error_code ec;
nlohmann::json jsonresult = handle_result(Result);
nlohmann::json jsonresult = handle_result(Result, hdl, data_map);
jsonresult["wav_name"] = wav_name;
jsonresult["is_final"] = true;
if (is_ssl) {
Expand Down Expand Up @@ -263,6 +292,8 @@ void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
std::make_shared<std::vector<std::vector<std::string>>>(2);
data_msg->strand_ = std::make_shared<asio::io_context::strand>(io_decoder_);

data_msg->is_sentence_started = false;

data_map.emplace(hdl, data_msg);
}catch (std::exception const& e) {
std::cerr << "Error: " << e.what() << std::endl;
Expand Down Expand Up @@ -501,6 +532,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
std::vector<std::vector<float>> hotwords_embedding_(*(msg_data->hotwords_embedding));
msg_data->strand_->post(
std::bind(&WebSocketServer::do_decoder, this,
data_map,
std::move(*(sample_data_p.get())), std::move(hdl),
std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
std::move(hotwords_embedding_),
Expand Down Expand Up @@ -550,6 +582,7 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
std::vector<std::vector<float>> hotwords_embedding_(*(msg_data->hotwords_embedding));
msg_data->strand_->post(
std::bind(&WebSocketServer::do_decoder, this,
data_map,
std::move(subvector), std::move(hdl),
std::ref(msg_data->msg),
std::ref(*(punc_cache_p.get())),
Expand Down
10 changes: 8 additions & 2 deletions runtime/websocket/bin/websocket-server-2pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ typedef struct {
std::string online_res = "";
std::string tpass_res = "";
std::shared_ptr<asio::io_context::strand> strand_; // for data execute in order
FUNASR_DEC_HANDLE decoder_handle=nullptr;
FUNASR_DEC_HANDLE decoder_handle=nullptr;

bool is_sentence_started = false;
int64_t start_time = 0;
int64_t end_time = 0;
} FUNASR_MESSAGE;

// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
Expand Down Expand Up @@ -114,7 +118,9 @@ class WebSocketServer {
server_->clear_access_channels(websocketpp::log::alevel::all);
}
}
void do_decoder(std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
void do_decoder(std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,std::owner_less<websocketpp::connection_hdl>>& data_map,
std::vector<char>& buffer,
websocketpp::connection_hdl& hdl,
nlohmann::json& msg,
std::vector<std::vector<std::string>>& punc_cache,
std::vector<std::vector<float>> &hotwords_embedding,
Expand Down