From 8f496586431eab72f757dfd24eddfda18553e427 Mon Sep 17 00:00:00 2001 From: user01 Date: Thu, 7 Dec 2023 10:49:55 +0800 Subject: [PATCH] [lint] auto format all by pre-commit, including c++, python --- .clang-format | 93 +++++++++++++++ .../v2/local/choose_utts_to_combine.py | 18 +-- examples/sre/v2/local/filter_utt_accd_dur.py | 1 - examples/sre/v2/local/generate_sre_aug.py | 1 - examples/sre/v2/local/make_system_sad.py | 28 +++-- examples/voxconverse/v1/diar/clusterer.py | 60 ++++++---- examples/voxconverse/v1/diar/clusterer_gpu.py | 22 +++- examples/voxconverse/v1/diar/make_rttm.py | 9 +- .../voxconverse/v1/sad/make_oracle_sad.py | 13 +- .../voxconverse/v1/sad/make_system_sad.py | 21 ++-- examples/voxconverse/v2/diar/extract_emb.py | 9 +- examples/voxconverse/v2/diar/make_fbank.py | 9 +- examples/voxconverse/v2/diar/make_rttm.py | 9 +- .../voxconverse/v2/diar/spectral_clusterer.py | 12 +- runtime/core/bin/asv_main.cc | 28 ++--- runtime/core/bin/extract_emb_main.cc | 16 +-- runtime/core/frontend/feature_pipeline.cc | 1 - runtime/core/frontend/wav.h | 4 +- runtime/core/speaker/bpu_speaker_model.cc | 53 +++++---- runtime/core/speaker/bpu_speaker_model.h | 13 +- runtime/core/speaker/onnx_speaker_model.cc | 29 +++-- runtime/core/speaker/speaker_engine.cc | 67 +++++------ runtime/core/speaker/speaker_engine.h | 19 ++- runtime/core/speaker/speaker_model.h | 3 +- runtime/core/utils/utils.cc | 10 +- runtime/core/utils/utils.h | 2 +- .../server/diarization_gpu/client/client.py | 23 ++-- .../model_repo/clusterer/1/model.py | 8 +- .../diarization_gpu/model_repo/run/1/model.py | 111 ++++++++++-------- runtime/server/x86_gpu/client/client.py | 27 +++-- .../server/x86_gpu/client/generate_input.py | 27 +++-- .../model_repo/feature_extractor/1/model.py | 8 +- setup.py | 8 +- tools/make_raw_list.py | 5 +- tools/make_shard_list.py | 8 +- tools/onnx2horizonbin.py | 75 +++++++----- tools/vector_mean.py | 9 +- tools/wav2dur.py | 1 + wespeaker/bin/adapt_plda.py | 19 +-- wespeaker/bin/average_model.py | 13 +- wespeaker/bin/eval_plda.py | 14 ++- wespeaker/bin/export_onnx.py | 31 +++-- wespeaker/bin/export_onnx_bpu.py | 28 +++-- wespeaker/bin/extract_deprecated.py | 3 +- wespeaker/bin/infer_onnx.py | 8 +- wespeaker/bin/score.py | 10 +- wespeaker/bin/score_norm.py | 10 +- wespeaker/bin/train.py | 6 +- wespeaker/bin/train_deprecated.py | 17 +-- wespeaker/bin/train_plda.py | 2 +- wespeaker/cli/speaker.py | 6 +- wespeaker/dataset/dataset.py | 40 ++++--- wespeaker/dataset/dataset_deprecated.py | 20 ++-- wespeaker/dataset/processor.py | 17 ++- wespeaker/models/campplus.py | 27 +++-- wespeaker/models/convert_repvgg.py | 4 +- wespeaker/models/ecapa_tdnn.py | 30 +++-- wespeaker/models/pooling_layers.py | 6 +- wespeaker/models/repvgg.py | 14 +-- wespeaker/models/resnet.py | 21 ++-- wespeaker/models/speaker_model.py | 1 + wespeaker/models/tdnn.py | 3 +- .../ssl/bin/average_contrastive_model.py | 12 +- wespeaker/ssl/bin/average_dino_model.py | 12 +- wespeaker/ssl/dataset/dataset.py | 12 +- wespeaker/ssl/utils/contrastive_executor.py | 3 +- wespeaker/utils/dataset_utils_deprecated.py | 3 +- wespeaker/utils/executor.py | 13 +- wespeaker/utils/file_utils.py | 1 + wespeaker/utils/plda/kaldi_utils.py | 6 +- wespeaker/utils/plda/two_cov_plda.py | 49 ++++---- 71 files changed, 779 insertions(+), 542 deletions(-) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..91dcbc07 --- /dev/null +++ b/.clang-format @@ -0,0 +1,93 @@ +--- +Language: Cpp +# BasedOnStyle: Google +AccessModifierOffset: -1 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: true +AllowShortLoopsOnASingleLine: true +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] +IncludeCategories: + - Regex: '^<.*\.h>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Auto +TabWidth: 8 +UseTab: Never +... diff --git a/examples/cnceleb/v2/local/choose_utts_to_combine.py b/examples/cnceleb/v2/local/choose_utts_to_combine.py index fd4a7a4c..f718abb1 100755 --- a/examples/cnceleb/v2/local/choose_utts_to_combine.py +++ b/examples/cnceleb/v2/local/choose_utts_to_combine.py @@ -54,12 +54,11 @@ "because this script tries to merge utterances from the " "same speaker as much as possible, and also needs to produce" "an output utt2spk map.") -parser.add_argument( - "utt2dur_in", - type=str, - metavar="", - help="Filename of [input] utterance-to-duration map, with lines like 'utt1 1.23'." -) +parser.add_argument("utt2dur_in", + type=str, + metavar="", + help="Filename of [input] utterance-to-duration map, " + "with lines like 'utt1 1.23'.") parser.add_argument( "utt2utts_out", type=str, @@ -70,9 +69,10 @@ "utt2spk_out", type=str, metavar="", - help="Filename of [output] utt2spk map, which maps new utterances to original " - "speakers. If utterances were combined across speakers, we map the new " - "utterance to the speaker that contributed the most to them.") + help="Filename of [output] utt2spk map, which maps new utterances to " + "original speakers. If utterances were combined across speakers, " + "we map the new utterance to the speaker that contributed the most to them." +) parser.add_argument( "utt2dur_out", type=str, diff --git a/examples/sre/v2/local/filter_utt_accd_dur.py b/examples/sre/v2/local/filter_utt_accd_dur.py index 2541e396..fff80d97 100644 --- a/examples/sre/v2/local/filter_utt_accd_dur.py +++ b/examples/sre/v2/local/filter_utt_accd_dur.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import fire diff --git a/examples/sre/v2/local/generate_sre_aug.py b/examples/sre/v2/local/generate_sre_aug.py index 29bda825..e5fdfb14 100644 --- a/examples/sre/v2/local/generate_sre_aug.py +++ b/examples/sre/v2/local/generate_sre_aug.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os import fire diff --git a/examples/sre/v2/local/make_system_sad.py b/examples/sre/v2/local/make_system_sad.py index 84d7d7df..1c629ea4 100644 --- a/examples/sre/v2/local/make_system_sad.py +++ b/examples/sre/v2/local/make_system_sad.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os + os.environ["OMP_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" @@ -35,20 +35,21 @@ def get_args(): parser = argparse.ArgumentParser(description='') - parser.add_argument('--repo-path', required=True, + parser.add_argument('--repo-path', + required=True, help='VAD model repo path') parser.add_argument('--scp', required=True, help='wav scp') - parser.add_argument('--min-duration', required=True, - type=float, help='min duration') + parser.add_argument('--min-duration', + required=True, + type=float, + help='min duration') args = parser.parse_args() return args @functools.lru_cache(maxsize=1) -def load_wav( - wav_rxfilename, -): +def load_wav(wav_rxfilename, ): """ This function reads audio file and return data in pytorch tensor. "lru_cache" holds recently loaded audio so that can be called many times on the same audio file. @@ -57,7 +58,8 @@ def load_wav( """ if wav_rxfilename.endswith('|'): # input piped command - p = subprocess.Popen(wav_rxfilename[:-1], shell=True, + p = subprocess.Popen(wav_rxfilename[:-1], + shell=True, stdout=subprocess.PIPE) data, samplerate = torchaudio.load(io.BytesIO(p.stdout.read())) elif wav_rxfilename == '-': @@ -82,8 +84,11 @@ def read_scp(scp): return utt_wav_pair -def silero_vad(utt_wav_pair, repo_path, min_duration, - sampling_rate=8000, threshold=0.25): +def silero_vad(utt_wav_pair, + repo_path, + min_duration, + sampling_rate=8000, + threshold=0.25): def module_from_file(module_name, file_path): spec = importlib.util.spec_from_file_location(module_name, file_path) @@ -102,8 +107,7 @@ def module_from_file(module_name, file_path): wav, sr = load_wav(wav) assert sr == sampling_rate speech_timestamps = utils_vad.get_speech_timestamps( - wav, model, sampling_rate=sampling_rate, - threshold=threshold) + wav, model, sampling_rate=sampling_rate, threshold=threshold) vad_result = "" for item in speech_timestamps: diff --git a/examples/voxconverse/v1/diar/clusterer.py b/examples/voxconverse/v1/diar/clusterer.py index 7ed033c4..0ecfa573 100644 --- a/examples/voxconverse/v1/diar/clusterer.py +++ b/examples/voxconverse/v1/diar/clusterer.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os + os.environ["OMP_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" @@ -42,20 +42,28 @@ def get_args(): parser.add_argument('--scp', required=True, help='wav scp') parser.add_argument('--segments', required=True, help='vad segments') parser.add_argument('--output', required=True, help='output label file') - parser.add_argument('--source', required=True, - help='onnx model') - parser.add_argument('--device', default='cuda', + parser.add_argument('--source', required=True, help='onnx model') + parser.add_argument('--device', + default='cuda', help='inference device type: cpu or cuda') - parser.add_argument('--batch-size', type=int, default=96, + parser.add_argument('--batch-size', + type=int, + default=96, help='batch size for embedding extraction') args = parser.parse_args() return args -def compute_embeddings(scp, segments, source, device, - batch_size, sampling_rate=16000, - window_secs=1.50, period_secs=0.75, frame_shift=10): +def compute_embeddings(scp, + segments, + source, + device, + batch_size, + sampling_rate=16000, + window_secs=1.50, + period_secs=0.75, + frame_shift=10): def read_segments(segments): utt_to_segments = OrderedDict() @@ -97,13 +105,12 @@ def subsegment(wav, segments, window_fs, period_fs): for (seg, begin, end) in segments: seg_begin = int(begin * sampling_rate) seg_end = int(end * sampling_rate) - seg_signal = signal[seg_begin: seg_end + 1, :] + seg_signal = signal[seg_begin:seg_end + 1, :] seg_length = seg_end - seg_begin if seg_length <= window_fs: subseg = seg + "-{:08d}-{:08d}".format( - 0, - int(seg_length / sampling_rate * 1000 // frame_shift)) + 0, int(seg_length / sampling_rate * 1000 // frame_shift)) subseg_signal = repeat_to_fill(seg_signal, window_fs) subsegs.append(subseg) @@ -116,15 +123,19 @@ def subsegment(wav, segments, window_fs, period_fs): int(subseg_begin / sampling_rate * 1000 / frame_shift), int(subseg_end / sampling_rate * 1000 / frame_shift)) subseg_signal = repeat_to_fill( - seg_signal[subseg_begin: subseg_end + 1, :], window_fs) + seg_signal[subseg_begin:subseg_end + 1, :], window_fs) subsegs.append(subseg) subseg_signals.append(subseg_signal) return subsegs, subseg_signals - def compute_fbank(wavs, num_mel_bins=80, frame_length=25, - frame_shift=10, dither=0.0, sample_frequency=16000): + def compute_fbank(wavs, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + sample_frequency=16000): feats = [] for wav in wavs: @@ -155,14 +166,15 @@ def init_session(source, device): opts = ort.SessionOptions() opts.inter_op_num_threads = 1 opts.intra_op_num_threads = 1 - session = ort.InferenceSession(source, sess_options=opts, + session = ort.InferenceSession(source, + sess_options=opts, providers=providers) return session def extract_embeddings(wavs, batch_size): embeddings = [] for i in range(0, wavs.size(0), batch_size): - batch_wavs = wavs[i: i + batch_size, :] + batch_wavs = wavs[i:i + batch_size, :] batch_feats = compute_fbank(batch_wavs) batch_embs = session.run(input_feed={'feats': batch_feats.numpy()}, output_names=['embs'])[0].squeeze() @@ -193,8 +205,8 @@ def extract_embeddings(wavs, batch_size): segments = utt_to_segments[utt] # Extract wav data using sliding window with overlap for each utterance - utt_subsegs, utt_subseg_signals = subsegment(wav, segments, - window_fs, period_fs) + utt_subsegs, utt_subseg_signals = subsegment(wav, segments, window_fs, + period_fs) # Convert a list of Tensor to a Tensor utt_subseg_signals = torch.stack(utt_subseg_signals).squeeze(-1) @@ -256,8 +268,8 @@ def kmeans(data): # Compute Laplacian laplacian_matrix = laplacian(pruned_similarity_matrix) # Compute spectral embeddings - spectral_embeddings = spectral(laplacian_matrix, num_spks, - min_num_spks, max_num_spks) + spectral_embeddings = spectral(laplacian_matrix, num_spks, min_num_spks, + max_num_spks) # Assign class labels labels = kmeans(spectral_embeddings) @@ -268,8 +280,7 @@ def main(): args = get_args() print('Segmenting and extracting speaker embeddings') - subsegs_list, embeddings_list = compute_embeddings(args.scp, - args.segments, + subsegs_list, embeddings_list = compute_embeddings(args.scp, args.segments, args.source, args.device, args.batch_size) @@ -279,7 +290,10 @@ def main(): with cf.ProcessPoolExecutor() as executor, open(args.output, 'w') as f: for (subsegs, labels) in zip(subsegs_list, executor.map(cluster, embeddings_list)): - [print(subseg, label, file=f) for (subseg, label) in zip(subsegs, labels)] + [ + print(subseg, label, file=f) + for (subseg, label) in zip(subsegs, labels) + ] if __name__ == '__main__': diff --git a/examples/voxconverse/v1/diar/clusterer_gpu.py b/examples/voxconverse/v1/diar/clusterer_gpu.py index d85a0022..d6b90ccb 100644 --- a/examples/voxconverse/v1/diar/clusterer_gpu.py +++ b/examples/voxconverse/v1/diar/clusterer_gpu.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os + os.environ["OMP_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" @@ -26,7 +27,12 @@ import scipy import torch -def cluster_gpu(embeddings, p=.01, num_spks=None, min_num_spks=1, max_num_spks=20): + +def cluster_gpu(embeddings, + p=.01, + num_spks=None, + min_num_spks=1, + max_num_spks=20): # Define utility functions def cosine_similarity(M): M = M / cp.linalg.norm(M, axis=1, keepdims=True) @@ -78,13 +84,14 @@ def kmeans(data): # Compute Laplacian laplacian_matrix = laplacian(pruned_similarity_matrix) # Compute spectral embeddings - spectral_embeddings = spectral(laplacian_matrix, num_spks, - min_num_spks, max_num_spks) + spectral_embeddings = spectral(laplacian_matrix, num_spks, min_num_spks, + max_num_spks) # Assign class labels labels = kmeans(spectral_embeddings) return labels + def test_time(): a = np.random.rand(1000, 256) @@ -106,11 +113,11 @@ def with_cuda(x, count): elapsed_time = timer() - start print("CPU Time: {}".format(elapsed_time)) + def main(): args = get_args() print('Segmenting and extracting speaker embeddings') - subsegs_list, embeddings_list = compute_embeddings(args.scp, - args.segments, + subsegs_list, embeddings_list = compute_embeddings(args.scp, args.segments, args.source, args.device, args.batch_size) @@ -123,7 +130,10 @@ def main(): for i in embeddings_list: labels_list.append(cluster_gpu(cp.asarray(i))) for (subsegs, labels) in zip(subsegs_list, labels_list): - [print(subseg, label, file=f) for (subseg, label) in zip(subsegs, labels)] + [ + print(subseg, label, file=f) + for (subseg, label) in zip(subsegs, labels) + ] if __name__ == '__main__': diff --git a/examples/voxconverse/v1/diar/make_rttm.py b/examples/voxconverse/v1/diar/make_rttm.py index cbba941b..3b5e1733 100644 --- a/examples/voxconverse/v1/diar/make_rttm.py +++ b/examples/voxconverse/v1/diar/make_rttm.py @@ -12,23 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse from collections import OrderedDict def get_args(): parser = argparse.ArgumentParser(description='') - parser.add_argument('--labels', required=True, + parser.add_argument('--labels', + required=True, help='class labels generated by clusterer') - parser.add_argument('--channel', type=int, default=1, + parser.add_argument('--channel', + type=int, + default=1, help='channel number in RTTM format') args = parser.parse_args() return args - def read_labels(labels, frame_shift=10): utt_to_subseg_labels = OrderedDict() for line in open(labels, 'r'): diff --git a/examples/voxconverse/v1/sad/make_oracle_sad.py b/examples/voxconverse/v1/sad/make_oracle_sad.py index b2ab4888..2c8fe1dd 100644 --- a/examples/voxconverse/v1/sad/make_oracle_sad.py +++ b/examples/voxconverse/v1/sad/make_oracle_sad.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse from collections import OrderedDict @@ -20,8 +19,10 @@ def get_args(): parser = argparse.ArgumentParser(description='') parser.add_argument('--rttm', required=True, help='reference rttm') - parser.add_argument('--min-duration', required=True, - type=float, help='min duration') + parser.add_argument('--min-duration', + required=True, + type=float, + help='min duration') args = parser.parse_args() return args @@ -67,6 +68,7 @@ def merge_segments(utt_to_segments, min_duration): return utt_to_merged_segments + def main(): args = get_args() @@ -76,8 +78,9 @@ def main(): segments_line_spec = "{}-{:08d}-{:08d} {} {:.3f} {:.3f}" for utt, segments in utt_to_merged_segments.items(): for (begin, end) in segments: - print(segments_line_spec.format( - utt, int(begin * 1000), int(end * 1000), utt, begin, end)) + print( + segments_line_spec.format(utt, int(begin * 1000), + int(end * 1000), utt, begin, end)) if __name__ == '__main__': diff --git a/examples/voxconverse/v1/sad/make_system_sad.py b/examples/voxconverse/v1/sad/make_system_sad.py index f220bfd1..70a49fc8 100644 --- a/examples/voxconverse/v1/sad/make_system_sad.py +++ b/examples/voxconverse/v1/sad/make_system_sad.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os + os.environ["OMP_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" @@ -31,11 +31,14 @@ def get_args(): parser = argparse.ArgumentParser(description='') - parser.add_argument('--repo-path', required=True, + parser.add_argument('--repo-path', + required=True, help='VAD model repo path') parser.add_argument('--scp', required=True, help='wav scp') - parser.add_argument('--min-duration', required=True, - type=float, help='min duration') + parser.add_argument('--min-duration', + required=True, + type=float, + help='min duration') args = parser.parse_args() return args @@ -50,8 +53,11 @@ def read_scp(scp): return utt_wav_pair -def silero_vad(utt_wav_pair, repo_path, min_duration, - sampling_rate=16000, threshold=0.25): +def silero_vad(utt_wav_pair, + repo_path, + min_duration, + sampling_rate=16000, + threshold=0.25): def module_from_file(module_name, file_path): spec = importlib.util.spec_from_file_location(module_name, file_path) @@ -69,8 +75,7 @@ def module_from_file(module_name, file_path): wav = utils_vad.read_audio(wav, sampling_rate=sampling_rate) speech_timestamps = utils_vad.get_speech_timestamps( - wav, model, sampling_rate=sampling_rate, - threshold=threshold) + wav, model, sampling_rate=sampling_rate, threshold=threshold) vad_result = "" for item in speech_timestamps: diff --git a/examples/voxconverse/v2/diar/extract_emb.py b/examples/voxconverse/v2/diar/extract_emb.py index 80074ec1..c2c140d1 100644 --- a/examples/voxconverse/v2/diar/extract_emb.py +++ b/examples/voxconverse/v2/diar/extract_emb.py @@ -127,11 +127,10 @@ def get_args(): type=float, default=0.75, help='the shift seconds in embedding extraction') - parser.add_argument( - '--subseg-cmn', - default=True, - type=lambda x: x.lower() == 'true', - help='do cmn after or before fbank sub-segmentation') + parser.add_argument('--subseg-cmn', + default=True, + type=lambda x: x.lower() == 'true', + help='do cmn after or before fbank sub-segmentation') args = parser.parse_args() return args diff --git a/examples/voxconverse/v2/diar/make_fbank.py b/examples/voxconverse/v2/diar/make_fbank.py index 90e6e141..619f8b1d 100644 --- a/examples/voxconverse/v2/diar/make_fbank.py +++ b/examples/voxconverse/v2/diar/make_fbank.py @@ -95,11 +95,10 @@ def get_args(): parser.add_argument('--ark-path', required=True, help='path to store feat ark') - parser.add_argument( - '--subseg-cmn', - default=True, - type=lambda x: x.lower() == 'true', - help='do cmn after or before fbank sub-segmentation') + parser.add_argument('--subseg-cmn', + default=True, + type=lambda x: x.lower() == 'true', + help='do cmn after or before fbank sub-segmentation') args = parser.parse_args() return args diff --git a/examples/voxconverse/v2/diar/make_rttm.py b/examples/voxconverse/v2/diar/make_rttm.py index cbba941b..3b5e1733 100644 --- a/examples/voxconverse/v2/diar/make_rttm.py +++ b/examples/voxconverse/v2/diar/make_rttm.py @@ -12,23 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse from collections import OrderedDict def get_args(): parser = argparse.ArgumentParser(description='') - parser.add_argument('--labels', required=True, + parser.add_argument('--labels', + required=True, help='class labels generated by clusterer') - parser.add_argument('--channel', type=int, default=1, + parser.add_argument('--channel', + type=int, + default=1, help='channel number in RTTM format') args = parser.parse_args() return args - def read_labels(labels, frame_shift=10): utt_to_subseg_labels = OrderedDict() for line in open(labels, 'r'): diff --git a/examples/voxconverse/v2/diar/spectral_clusterer.py b/examples/voxconverse/v2/diar/spectral_clusterer.py index ec2b377b..769421e0 100644 --- a/examples/voxconverse/v2/diar/spectral_clusterer.py +++ b/examples/voxconverse/v2/diar/spectral_clusterer.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - import os + os.environ["OMP_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" @@ -80,8 +80,8 @@ def kmeans(data): # Compute Laplacian laplacian_matrix = laplacian(pruned_similarity_matrix) # Compute spectral embeddings - spectral_embeddings = spectral(laplacian_matrix, num_spks, - min_num_spks, max_num_spks) + spectral_embeddings = spectral(laplacian_matrix, num_spks, min_num_spks, + max_num_spks) # Assign class labels labels = kmeans(spectral_embeddings) @@ -118,6 +118,7 @@ def get_args(): return args + def main(): args = get_args() @@ -127,7 +128,10 @@ def main(): with cf.ProcessPoolExecutor() as executor, open(args.output, 'w') as f: for (subsegs, labels) in zip(subsegs_list, executor.map(cluster, embeddings_list)): - [print(subseg, label, file=f) for (subseg, label) in zip(subsegs, labels)] + [ + print(subseg, label, file=f) + for (subseg, label) in zip(subsegs, labels) + ] if __name__ == '__main__': diff --git a/runtime/core/bin/asv_main.cc b/runtime/core/bin/asv_main.cc index ea45c1fd..f61a6ddb 100644 --- a/runtime/core/bin/asv_main.cc +++ b/runtime/core/bin/asv_main.cc @@ -13,13 +13,12 @@ // limitations under the License. #include + #include "frontend/wav.h" -#include "utils/utils.h" #include "gflags/gflags.h" -#include "utils/timer.h" - #include "speaker/speaker_engine.h" - +#include "utils/timer.h" +#include "utils/utils.h" DEFINE_string(enroll_wav, "", "First wav as enroll wav."); DEFINE_string(test_wav, "", "Second wav as test wav."); @@ -31,8 +30,6 @@ DEFINE_int32(sample_rate, 16000, "sample rate"); DEFINE_int32(embedding_size, 256, "embedding size"); DEFINE_int32(SamplesPerChunk, 32000, "samples of one chunk"); - - int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -41,8 +38,8 @@ int main(int argc, char* argv[]) { LOG(INFO) << FLAGS_speaker_model_path; LOG(INFO) << "Init model ..."; auto speaker_engine = std::make_shared( - FLAGS_speaker_model_path, FLAGS_fbank_dim, FLAGS_sample_rate, - FLAGS_embedding_size, FLAGS_SamplesPerChunk); + FLAGS_speaker_model_path, FLAGS_fbank_dim, FLAGS_sample_rate, + FLAGS_embedding_size, FLAGS_SamplesPerChunk); int embedding_size = speaker_engine->EmbeddingSize(); LOG(INFO) << "embedding size: " << embedding_size; // read enroll wav/pcm data @@ -52,26 +49,21 @@ int main(int argc, char* argv[]) { // NOTE(cdliang): memory allocation std::vector enroll_embs(embedding_size, 0); int enroll_wave_dur = static_cast(static_cast(enroll_samples) / - data_reader->sample_rate() * 1000); + data_reader->sample_rate() * 1000); LOG(INFO) << enroll_wave_dur; - speaker_engine->ExtractEmbedding(enroll_data, - enroll_samples, - &enroll_embs); + speaker_engine->ExtractEmbedding(enroll_data, enroll_samples, &enroll_embs); // test wav auto test_data_reader = wenet::ReadAudioFile(FLAGS_test_wav); int16_t* test_data = const_cast(test_data_reader->data()); int test_samples = test_data_reader->num_sample(); std::vector test_embs(embedding_size, 0); int test_wave_dur = static_cast(static_cast(test_samples) / - test_data_reader->sample_rate() * 1000); + test_data_reader->sample_rate() * 1000); LOG(INFO) << test_wave_dur; - speaker_engine->ExtractEmbedding(test_data, - test_samples, - &test_embs); + speaker_engine->ExtractEmbedding(test_data, test_samples, &test_embs); float cosine_score; LOG(INFO) << "compute score ..."; - cosine_score = speaker_engine->CosineSimilarity(enroll_embs, - test_embs); + cosine_score = speaker_engine->CosineSimilarity(enroll_embs, test_embs); LOG(INFO) << "Cosine socre: " << cosine_score; if (cosine_score >= FLAGS_threshold) { LOG(INFO) << "It's the same speaker!"; diff --git a/runtime/core/bin/extract_emb_main.cc b/runtime/core/bin/extract_emb_main.cc index 0bb3f824..d551a106 100644 --- a/runtime/core/bin/extract_emb_main.cc +++ b/runtime/core/bin/extract_emb_main.cc @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include -#include #include +#include +#include #include "frontend/wav.h" -#include "utils/utils.h" -#include "utils/timer.h" #include "speaker/speaker_engine.h" +#include "utils/timer.h" +#include "utils/utils.h" DEFINE_string(wav_list, "", "input wav scp"); DEFINE_string(result, "", "output embedding file"); @@ -38,8 +38,8 @@ int main(int argc, char* argv[]) { // init model LOG(INFO) << "Init model ..."; auto speaker_engine = std::make_shared( - FLAGS_speaker_model_path, FLAGS_fbank_dim, FLAGS_sample_rate, - FLAGS_embedding_size, FLAGS_SamplesPerChunk); + FLAGS_speaker_model_path, FLAGS_fbank_dim, FLAGS_sample_rate, + FLAGS_embedding_size, FLAGS_SamplesPerChunk); int embedding_size = speaker_engine->EmbeddingSize(); LOG(INFO) << "embedding size: " << embedding_size; // read wav.scp @@ -58,11 +58,11 @@ int main(int argc, char* argv[]) { if (!FLAGS_result.empty()) { result.open(FLAGS_result, std::ios::out); } - std::ostream &buffer = FLAGS_result.empty() ? std::cout : result; + std::ostream& buffer = FLAGS_result.empty() ? std::cout : result; int total_waves_dur = 0; int total_extract_time = 0; - for (auto &wav : waves) { + for (auto& wav : waves) { auto data_reader = wenet::ReadAudioFile(wav.second); CHECK_EQ(data_reader->sample_rate(), 16000); int16_t* data = const_cast(data_reader->data()); diff --git a/runtime/core/frontend/feature_pipeline.cc b/runtime/core/frontend/feature_pipeline.cc index 20a47d9c..c59d6f73 100644 --- a/runtime/core/frontend/feature_pipeline.cc +++ b/runtime/core/frontend/feature_pipeline.cc @@ -46,7 +46,6 @@ void FeaturePipeline::AcceptWaveform(const std::vector& wav) { finish_condition_.notify_one(); } - void FeaturePipeline::AcceptWaveform(const std::vector& wav) { std::vector float_wav(wav.size()); for (size_t i = 0; i < wav.size(); i++) { diff --git a/runtime/core/frontend/wav.h b/runtime/core/frontend/wav.h index 90abf475..29931e65 100644 --- a/runtime/core/frontend/wav.h +++ b/runtime/core/frontend/wav.h @@ -22,12 +22,12 @@ #include #include -#include #include #include +#include -#include "glog/logging.h" #include "gflags/gflags.h" +#include "glog/logging.h" DEFINE_int32(pcm_sample_rate, 16000, "pcm data sample rate"); diff --git a/runtime/core/speaker/bpu_speaker_model.cc b/runtime/core/speaker/bpu_speaker_model.cc index 08d4163e..c99f08bd 100644 --- a/runtime/core/speaker/bpu_speaker_model.cc +++ b/runtime/core/speaker/bpu_speaker_model.cc @@ -15,25 +15,25 @@ #ifdef USE_BPU #include "speaker/bpu_speaker_model.h" -#include + #include -#include "glog/logging.h" +#include #include "easy_dnn/data_structure.h" #include "easy_dnn/model_manager.h" #include "easy_dnn/task_manager.h" +#include "glog/logging.h" using hobot::easy_dnn::ModelManager; using hobot::easy_dnn::Task; using hobot::easy_dnn::TaskManager; - namespace wespeaker { void BpuSpeakerModel::AllocMemory( - std::vector>* input_dnn_tensor_array, - std::vector>* output_dnn_tensor_array, - Model* model) { + std::vector>* input_dnn_tensor_array, + std::vector>* output_dnn_tensor_array, + Model* model) { int32_t input_counts = model->GetInputCount(); LOG(INFO) << "input_counts: " << input_counts; input_dnn_tensor_array->resize(input_counts); @@ -45,8 +45,7 @@ void BpuSpeakerModel::AllocMemory( if (input->properties.tensorType != hbDNNDataType::HB_DNN_TENSOR_TYPE_F32) { LOG(FATAL) << "Input data type must be float32"; } - hbSysAllocCachedMem(&(input->sysMem[0]), - input->properties.alignedByteSize); + hbSysAllocCachedMem(&(input->sysMem[0]), input->properties.alignedByteSize); } // stage-2: output int32_t output_counts = model->GetOutputCount(); @@ -78,9 +77,9 @@ void BpuSpeakerModel::Read(const std::string& model_path) { // Model_path is bin model egs: speaker_resnet34.bin ret_code = model_manager->Load(models, model_path); if (ret_code != 0) { - LOG(FATAL) << "easydn error code: " - << ", error loading bpu model speaker_model.bin at " - << model_path; + LOG(FATAL) << "easydn error code: " + << ", error loading bpu model speaker_model.bin at " + << model_path; } // get model handle speaker_dnn_handle_ = model_manager->GetModel([](Model* model) { @@ -98,8 +97,8 @@ BpuSpeakerModel::BpuSpeakerModel(const std::string& model_path) { } void BpuSpeakerModel::ExtractEmbedding( - const std::vector>& chunk_feats, - std::vector* embed) { + const std::vector>& chunk_feats, + std::vector* embed) { // reset input && output Reset(); // chunk_feats: [198, 80] @@ -123,10 +122,10 @@ void BpuSpeakerModel::ExtractEmbedding( infer_task.reset(); hbSysFlushMem(&(output_dnn_[0]->sysMem[0]), HB_SYS_MEM_CACHE_INVALIDATE); - int output_dim = \ - output_dnn_[0]->properties.validShape.dimensionSize[1]; // 256 - const float* raw_data = \ - reinterpret_cast(output_dnn_[0]->sysMem[0].virAddr); + int output_dim = + output_dnn_[0]->properties.validShape.dimensionSize[1]; // 256 + const float* raw_data = + reinterpret_cast(output_dnn_[0]->sysMem[0].virAddr); embed->reserve(output_dim); // NOTE(cdliang): default output_node = 1 for (int idx = 0, i = 0; i < output_dim; i++) { @@ -136,15 +135,17 @@ void BpuSpeakerModel::ExtractEmbedding( void BpuSpeakerModel::Reset() { auto set_to_zero = - [](std::vector>& input_dnn_tensor_array, - std::vector>& output_dnn_tensor_array) { - for (auto& tensor : input_dnn_tensor_array) { - memset(tensor->sysMem[0].virAddr, 0, tensor->properties.alignedByteSize); - } - for (auto& tensor : output_dnn_tensor_array) { - memset(tensor->sysMem[0].virAddr, 0, tensor->properties.alignedByteSize); - } - }; + [](std::vector>& input_dnn_tensor_array, + std::vector>& output_dnn_tensor_array) { + for (auto& tensor : input_dnn_tensor_array) { + memset(tensor->sysMem[0].virAddr, 0, + tensor->properties.alignedByteSize); + } + for (auto& tensor : output_dnn_tensor_array) { + memset(tensor->sysMem[0].virAddr, 0, + tensor->properties.alignedByteSize); + } + }; set_to_zero(input_dnn_, output_dnn_); } } // namespace wespeaker diff --git a/runtime/core/speaker/bpu_speaker_model.h b/runtime/core/speaker/bpu_speaker_model.h index 82a5d8d5..ef5482db 100644 --- a/runtime/core/speaker/bpu_speaker_model.h +++ b/runtime/core/speaker/bpu_speaker_model.h @@ -17,17 +17,17 @@ #ifdef USE_BPU -#include -#include #include +#include +#include #include "easy_dnn/data_structure.h" #include "easy_dnn/model.h" #include "speaker/speaker_model.h" -using hobot::easy_dnn::Model; using hobot::easy_dnn::DNNTensor; +using hobot::easy_dnn::Model; namespace wespeaker { @@ -38,11 +38,12 @@ class BpuSpeakerModel : public SpeakerModel { ~BpuSpeakerModel() = default; void ExtractEmbedding(const std::vector>& chunk_feats, std::vector* embed) override; + private: void AllocMemory( - std::vector>* input_dnn_tensor_array, - std::vector>* output_dnn_tensor_array, - Model* model); + std::vector>* input_dnn_tensor_array, + std::vector>* output_dnn_tensor_array, + Model* model); void Read(const std::string& model_path); void Reset(); std::vector> input_dnn_; diff --git a/runtime/core/speaker/onnx_speaker_model.cc b/runtime/core/speaker/onnx_speaker_model.cc index d83ddaac..3e7c6eb0 100644 --- a/runtime/core/speaker/onnx_speaker_model.cc +++ b/runtime/core/speaker/onnx_speaker_model.cc @@ -16,14 +16,14 @@ #include -#include "speaker/onnx_speaker_model.h" #include "glog/logging.h" +#include "speaker/onnx_speaker_model.h" #include "utils/utils.h" namespace wespeaker { -Ort::Env OnnxSpeakerModel::env_ = Ort::Env( - ORT_LOGGING_LEVEL_WARNING, "OnnxModel"); +Ort::Env OnnxSpeakerModel::env_ = + Ort::Env(ORT_LOGGING_LEVEL_WARNING, "OnnxModel"); Ort::SessionOptions OnnxSpeakerModel::session_options_ = Ort::SessionOptions(); void OnnxSpeakerModel::InitEngineThreads(int num_threads) { @@ -32,22 +32,22 @@ void OnnxSpeakerModel::InitEngineThreads(int num_threads) { #ifdef USE_GPU void OnnxSpeakerModel::SetGpuDeviceId(int gpu_id) { - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA( - session_options_, gpu_id)); + Ort::ThrowOnError( + OrtSessionOptionsAppendExecutionProvider_CUDA(session_options_, gpu_id)); } #endif OnnxSpeakerModel::OnnxSpeakerModel(const std::string& model_path) { session_options_.SetGraphOptimizationLevel( GraphOptimizationLevel::ORT_ENABLE_EXTENDED); - // 1. Load sessions - #ifdef _MSC_VER +// 1. Load sessions +#ifdef _MSC_VER speaker_session_ = std::make_shared( - env_, ToWString(model_path).c_str(), session_options_); - #else - speaker_session_ = std::make_shared( - env_, model_path.c_str(), session_options_); - #endif + env_, ToWString(model_path).c_str(), session_options_); +#else + speaker_session_ = std::make_shared(env_, model_path.c_str(), + session_options_); +#endif // 2. Model info Ort::AllocatorWithDefaultOptions allocator; // 2.1. input info @@ -69,10 +69,9 @@ OnnxSpeakerModel::OnnxSpeakerModel(const std::string& model_path) { } void OnnxSpeakerModel::ExtractEmbedding( - const std::vector>& feats, - std::vector* embed) { + const std::vector>& feats, std::vector* embed) { Ort::MemoryInfo memory_info = - Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); // prepare onnx required data unsigned int num_frames = feats.size(); unsigned int feat_dim = feats[0].size(); diff --git a/runtime/core/speaker/speaker_engine.cc b/runtime/core/speaker/speaker_engine.cc index b63a6195..388ffd5a 100644 --- a/runtime/core/speaker/speaker_engine.cc +++ b/runtime/core/speaker/speaker_engine.cc @@ -12,25 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "speaker/speaker_engine.h" #include #include #include #include -#include "speaker/speaker_engine.h" - #ifdef USE_ONNX - #include "speaker/onnx_speaker_model.h" +#include "speaker/onnx_speaker_model.h" #elif USE_BPU - #include "speaker/bpu_speaker_model.h" +#include "speaker/bpu_speaker_model.h" #endif namespace wespeaker { -SpeakerEngine::SpeakerEngine(const std::string& model_path, - const int feat_dim, - const int sample_rate, - const int embedding_size, +SpeakerEngine::SpeakerEngine(const std::string& model_path, const int feat_dim, + const int sample_rate, const int embedding_size, const int SamplesPerChunk) { // NOTE(cdliang): default num_threads = 1 const int kNumGemmThreads = 1; @@ -41,26 +38,24 @@ SpeakerEngine::SpeakerEngine(const std::string& model_path, LOG(INFO) << "per_chunk_samples: " << per_chunk_samples_; sample_rate_ = sample_rate; LOG(INFO) << "Sample rate: " << sample_rate_; - feature_config_ = std::make_shared( - feat_dim, sample_rate); - feature_pipeline_ = \ - std::make_shared(*feature_config_); + feature_config_ = + std::make_shared(feat_dim, sample_rate); + feature_pipeline_ = + std::make_shared(*feature_config_); feature_pipeline_->Reset(); #ifdef USE_ONNX OnnxSpeakerModel::InitEngineThreads(kNumGemmThreads); - #ifdef USE_GPU +#ifdef USE_GPU // NOTE(cdliang): default gpu_id = 0 OnnxSpeakerModel::SetGpuDeviceId(0); - #endif +#endif model_ = std::make_shared(model_path); #elif USE_BPU model_ = std::make_shared(model_path); #endif } -int SpeakerEngine::EmbeddingSize() { - return embedding_size_; -} +int SpeakerEngine::EmbeddingSize() { return embedding_size_; } void SpeakerEngine::ApplyMean(std::vector>* feat, unsigned int feat_dim) { @@ -70,7 +65,7 @@ void SpeakerEngine::ApplyMean(std::vector>* feat, std::plus<>{}); } std::transform(mean.begin(), mean.end(), mean.begin(), - [&](const float d) {return d / feat->size();}); + [&](const float d) { return d / feat->size(); }); for (auto& i : *feat) { std::transform(i.begin(), i.end(), mean.begin(), i.begin(), std::minus<>{}); } @@ -82,12 +77,13 @@ void SpeakerEngine::ApplyMean(std::vector>* feat, // Extract audio features chunk by chunk, with 198 frames for each chunk. // If the last chunk is less than 198 frames, // concatenate the head frame to the tail. -void SpeakerEngine::ExtractFeature(const int16_t* data, int data_size, +void SpeakerEngine::ExtractFeature( + const int16_t* data, int data_size, std::vector>>* chunks_feat) { if (data != nullptr) { std::vector> chunk_feat; - feature_pipeline_->AcceptWaveform(std::vector( - data, data + data_size)); + feature_pipeline_->AcceptWaveform( + std::vector(data, data + data_size)); if (per_chunk_samples_ <= 0) { // full mode feature_pipeline_->Read(feature_pipeline_->num_frames(), &chunk_feat); @@ -96,11 +92,11 @@ void SpeakerEngine::ExtractFeature(const int16_t* data, int data_size, chunk_feat.clear(); } else { // NOTE(cdliang): extract feature with chunk by chunk - int num_chunk_frames_ = 1 + (( - per_chunk_samples_ - sample_rate_ / 1000 * 25) / - (sample_rate_ / 1000 * 10)); - int chunk_num = std::ceil( - feature_pipeline_->num_frames() / num_chunk_frames_); + int num_chunk_frames_ = + 1 + ((per_chunk_samples_ - sample_rate_ / 1000 * 25) / + (sample_rate_ / 1000 * 10)); + int chunk_num = + std::ceil(feature_pipeline_->num_frames() / num_chunk_frames_); chunks_feat->reserve(chunk_num); chunk_feat.reserve(num_chunk_frames_); while (feature_pipeline_->NumQueuedFrames() >= num_chunk_frames_) { @@ -119,12 +115,13 @@ void SpeakerEngine::ExtractFeature(const int16_t* data, int data_size, chunk_feat.insert(chunk_feat.end(), chunk_feat.begin(), chunk_feat.begin() + last_frames); } - chunk_feat.insert(chunk_feat.end(), chunk_feat.begin(), - chunk_feat.begin() + num_chunk_frames_ - chunk_feat.size()); + chunk_feat.insert( + chunk_feat.end(), chunk_feat.begin(), + chunk_feat.begin() + num_chunk_frames_ - chunk_feat.size()); } else { - chunk_feat.insert(chunk_feat.end(), - (*chunks_feat)[0].begin(), - (*chunks_feat)[0].begin() + num_chunk_frames_ - chunk_feat.size()); + chunk_feat.insert(chunk_feat.end(), (*chunks_feat)[0].begin(), + (*chunks_feat)[0].begin() + num_chunk_frames_ - + chunk_feat.size()); } CHECK_EQ(chunk_feat.size(), num_chunk_frames_); chunks_feat->emplace_back(chunk_feat); @@ -162,10 +159,10 @@ float SpeakerEngine::CosineSimilarity(const std::vector& emb1, const std::vector& emb2) { CHECK_EQ(emb1.size(), emb2.size()); float dot = std::inner_product(emb1.begin(), emb1.end(), emb2.begin(), 0.0); - float emb1_sum = std::inner_product(emb1.begin(), emb1.end(), - emb1.begin(), 0.0); - float emb2_sum = std::inner_product(emb2.begin(), emb2.end(), - emb2.begin(), 0.0); + float emb1_sum = + std::inner_product(emb1.begin(), emb1.end(), emb1.begin(), 0.0); + float emb2_sum = + std::inner_product(emb2.begin(), emb2.end(), emb2.begin(), 0.0); dot /= std::max(std::sqrt(emb1_sum) * std::sqrt(emb2_sum), std::numeric_limits::epsilon()); return dot; diff --git a/runtime/core/speaker/speaker_engine.h b/runtime/core/speaker/speaker_engine.h index 0b8b45cc..219bbb3f 100644 --- a/runtime/core/speaker/speaker_engine.h +++ b/runtime/core/speaker/speaker_engine.h @@ -15,38 +15,35 @@ #ifndef SPEAKER_SPEAKER_ENGINE_H_ #define SPEAKER_SPEAKER_ENGINE_H_ +#include #include #include -#include #include "frontend/feature_pipeline.h" #include "speaker/speaker_model.h" - namespace wespeaker { class SpeakerEngine { public: - explicit SpeakerEngine(const std::string& model_path, - const int feat_dim, - const int sample_rate, - const int embedding_size, + explicit SpeakerEngine(const std::string& model_path, const int feat_dim, + const int sample_rate, const int embedding_size, const int SamplesPerChunk); // return embedding_size int EmbeddingSize(); // extract fbank - void ExtractFeature(const int16_t* data, int data_size, - std::vector>>* chunks_feat); + void ExtractFeature( + const int16_t* data, int data_size, + std::vector>>* chunks_feat); // extract embedding void ExtractEmbedding(const int16_t* data, int data_size, std::vector* avg_emb); float CosineSimilarity(const std::vector& emb1, - const std::vector& emb2); + const std::vector& emb2); private: - void ApplyMean(std::vector>* feats, - unsigned int feat_dim); + void ApplyMean(std::vector>* feats, unsigned int feat_dim); std::shared_ptr model_ = nullptr; std::shared_ptr feature_config_ = nullptr; std::shared_ptr feature_pipeline_ = nullptr; diff --git a/runtime/core/speaker/speaker_model.h b/runtime/core/speaker/speaker_model.h index 77348e1a..08e7f815 100644 --- a/runtime/core/speaker/speaker_model.h +++ b/runtime/core/speaker/speaker_model.h @@ -15,8 +15,9 @@ #ifndef SPEAKER_SPEAKER_MODEL_H_ #define SPEAKER_SPEAKER_MODEL_H_ -#include #include +#include + #include "utils/utils.h" namespace wespeaker { diff --git a/runtime/core/utils/utils.cc b/runtime/core/utils/utils.cc index bd4c1b96..cd4c870e 100644 --- a/runtime/core/utils/utils.cc +++ b/runtime/core/utils/utils.cc @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include -#include #include -#include -#include #include -#include +#include +#include -#include "utils/utils.h" #include "glog/logging.h" +#include "utils/utils.h" namespace wespeaker { diff --git a/runtime/core/utils/utils.h b/runtime/core/utils/utils.h index b1aa002b..9f52815e 100644 --- a/runtime/core/utils/utils.h +++ b/runtime/core/utils/utils.h @@ -15,9 +15,9 @@ #ifndef UTILS_UTILS_H_ #define UTILS_UTILS_H_ -#include #include #include +#include namespace wespeaker { diff --git a/runtime/server/diarization_gpu/client/client.py b/runtime/server/diarization_gpu/client/client.py index 24723e32..a2d034eb 100644 --- a/runtime/server/diarization_gpu/client/client.py +++ b/runtime/server/diarization_gpu/client/client.py @@ -25,6 +25,7 @@ class SpeakerClient(object): + def __init__(self, triton_client, model_name, protocol_client): self.triton_client = triton_client self.protocol_client = protocol_client @@ -38,8 +39,10 @@ def recognize(self, wav_path, client_index): cur_length = len(waveform) input = np.zeros((1, cur_length), dtype=np.float32) input[0][0:cur_length] = waveform[0:cur_length] - inputs = [self.protocol_client.InferInput("input", input.shape, - np_to_triton_dtype(input.dtype))] + inputs = [ + self.protocol_client.InferInput("input", input.shape, + np_to_triton_dtype(input.dtype)) + ] inputs[0].set_data_from_numpy(input) outputs = [grpcclient.InferRequestedOutput("LABELS")] response = self.triton_client.infer(self.model_name, @@ -113,10 +116,12 @@ def single_job(li): if not os.path.exists(dir_name) and (dir_name != ''): os.makedirs(dir_name) seg_writer = open(os.path.join(FLAGS.output_directory, - 'rttm' + str(idx)), 'w', encoding="utf-8") + 'rttm' + str(idx)), + 'w', + encoding="utf-8") - with grpcclient.InferenceServerClient(url=FLAGS.url, - verbose=FLAGS.verbose) as triton_client: + with grpcclient.InferenceServerClient( + url=FLAGS.url, verbose=FLAGS.verbose) as triton_client: protocol_client = grpcclient speech_client = SpeakerClient(triton_client, FLAGS.model_name, protocol_client) @@ -132,11 +137,9 @@ def single_job(li): end = rttms[i][1] label = int(rttms[i][2]) channel = 1 - seg_writer.write(spec.format(utt, - channel, - begin, - end - begin, - label) + '\n') + seg_writer.write( + spec.format(utt, channel, begin, end - begin, label) + + '\n') seg_writer.flush() return predictions diff --git a/runtime/server/diarization_gpu/model_repo/clusterer/1/model.py b/runtime/server/diarization_gpu/model_repo/clusterer/1/model.py index 0b879bd3..ad3b6816 100644 --- a/runtime/server/diarization_gpu/model_repo/clusterer/1/model.py +++ b/runtime/server/diarization_gpu/model_repo/clusterer/1/model.py @@ -57,8 +57,12 @@ def initialize(self, args): self.output0_dtype = pb_utils.triton_string_to_numpy( output0_config['data_type']) - def cluster_gpu(self, embeddings, p=.01, num_spks=None, - min_num_spks=1, max_num_spks=20): + def cluster_gpu(self, + embeddings, + p=.01, + num_spks=None, + min_num_spks=1, + max_num_spks=20): # Define utility functions def cosine_similarity(M): M = M / cp.linalg.norm(M, axis=1, keepdims=True) diff --git a/runtime/server/diarization_gpu/model_repo/run/1/model.py b/runtime/server/diarization_gpu/model_repo/run/1/model.py index 86401b4e..ef73734d 100644 --- a/runtime/server/diarization_gpu/model_repo/run/1/model.py +++ b/runtime/server/diarization_gpu/model_repo/run/1/model.py @@ -36,8 +36,8 @@ def initialize(self, args): self.device = "cpu" # Get OUTPUT0 configuration - output0_config = pb_utils.get_output_config_by_name(model_config, - "LABELS") + output0_config = pb_utils.get_output_config_by_name( + model_config, "LABELS") # Convert Triton types to numpy types self.output0_dtype = pb_utils.triton_string_to_numpy( output0_config['data_type']) @@ -59,8 +59,8 @@ def prepare_chunks(self, for current_start_sample in range(0, audio_length_samples, window_size_samples): - chunk = wav[current_start_sample: - current_start_sample + window_size_samples] + chunk = wav[current_start_sample:current_start_sample + + window_size_samples] if len(chunk) < window_size_samples: chunk = torch.nn.functional.pad( chunk, (0, int(window_size_samples - len(chunk)))) @@ -68,7 +68,9 @@ def prepare_chunks(self, chunks.append(speech_prob) return chunks - def get_timestamps(self, speech_probs, audio_length_samples, + def get_timestamps(self, + speech_probs, + audio_length_samples, sr: int = 16000, threshold: float = 0.5, min_duration: float = 0.255, @@ -115,19 +117,21 @@ def get_timestamps(self, speech_probs, audio_length_samples, for i, speech in enumerate(speeches): if i == 0: - speech['start'] = int(max(0, - speech['start'] - speech_pad_samples)) + speech['start'] = int( + max(0, speech['start'] - speech_pad_samples)) if i != len(speeches) - 1: silence_duration = speeches[i + 1]['start'] - speech['end'] if silence_duration < 2 * speech_pad_samples: speech['end'] += int(silence_duration // 2) speeches[i + 1]['start'] = int( - max(0, speeches[i + 1]['start'] - silence_duration // 2)) + max(0, + speeches[i + 1]['start'] - silence_duration // 2)) else: speech['end'] += int(speech_pad_samples) else: - speech['end'] = int(min(audio_length_samples, speech['end'] - + speech_pad_samples)) + speech['end'] = int( + min(audio_length_samples, + speech['end'] + speech_pad_samples)) vad_result = [] for item in speeches: begin = item['start'] / sr @@ -138,11 +142,15 @@ def get_timestamps(self, speech_probs, audio_length_samples, vad_result.append(item) return vad_result - def subsegment(self, wav, segments, wav_idx, + def subsegment(self, + wav, + segments, + wav_idx, window_fs: float = 1.50, period_fs: float = 0.75, sr: int = 16000, frame_shift: int = 10): + def repeat_to_fill(x, window_fs): length = x.size(0) num = (window_fs + length - 1) // length @@ -162,13 +170,14 @@ def repeat_to_fill(x, window_fs): for segment in segments: seg_begin = int(segment['start'] * sr) seg_end = int(segment['end'] * sr) - seg_signal = wav[seg_begin: seg_end + 1] + seg_signal = wav[seg_begin:seg_end + 1] seg_length = seg_end - seg_begin if seg_length <= window_fs: - subseg = [wav_idx, seg_idx, - segment['start'], segment['end'], 0, - int(seg_length / sr * 1000 // frame_shift)] + subseg = [ + wav_idx, seg_idx, segment['start'], segment['end'], 0, + int(seg_length / sr * 1000 // frame_shift) + ] subseg_signal = repeat_to_fill(seg_signal, window_fs) subsegs.append(subseg) @@ -178,12 +187,13 @@ def repeat_to_fill(x, window_fs): max_subseg_begin = seg_length - window_fs + period_fs for subseg_begin in range(0, max_subseg_begin, period_fs): subseg_end = min(subseg_begin + window_fs, seg_length) - subseg = [wav_idx, seg_idx, - segment['start'], segment['end'], - int(subseg_begin / sr * 1000 / frame_shift), - int(subseg_end / sr * 1000 / frame_shift)] + subseg = [ + wav_idx, seg_idx, segment['start'], segment['end'], + int(subseg_begin / sr * 1000 / frame_shift), + int(subseg_end / sr * 1000 / frame_shift) + ] subseg_signal = repeat_to_fill( - seg_signal[subseg_begin: subseg_end + 1], window_fs) + seg_signal[subseg_begin:subseg_end + 1], window_fs) subsegs.append(subseg) subseg_signals.append(subseg_signal) @@ -196,8 +206,10 @@ def read_labels(self, subseg_ids, label, frame_shift=10): new_sort = {} for i, subseg in enumerate(subseg_ids): (utt, seg_idx, begin_ms, end_ms, begin_frames, end_frames) = subseg - begin = (int(begin_ms * 1000) + int(begin_frames) * frame_shift) / 1000.0 - end = (int(begin_ms * 1000) + int(end_frames) * frame_shift) / 1000.0 + begin = (int(begin_ms * 1000) + + int(begin_frames) * frame_shift) / 1000.0 + end = (int(begin_ms * 1000) + + int(end_frames) * frame_shift) / 1000.0 new_sort[seg_idx] = (begin, end, label[i]) utt_to_subseg_labels = list(dict(sorted(new_sort.items())).values()) return utt_to_subseg_labels @@ -280,7 +292,8 @@ async def execute(self, requests): out_segs = [] for speech_prob, speech_len in zip(reshape_probs, total_lens): segments = self.get_timestamps(speech_prob, - speech_len, threshold=0.36) + speech_len, + threshold=0.36) out_segs.append(segments) total_subsegments = [] @@ -289,8 +302,7 @@ async def execute(self, requests): wav_idx = 0 for waveform, segments in zip(total_wavs, out_segs): - subsegs, subseg_signals = self.subsegment(waveform, - segments, + subsegs, subseg_signals = self.subsegment(waveform, segments, wav_idx) total_subsegments.extend(subseg_signals) total_subsegment_ids.extend(subsegs) @@ -298,25 +310,25 @@ async def execute(self, requests): inference_response_awaits = [] for wavs in total_subsegments: - input_tensor_spk0 = pb_utils.Tensor.from_dlpack("WAV", - to_dlpack(wavs)) + input_tensor_spk0 = pb_utils.Tensor.from_dlpack( + "WAV", to_dlpack(wavs)) input_tensors_spk = [input_tensor_spk0] - inference_request = pb_utils.InferenceRequest(model_name='speaker', - requested_output_names=['EMBEDDINGS'], - inputs=input_tensors_spk) + inference_request = pb_utils.InferenceRequest( + model_name='speaker', + requested_output_names=['EMBEDDINGS'], + inputs=input_tensors_spk) inference_response_awaits.append(inference_request.async_exec()) - inference_responses = await asyncio.gather( - *inference_response_awaits) + inference_responses = await asyncio.gather(*inference_response_awaits) for inference_response in inference_responses: if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response. - error().message()) + raise pb_utils.TritonModelException( + inference_response.error().message()) else: - batched_result = pb_utils.get_output_tensor_by_name(inference_response, - 'EMBEDDINGS') + batched_result = pb_utils.get_output_tensor_by_name( + inference_response, 'EMBEDDINGS') total_embds.extend(from_dlpack(batched_result.to_dlpack())) out_embds = list() @@ -338,26 +350,26 @@ async def execute(self, requests): "EMBEDDINGS", to_dlpack(torch.unsqueeze(embd, 0))) input_tensors_spk = [input_tensor_embds0] - inference_request = pb_utils.InferenceRequest(model_name='clusterer', - requested_output_names=['LABELS'], - request_id=str(i), - inputs=input_tensors_spk) + inference_request = pb_utils.InferenceRequest( + model_name='clusterer', + requested_output_names=['LABELS'], + request_id=str(i), + inputs=input_tensors_spk) inference_response_awaits.append(inference_request.async_exec()) - inference_responses = await asyncio.gather( - *inference_response_awaits) + inference_responses = await asyncio.gather(*inference_response_awaits) i = 0 results = [] for inference_response in inference_responses: if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response. - error().message()) + raise pb_utils.TritonModelException( + inference_response.error().message()) else: - result = pb_utils.get_output_tensor_by_name(inference_response, - 'LABELS').as_numpy()[0] - utt_to_subseg_labels = self.read_labels(out_time_info[i], - result) + result = pb_utils.get_output_tensor_by_name( + inference_response, 'LABELS').as_numpy()[0] + utt_to_subseg_labels = self.read_labels( + out_time_info[i], result) i += 1 rttm = self.merge_segments(utt_to_subseg_labels) if len(rttm) > 0: @@ -368,7 +380,8 @@ async def execute(self, requests): for b in batch_count: sents = np.array(results[st:st + b]) out0 = pb_utils.Tensor("LABELS", sents.astype(self.output0_dtype)) - inference_response = pb_utils.InferenceResponse(output_tensors=[out0]) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out0]) responses.append(inference_response) st += b return responses diff --git a/runtime/server/x86_gpu/client/client.py b/runtime/server/x86_gpu/client/client.py index c9b70b47..11a13c58 100644 --- a/runtime/server/x86_gpu/client/client.py +++ b/runtime/server/x86_gpu/client/client.py @@ -23,7 +23,9 @@ import os import kaldiio + class SpeakerClient(object): + def __init__(self, triton_client, model_name, protocol_client): self.triton_client = triton_client self.protocol_client = protocol_client @@ -43,8 +45,10 @@ def recognize(self, wav_path, client_index): print(len(waveform) // 16000) input = np.zeros((1, cur_length), dtype=np.float32) input[0][0:cur_length] = waveform[0:cur_length] - inputs = [self.protocol_client.InferInput("WAV", input.shape, - np_to_triton_dtype(input.dtype))] + inputs = [ + self.protocol_client.InferInput("WAV", input.shape, + np_to_triton_dtype(input.dtype)) + ] inputs[0].set_data_from_numpy(input) outputs = [grpcclient.InferRequestedOutput("EMBEDDINGS")] response = self.triton_client.infer(self.model_name, @@ -88,13 +92,11 @@ def recognize(self, wav_path, client_index): required=False, default=None, help='data dir will be append to audio file if given') - parser.add_argument( - '--audio_file', - type=str, - required=False, - default=None, - help='single wav file' - ) + parser.add_argument('--audio_file', + type=str, + required=False, + default=None, + help='single wav file') FLAGS = parser.parse_args() @@ -122,11 +124,12 @@ def single_job(li): if not os.path.exists(dir_name) and (dir_name != ''): os.makedirs(dir_name) - embed_ark = os.path.abspath(dir_name) + "/xvector_{:0>3d}.ark".format(idx) + embed_ark = os.path.abspath(dir_name) + "/xvector_{:0>3d}.ark".format( + idx) embed_scp = embed_ark[:-3] + "scp" - with grpcclient.InferenceServerClient(url=FLAGS.url, - verbose=FLAGS.verbose) as triton_client: + with grpcclient.InferenceServerClient( + url=FLAGS.url, verbose=FLAGS.verbose) as triton_client: protocol_client = grpcclient speech_client = SpeakerClient(triton_client, FLAGS.model_name, protocol_client) diff --git a/runtime/server/x86_gpu/client/generate_input.py b/runtime/server/x86_gpu/client/generate_input.py index 7439ab6a..07dee208 100644 --- a/runtime/server/x86_gpu/client/generate_input.py +++ b/runtime/server/x86_gpu/client/generate_input.py @@ -5,19 +5,16 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument( - '--audio_file', - type=str, - default=None, - help='single wav file' - ) + parser.add_argument('--audio_file', + type=str, + default=None, + help='single wav file') parser.add_argument( '--seconds', type=float, required=False, default=None, - help='how long of the audio will be used as test sample' - ) + help='how long of the audio will be used as test sample') FLAGS = parser.parse_args() wav_file = FLAGS.audio_file @@ -27,13 +24,19 @@ num_samples = int(FLAGS.second * sample_rate) seconds = FLAGS.seconds if seconds < true_length: - waveform = waveform[0: num_samples] + waveform = waveform[0:num_samples] else: temp = np.zeros(num_samples, dtype=np.float32) - temp[0: len(waveform)] = waveform[:] + temp[0:len(waveform)] = waveform[:] waveform = temp - data = {"data": [{"WAV": {"shape": [len(waveform)], - "content": waveform.tolist()}}]} + data = { + "data": [{ + "WAV": { + "shape": [len(waveform)], + "content": waveform.tolist() + } + }] + } json.dump(data, open("input.json", "w")) diff --git a/runtime/server/x86_gpu/model_repo/feature_extractor/1/model.py b/runtime/server/x86_gpu/model_repo/feature_extractor/1/model.py index 696bf491..4f9e74a1 100644 --- a/runtime/server/x86_gpu/model_repo/feature_extractor/1/model.py +++ b/runtime/server/x86_gpu/model_repo/feature_extractor/1/model.py @@ -19,7 +19,9 @@ from typing import List import json + class Fbank(torch.nn.Module): + def __init__(self, opts): super(Fbank, self).__init__() self.fbank = kaldifeat.Fbank(opts) @@ -124,7 +126,9 @@ def execute(self, requests): for b in batch_count: batch_speech = features[idx:idx + b] idx += b - out0 = pb_utils.Tensor.from_dlpack("speech", to_dlpack(batch_speech)) - inference_response = pb_utils.InferenceResponse(output_tensors=[out0]) + out0 = pb_utils.Tensor.from_dlpack("speech", + to_dlpack(batch_speech)) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out0]) responses.append(inference_response) return responses diff --git a/setup.py b/setup.py index 50de585a..3ae566a1 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,9 @@ name="wespeaker", install_requires=requirements, packages=find_packages(), - entry_points={"console_scripts": [ - "wespeaker = wespeaker.cli.speaker:main", - ]}, + entry_points={ + "console_scripts": [ + "wespeaker = wespeaker.cli.speaker:main", + ] + }, ) diff --git a/tools/make_raw_list.py b/tools/make_raw_list.py index f3e373ac..50ee9e77 100644 --- a/tools/make_raw_list.py +++ b/tools/make_raw_list.py @@ -21,7 +21,10 @@ def get_args(): parser = argparse.ArgumentParser(description='') - parser.add_argument('--vad_file', type=str, help='vad file', default='non_exist') + parser.add_argument('--vad_file', + type=str, + help='vad file', + default='non_exist') parser.add_argument('wav_file', help='wav file') parser.add_argument('utt2spk_file', help='utt2spk file') parser.add_argument('raw_list', help='output raw list file') diff --git a/tools/make_shard_list.py b/tools/make_shard_list.py index e64f5274..12970297 100644 --- a/tools/make_shard_list.py +++ b/tools/make_shard_list.py @@ -103,7 +103,8 @@ def write_tar_file(data_list, tar_file, index=0, total=1): ts = time.time() if wav.endswith('|'): - p = subprocess.Popen(wav[:-1], shell=True, + p = subprocess.Popen(wav[:-1], + shell=True, stdout=subprocess.PIPE) data = p.stdout.read() else: @@ -148,7 +149,10 @@ def get_args(): parser.add_argument('--shuffle', action='store_true', help='whether to shuffle data') - parser.add_argument('--vad_file', type=str, help='vad file', default='non_exist') + parser.add_argument('--vad_file', + type=str, + help='vad file', + default='non_exist') parser.add_argument('wav_file', help='wav file') parser.add_argument('utt2spk_file', help='utt2spk file') parser.add_argument('shards_dir', help='output shards dir') diff --git a/tools/onnx2horizonbin.py b/tools/onnx2horizonbin.py index a905c91e..c69538d4 100644 --- a/tools/onnx2horizonbin.py +++ b/tools/onnx2horizonbin.py @@ -32,7 +32,6 @@ print('Please install hbdk,horizon_nn,horizon_tc_ui !') sys.exit(1) - logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) @@ -40,15 +39,14 @@ def make_calibration_data(args, conf, cal_data_dir): conf['shuffle'] = True logger.info(conf) - dataset = Dataset( - args.cali_data_type, - args.cali_datalist, - conf, - spk2id_dict={}, - whole_utt=False, - reverb_lmdb_file=None, - noise_lmdb_file=None, - repeat_dataset=False) + dataset = Dataset(args.cali_data_type, + args.cali_datalist, + conf, + spk2id_dict={}, + whole_utt=False, + reverb_lmdb_file=None, + noise_lmdb_file=None, + repeat_dataset=False) dataloader = DataLoader(dataset, shuffle=False, batch_size=1, @@ -57,11 +55,10 @@ def make_calibration_data(args, conf, cal_data_dir): if batch_idx % 100 == 0: logger.info("processed {} samples.".format(batch_idx)) assert len(batch['key']) == 1 - key = batch['key'][0] # [B, key] + key = batch['key'][0] # [B, key] feat = batch['feat'] feat = feat.unsqueeze(1).numpy() - feats_save_path = os.path.join(cal_data_dir, - '{}.bin'.format(key)) + feats_save_path = os.path.join(cal_data_dir, '{}.bin'.format(key)) feat.tofile(feats_save_path) @@ -133,30 +130,48 @@ def generate_config(args): output_dir = os.path.realpath(args.output_dir) speaker_onnx_path = os.path.realpath(args.onnx_path) speaker_log_path = os.path.join(output_dir, 'hb_makertbin_output_speaker') - speaker_config = template.format( - os.path.realpath(args.onnx_path), "speaker", speaker_log_path, - args.input_name, args.input_shape, - "cal_data_dir", args.calibration_type, args.extra_ops_run_on_cpu, "") + speaker_config = template.format(os.path.realpath(args.onnx_path), + "speaker", speaker_log_path, + args.input_name, args.input_shape, + "cal_data_dir", args.calibration_type, + args.extra_ops_run_on_cpu, "") with open(output_dir + "/config_speaker.yaml", "w") as speaker_yaml: speaker_yaml.write(speaker_config) def get_args(): - parser = argparse.ArgumentParser(description='convert onnx to horizon .bin') + parser = argparse.ArgumentParser( + description='convert onnx to horizon .bin') parser.add_argument('--config', required=True, help='config file') parser.add_argument('--output_dir', required=True, help='output directory') - parser.add_argument('--cali_datalist', type=str, default=None, + parser.add_argument('--cali_datalist', + type=str, + default=None, help='make calibration data') - parser.add_argument('--cali_data_type', type=str, default=None, + parser.add_argument('--cali_data_type', + type=str, + default=None, help='make calibration data') - parser.add_argument('--extra_ops_run_on_cpu', type=str, default="", + parser.add_argument('--extra_ops_run_on_cpu', + type=str, + default="", help='extra operations running on cpu.') - parser.add_argument('--calibration_type', type=str, default='default', + parser.add_argument('--calibration_type', + type=str, + default='default', help='kl / max / default.') - parser.add_argument('--onnx_path', type=str, required=True, + parser.add_argument('--onnx_path', + type=str, + required=True, help='onnx model (float)') - parser.add_argument('--input_name', type=str, required=True, help='input name') - parser.add_argument('--input_shape', type=str, required=True, help='input shape') + parser.add_argument('--input_name', + type=str, + required=True, + help='input name') + parser.add_argument('--input_shape', + type=str, + required=True, + help='input shape') return parser @@ -191,9 +206,7 @@ def get_args(): output_dir = os.path.realpath(args.output_dir) logger.info("Stage-3: Make speaker.bin") - os.system( - "cd {} && mkdir -p hb_makertbin_log_speaker".format(output_dir) + - " && cd hb_makertbin_log_speaker &&" + - " hb_mapper makertbin --model-type \"onnx\" --config \"{}\"".format( - output_dir + "/config_speaker.yaml") - ) + os.system("cd {} && mkdir -p hb_makertbin_log_speaker".format(output_dir) + + " && cd hb_makertbin_log_speaker &&" + + " hb_mapper makertbin --model-type \"onnx\" --config \"{}\"". + format(output_dir + "/config_speaker.yaml")) diff --git a/tools/vector_mean.py b/tools/vector_mean.py index 4d0cdc51..f7e59021 100644 --- a/tools/vector_mean.py +++ b/tools/vector_mean.py @@ -55,17 +55,12 @@ def compute_vector_mean(spk2utt, xvector_scp, spk_xvector_ark): if __name__ == '__main__': parser = argparse.ArgumentParser(description='compute the mean of vector') - parser.add_argument('--spk2utt', - type=str, - default='', - help='spk2utt file') + parser.add_argument('--spk2utt', type=str, default='', help='spk2utt file') parser.add_argument('--xvector_scp', type=str, default='', help='xvector file (kaldi format)') - parser.add_argument('--spk_xvector_ark', - type=str, - default='') + parser.add_argument('--spk_xvector_ark', type=str, default='') args = parser.parse_args() compute_vector_mean(args.spk2utt, args.xvector_scp, args.spk_xvector_ark) diff --git a/tools/wav2dur.py b/tools/wav2dur.py index 1bcc1b69..b53a7fe1 100755 --- a/tools/wav2dur.py +++ b/tools/wav2dur.py @@ -4,6 +4,7 @@ import sys import torchaudio + torchaudio.set_audio_backend("sox_io") scp = sys.argv[1] diff --git a/wespeaker/bin/adapt_plda.py b/wespeaker/bin/adapt_plda.py index 3d99baa5..c0313535 100644 --- a/wespeaker/bin/adapt_plda.py +++ b/wespeaker/bin/adapt_plda.py @@ -17,34 +17,39 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse from wespeaker.utils.plda.two_cov_plda import TwoCovPLDA if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--adp_scp', '-ad', + parser.add_argument('--adp_scp', + '-ad', type=str, required=True, help='Data for unlabeled adaptation.') - parser.add_argument('--across_class_scale', '-as', + parser.add_argument('--across_class_scale', + '-as', type=float, help='Scaling factor for across class covariance.', default=0.5) - parser.add_argument('--within_class_scale', '-ws', + parser.add_argument('--within_class_scale', + '-ws', type=float, help='Scaling factor for withn class covariance.', default=0.5) - parser.add_argument('--mdl_org', '-mo', + parser.add_argument('--mdl_org', + '-mo', type=str, required=True, help='Original PLDA mdl.') - parser.add_argument('--mdl_adp', '-ma', + parser.add_argument('--mdl_adp', + '-ma', type=str, required=True, help='Adapted PLDA mdl.') - parser.add_argument('--mdl_format', '-mf', + parser.add_argument('--mdl_format', + '-mf', type=str, default='wespeaker', help='Format of the model wespeaker/kaldi') diff --git a/wespeaker/bin/average_model.py b/wespeaker/bin/average_model.py index cc1181d7..a419acf4 100644 --- a/wespeaker/bin/average_model.py +++ b/wespeaker/bin/average_model.py @@ -35,10 +35,11 @@ def get_args(): default=0, type=int, help='min epoch used for averaging model') - parser.add_argument('--max_epoch', - default=65536, # Big enough - type=int, - help='max epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') args = parser.parse_args() print(args) return args @@ -47,8 +48,8 @@ def get_args(): def main(): args = get_args() - path_list = glob.glob( - '{}/[!avg][!final][!convert]*.pt'.format(args.src_path)) + path_list = glob.glob('{}/[!avg][!final][!convert]*.pt'.format( + args.src_path)) path_list = sorted( path_list, key=lambda p: int(re.findall(r"(?<=model_)\d*(?=.pt)", p)[0])) diff --git a/wespeaker/bin/eval_plda.py b/wespeaker/bin/eval_plda.py index 4b10fe8c..1481e874 100644 --- a/wespeaker/bin/eval_plda.py +++ b/wespeaker/bin/eval_plda.py @@ -22,14 +22,20 @@ type=str, default='2cov', help='which type of plda to use, 2cov|kaldi') - parser.add_argument('--enroll_scp_path', type=str, help='enroll embeddings') - parser.add_argument('--indomain_scp_path', type=str, + parser.add_argument('--enroll_scp_path', + type=str, + help='enroll embeddings') + parser.add_argument('--indomain_scp_path', + type=str, help='embeddings to compute meanvec') parser.add_argument('--test_scp_path', type=str, help='test embeddings') - parser.add_argument('--utt2spk', type=str, + parser.add_argument('--utt2spk', + type=str, help='utt2spk for the enroll speakers') parser.add_argument('--model_path', type=str, help='pretrained plda path') - parser.add_argument('--score_path', type=str, help='score file to write to') + parser.add_argument('--score_path', + type=str, + help='score file to write to') parser.add_argument('--trial', type=str, help='trial file to score upon') args = parser.parse_args() diff --git a/wespeaker/bin/export_onnx.py b/wespeaker/bin/export_onnx.py index 269373cf..014c4738 100644 --- a/wespeaker/bin/export_onnx.py +++ b/wespeaker/bin/export_onnx.py @@ -31,7 +31,9 @@ def get_args(): parser.add_argument('--config', required=True, help='config file') parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--output_model', required=True, help='output file') - parser.add_argument('--mean_vec', required=False, default=None, + parser.add_argument('--mean_vec', + required=False, + default=None, help='mean vector') args = parser.parse_args() return args @@ -54,6 +56,7 @@ def main(): mean_vec = torch.zeros(embed_dim, dtype=torch.float32) class Model(nn.Module): + def __init__(self, model, mean_vec=None): super(Model, self).__init__() self.model = model @@ -75,15 +78,23 @@ def forward(self, feats): num_frms = configs['dataset_args'].get('num_frms', 200) dummy_input = torch.ones(1, num_frms, feat_dim) - torch.onnx.export( - model, dummy_input, - args.output_model, - do_constant_folding=True, - verbose=False, - opset_version=14, - input_names=['feats'], - output_names=['embs'], - dynamic_axes={'feats': {0: 'B', 1: 'T'}, 'embs': {0: 'B'}}) + torch.onnx.export(model, + dummy_input, + args.output_model, + do_constant_folding=True, + verbose=False, + opset_version=14, + input_names=['feats'], + output_names=['embs'], + dynamic_axes={ + 'feats': { + 0: 'B', + 1: 'T' + }, + 'embs': { + 0: 'B' + } + }) # You may further generate tensorrt engine: # trtexec --onnx=avg_model.onnx --minShapes=feats:1x200x80 \ diff --git a/wespeaker/bin/export_onnx_bpu.py b/wespeaker/bin/export_onnx_bpu.py index f5fcdc62..9cd3d297 100644 --- a/wespeaker/bin/export_onnx_bpu.py +++ b/wespeaker/bin/export_onnx_bpu.py @@ -33,15 +33,21 @@ def get_args(): parser.add_argument('--config', required=True, help='config file') parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--output_model', required=True, help='output file') - parser.add_argument('--mean_vec', required=False, default=None, + parser.add_argument('--mean_vec', + required=False, + default=None, help='mean vector') # NOTE(cdliang): for horizon bpu, the shape of input is fixed. - parser.add_argument('--num_frames', type=int, required=True, help="num frames") + parser.add_argument('--num_frames', + type=int, + required=True, + help="num frames") args = parser.parse_args() return args class Model(nn.Module): + def __init__(self, model, mean_vec=None): super(Model, self).__init__() self.model = model @@ -49,7 +55,7 @@ def __init__(self, model, mean_vec=None): def forward(self, feats): # NOTE(cdliang): for horizion x3pi, input shape is [NHWC] - feats = feats.squeeze(1) # [B, 1, T, F] -> [B, T, F] + feats = feats.squeeze(1) # [B, 1, T, F] -> [B, T, F] outputs = self.model(feats) # embed or (embed_a, embed_b) embeds = outputs[-1] if isinstance(outputs, tuple) else outputs embeds = embeds - self.mean_vec @@ -78,14 +84,14 @@ def main(): feat_dim = configs['model_args'].get('feat_dim', 80) static_input = torch.ones(1, 1, args.num_frames, feat_dim) - torch.onnx.export( - model, static_input, - args.output_model, - do_constant_folding=True, - verbose=False, - opset_version=11, - input_names=['feats'], - output_names=['embs']) + torch.onnx.export(model, + static_input, + args.output_model, + do_constant_folding=True, + verbose=False, + opset_version=11, + input_names=['feats'], + output_names=['embs']) if __name__ == '__main__': diff --git a/wespeaker/bin/extract_deprecated.py b/wespeaker/bin/extract_deprecated.py index 78fa0fd8..4cff0226 100644 --- a/wespeaker/bin/extract_deprecated.py +++ b/wespeaker/bin/extract_deprecated.py @@ -70,7 +70,8 @@ def extract(config='conf/config.yaml', **kwargs): with torch.no_grad(): with kaldiio.WriteHelper('ark,scp:' + embed_ark + "," + embed_scp) as writer: - t_bar = tqdm(ncols=100, total=len(dataloader), + t_bar = tqdm(ncols=100, + total=len(dataloader), desc='extract_embed: ') for i, (utts, feats, _) in enumerate(dataloader): t_bar.update() diff --git a/wespeaker/bin/infer_onnx.py b/wespeaker/bin/infer_onnx.py index 2e94a26e..eb89efbe 100644 --- a/wespeaker/bin/infer_onnx.py +++ b/wespeaker/bin/infer_onnx.py @@ -63,12 +63,8 @@ def main(): feats = compute_fbank(wav_path) feats = feats.unsqueeze(0).numpy() # add batch dimension - embeddings = session.run( - output_names=['embs'], - input_feed={ - 'feats': feats - } - ) + embeddings = session.run(output_names=['embs'], + input_feed={'feats': feats}) print(embeddings[0].shape) diff --git a/wespeaker/bin/score.py b/wespeaker/bin/score.py index 13025711..b65209de 100644 --- a/wespeaker/bin/score.py +++ b/wespeaker/bin/score.py @@ -68,15 +68,11 @@ def trials_cosine_score(eval_scp_path='', w_f.write('{} {} {:.5f} {}\n'.format( segs[0], segs[1], cos_score, segs[2])) else: # enroll_name test_name - w_f.write('{} {} {:.5f}\n'.format( - segs[0], segs[1], cos_score)) + w_f.write('{} {} {:.5f}\n'.format(segs[0], segs[1], + cos_score)) -def main(exp_dir, - eval_scp_path, - cal_mean, - cal_mean_dir, - *trials): +def main(exp_dir, eval_scp_path, cal_mean, cal_mean_dir, *trials): if not cal_mean: print("Do not do mean normalization for evaluation embeddings.") mean_vec_path = None diff --git a/wespeaker/bin/score_norm.py b/wespeaker/bin/score_norm.py index cb2afd37..0856a936 100644 --- a/wespeaker/bin/score_norm.py +++ b/wespeaker/bin/score_norm.py @@ -24,8 +24,8 @@ def get_mean_std(emb, cohort, top_n): - emb = emb / np.sqrt(np.sum(emb ** 2, axis=1, keepdims=True)) - cohort = cohort / np.sqrt(np.sum(cohort ** 2, axis=1, keepdims=True)) + emb = emb / np.sqrt(np.sum(emb**2, axis=1, keepdims=True)) + cohort = cohort / np.sqrt(np.sum(cohort**2, axis=1, keepdims=True)) emb_cohort_score = np.matmul(emb, cohort.T) emb_cohort_score = np.sort(emb_cohort_score, axis=1)[:, ::-1] emb_cohort_score_topn = emb_cohort_score[:, :top_n] @@ -67,7 +67,7 @@ def main(score_norm_method, else: assert os.path.exists( mean_vec_path), "mean_vec file ({}) does not exist !!!".format( - mean_vec_path) + mean_vec_path) mean_vec = np.load(mean_vec_path) # get embedding @@ -105,8 +105,8 @@ def main(score_norm_method, normed_score = 0.5 * ( (score - enroll_mean[enroll_idx]) / enroll_std[enroll_idx] + (score - test_mean[test_idx]) / test_std[test_idx]) - fout.write('{} {} {:.5f} {}\n'.format( - line[0], line[1], normed_score, line[3])) + fout.write('{} {} {:.5f} {}\n'.format(line[0], line[1], + normed_score, line[3])) logging.info("Over!") diff --git a/wespeaker/bin/train.py b/wespeaker/bin/train.py index e6d66747..add26fd6 100644 --- a/wespeaker/bin/train.py +++ b/wespeaker/bin/train.py @@ -115,7 +115,8 @@ def train(config='conf/config.yaml', **kwargs): elif checkpoint is None: logger.info('Train model from scratch ...') # projection layer - configs['projection_args']['embed_dim'] = configs['model_args']['embed_dim'] + configs['projection_args']['embed_dim'] = configs['model_args'][ + 'embed_dim'] configs['projection_args']['num_class'] = len(spk2id_dict) configs['projection_args']['do_lm'] = configs.get('do_lm', False) if configs['data_type'] != 'feat' and configs['dataset_args'][ @@ -123,7 +124,8 @@ def train(config='conf/config.yaml', **kwargs): # diff speed is regarded as diff spk configs['projection_args']['num_class'] *= 3 if configs.get('do_lm', False): - logger.info('No speed perturb while doing large margin fine-tuning') + logger.info( + 'No speed perturb while doing large margin fine-tuning') configs['dataset_args']['speed_perturb'] = False projection = get_projection(configs['projection_args']) model.add_module("projection", projection) diff --git a/wespeaker/bin/train_deprecated.py b/wespeaker/bin/train_deprecated.py index 63a9c21b..bf1ca61e 100644 --- a/wespeaker/bin/train_deprecated.py +++ b/wespeaker/bin/train_deprecated.py @@ -119,7 +119,8 @@ def train(config='conf/config.yaml', **kwargs): else: logger.info('Train model from scratch...') # projection layer - configs['projection_args']['embed_dim'] = configs['model_args']['embed_dim'] + configs['projection_args']['embed_dim'] = configs['model_args'][ + 'embed_dim'] configs['projection_args']['num_class'] = len(spk2id_dict) if configs['feature_args']['raw_wav'] and configs['dataset_args'][ 'speed_perturb']: @@ -159,8 +160,9 @@ def train(config='conf/config.yaml', **kwargs): logger.info("loss criterion is: " + configs['loss']) configs['optimizer_args']['lr'] = configs['scheduler_args']['initial_lr'] - optimizer = getattr(torch.optim, configs['optimizer'])( - ddp_model.parameters(), **configs['optimizer_args']) + optimizer = getattr(torch.optim, + configs['optimizer'])(ddp_model.parameters(), + **configs['optimizer_args']) if rank == 0: logger.info("<== Optimizer ==>") logger.info("optimizer is: " + configs['optimizer']) @@ -168,8 +170,9 @@ def train(config='conf/config.yaml', **kwargs): # scheduler configs['scheduler_args']['num_epochs'] = configs['num_epochs'] configs['scheduler_args']['epoch_iter'] = len(train_dataloader) - scheduler = getattr(schedulers, configs['scheduler'])( - optimizer, **configs['scheduler_args']) + scheduler = getattr(schedulers, + configs['scheduler'])(optimizer, + **configs['scheduler_args']) if rank == 0: logger.info("<== Scheduler ==>") logger.info("scheduler is: " + configs['scheduler']) @@ -215,8 +218,8 @@ def train(config='conf/config.yaml', **kwargs): if epoch % configs['save_epoch_interval'] == 0 or epoch >= configs[ 'num_epochs'] - configs['num_avg']: save_checkpoint( - model, - os.path.join(model_dir, 'model_{}.pt'.format(epoch))) + model, os.path.join(model_dir, + 'model_{}.pt'.format(epoch))) if rank == 0: os.symlink('model_{}.pt'.format(configs['num_epochs']), diff --git a/wespeaker/bin/train_plda.py b/wespeaker/bin/train_plda.py index e030c94e..554978af 100644 --- a/wespeaker/bin/train_plda.py +++ b/wespeaker/bin/train_plda.py @@ -27,7 +27,7 @@ type=str, default='2cov', help='which type of plda to use, we only support ' - 'kaldi 2cov version currently') + 'kaldi 2cov version currently') parser.add_argument('--scp_path', type=str, help='the plda training embedding.scp file') diff --git a/wespeaker/cli/speaker.py b/wespeaker/cli/speaker.py index 5fbcb85e..62ef3165 100644 --- a/wespeaker/cli/speaker.py +++ b/wespeaker/cli/speaker.py @@ -32,6 +32,7 @@ class Speaker: + def __init__(self, model_dir: str): config_path = os.path.join(model_dir, 'config.yaml') model_path = os.path.join(model_dir, 'avg_model.pt') @@ -187,7 +188,7 @@ def get_args(): help='audio file2, specifically for similarity task') parser.add_argument('--wav_scp', help='path to wav.scp, for extract and saving ' - 'kaldi-stype embeddings') + 'kaldi-stype embeddings') parser.add_argument('--resample_rate', type=int, default=16000, @@ -221,7 +222,8 @@ def main(): names, embeddings = model.extract_embedding_list(args.wav_scp) embed_ark = args.output_file + ".ark" embed_scp = args.output_file + ".scp" - with kaldiio.WriteHelper('ark,scp:' + embed_ark + "," + embed_scp) as writer: + with kaldiio.WriteHelper('ark,scp:' + embed_ark + "," + + embed_scp) as writer: for name, embedding in zip(names, embeddings): writer(name, embedding) elif args.task == 'similarity': diff --git a/wespeaker/dataset/dataset.py b/wespeaker/dataset/dataset.py index 8134257f..d4cf8a9d 100644 --- a/wespeaker/dataset/dataset.py +++ b/wespeaker/dataset/dataset.py @@ -38,7 +38,6 @@ def __init__(self, source, f, *args, **kw): def set_epoch(self, epoch): self.source.set_epoch(epoch) - def __iter__(self): """ Return an iterator over the source dataset processed by the given processor. @@ -103,7 +102,11 @@ def sample(self, data): class DataList(IterableDataset): - def __init__(self, lists, shuffle=True, partition=True, repeat_dataset=True): + def __init__(self, + lists, + shuffle=True, + partition=True, + repeat_dataset=True): self.lists = lists self.repeat_dataset = repeat_dataset self.sampler = DistributedSampler(shuffle, partition) @@ -171,14 +174,15 @@ def Dataset(data_type, filter_conf = configs.get('filter_args', {}) dataset = Processor(dataset, processor.filter, - frame_shift=configs['fbank_args'].get('frame_shift', 10), + frame_shift=configs['fbank_args'].get( + 'frame_shift', 10), data_type=data_type, - **filter_conf - ) + **filter_conf) # Local shuffle if shuffle: - dataset = Processor(dataset, processor.shuffle, **configs['shuffle_args']) + dataset = Processor(dataset, processor.shuffle, + **configs['shuffle_args']) # spk2id dataset = Processor(dataset, processor.spk_to_id, spk2id_dict) @@ -187,7 +191,8 @@ def Dataset(data_type, if not whole_utt: # random chunk chunk_len = num_frms = configs.get('num_frms', 200) - dataset = Processor(dataset, processor.random_chunk, chunk_len, 'feat') + dataset = Processor(dataset, processor.random_chunk, chunk_len, + 'feat') else: # resample resample_rate = configs.get('resample_rate', 16000) @@ -195,24 +200,28 @@ def Dataset(data_type, # speed perturb speed_perturb_flag = configs.get('speed_perturb', True) if speed_perturb_flag: - dataset = Processor(dataset, processor.speed_perturb, len(spk2id_dict)) + dataset = Processor(dataset, processor.speed_perturb, + len(spk2id_dict)) if not whole_utt: # random chunk num_frms = configs.get('num_frms', 200) frame_shift = configs['fbank_args'].get('frame_shift', 10) frame_length = configs['fbank_args'].get('frame_length', 25) - chunk_len = ((num_frms - 1) * frame_shift - + frame_length) * resample_rate // 1000 - dataset = Processor(dataset, processor.random_chunk, chunk_len, data_type) + chunk_len = ((num_frms - 1) * frame_shift + + frame_length) * resample_rate // 1000 + dataset = Processor(dataset, processor.random_chunk, chunk_len, + data_type) # add reverb & noise aug_prob = configs.get('aug_prob', 0.6) if (reverb_lmdb_file and noise_lmdb_file) and (aug_prob > 0.0): reverb_data = LmdbData(reverb_lmdb_file) noise_data = LmdbData(noise_lmdb_file) - dataset = Processor(dataset, processor.add_reverb_noise, reverb_data, - noise_data, resample_rate, aug_prob) + dataset = Processor(dataset, processor.add_reverb_noise, + reverb_data, noise_data, resample_rate, + aug_prob) # compute fbank - dataset = Processor(dataset, processor.compute_fbank, **configs['fbank_args']) + dataset = Processor(dataset, processor.compute_fbank, + **configs['fbank_args']) # apply cmvn dataset = Processor(dataset, processor.apply_cmvn) @@ -220,6 +229,7 @@ def Dataset(data_type, # spec augmentation spec_aug_flag = configs.get('spec_aug', True) if spec_aug_flag: - dataset = Processor(dataset, processor.spec_aug, **configs['spec_aug_args']) + dataset = Processor(dataset, processor.spec_aug, + **configs['spec_aug_args']) return dataset diff --git a/wespeaker/dataset/dataset_deprecated.py b/wespeaker/dataset/dataset_deprecated.py index f701656f..fda30ddc 100644 --- a/wespeaker/dataset/dataset_deprecated.py +++ b/wespeaker/dataset/dataset_deprecated.py @@ -23,7 +23,8 @@ import torchaudio.compliance.kaldi as kaldi from wespeaker.utils.file_utils import read_scp -from wespeaker.utils.dataset_utils_deprecated import (get_random_chunk, speed_perturb, +from wespeaker.utils.dataset_utils_deprecated import (get_random_chunk, + speed_perturb, spec_augmentation) @@ -31,11 +32,8 @@ class FeatList_LableDict_Dataset(Dataset): """ shuffle wav.scp/feats.scp, load all labels into cpu memory """ - def __init__(self, - data_list, - utt2spkid_dict, - whole_utt=False, - **kwargs): + + def __init__(self, data_list, utt2spkid_dict, whole_utt=False, **kwargs): super(FeatList_LableDict_Dataset, self).__init__() self.data_list = data_list self.length = len(data_list) @@ -117,9 +115,14 @@ def __len__(self): class Augment_Wav: + def __init__(self, musan_scp, rirs_scp): - self.noise_snr = {'noise': [0, 15], 'speech': [13, 20], 'music': [5, 15]} + self.noise_snr = { + 'noise': [0, 15], + 'speech': [13, 20], + 'music': [5, 15] + } self.num_noise = {'noise': [1, 1], 'speech': [3, 7], 'music': [1, 1]} self.rir_list = read_scp(rirs_scp) @@ -159,7 +162,8 @@ def additive_noise(self, noise_type, audio): self.noise_snr[noise_type][1]) noise_db = 10 * np.log10(np.mean(noise_audio**2) + 1e-4) noise_list.append( - np.sqrt(10**((audio_db - noise_db - noise_snr) / 10)) * noise_audio) + np.sqrt(10**((audio_db - noise_db - noise_snr) / 10)) * + noise_audio) return np.sum(np.stack(noise_list), axis=0) + audio diff --git a/wespeaker/dataset/processor.py b/wespeaker/dataset/processor.py index 676c02ab..33d23d68 100644 --- a/wespeaker/dataset/processor.py +++ b/wespeaker/dataset/processor.py @@ -125,6 +125,7 @@ def parse_raw(data): Returns: Iterable[{key, wav, spk, sample_rate}] """ + def read_audio(wav): if wav.endswith('|'): p = Popen(wav[:-1], shell=True, stdout=PIPE) @@ -156,7 +157,8 @@ def apply_vad(waveform, sample_rate, vad): try: waveform, sample_rate = read_audio(wav_file) if 'vad' in obj: - waveform, sample_rate = apply_vad(waveform, sample_rate, obj['vad']) + waveform, sample_rate = apply_vad(waveform, sample_rate, + obj['vad']) example = dict(key=key, spk=spk, wav=waveform, @@ -187,9 +189,7 @@ def parse_feat(data): spk = obj['spk'] try: feat = torch.from_numpy(kaldiio.load_mat(feat_ark)) - example = dict(key=key, - spk=spk, - feat=feat) + example = dict(key=key, spk=spk, feat=feat) yield example except Exception as ex: logging.warning('Failed to load {}'.format(feat_ark)) @@ -312,7 +312,8 @@ def get_random_chunk(data, chunk_len): else: # padding repeat_factor = chunk_len // data_len + 1 - repeat_shape = repeat_factor if len(data_shape) == 1 else (repeat_factor, 1) + repeat_shape = repeat_factor if len(data_shape) == 1 else ( + repeat_factor, 1) if type(data) == torch.Tensor: data = data.repeat(repeat_shape) else: # np.array @@ -326,8 +327,7 @@ def filter(data, min_num_frames=100, max_num_frames=800, frame_shift=10, - data_type='shard/raw/feat' - ): + data_type='shard/raw/feat'): """ Filter the utterance with very short duration and random chunk the utterance with very long duration. @@ -452,8 +452,7 @@ def add_reverb_noise(data, # Since the noise audio could be very long, it must be # chunked first before resampled (to save time) noise_audio = get_random_chunk( - noise_audio, - int(audio_len / resample_rate * noise_sr)) + noise_audio, int(audio_len / resample_rate * noise_sr)) noise_audio = signal.resample(noise_audio, audio_len) else: noise_audio = get_random_chunk(noise_audio, audio_len) diff --git a/wespeaker/models/campplus.py b/wespeaker/models/campplus.py index d5c25670..5917c0e8 100644 --- a/wespeaker/models/campplus.py +++ b/wespeaker/models/campplus.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - ''' This implementation is adapted from github repo: https://github.com/alibaba-damo-academy/3D-Speaker @@ -53,6 +52,7 @@ def get_nonlinear(config_str, channels): class TDNNLayer(nn.Module): + def __init__(self, in_channels, out_channels, @@ -83,6 +83,7 @@ def forward(self, x): class CAMLayer(nn.Module): + def __init__(self, bn_channels, out_channels, @@ -127,12 +128,14 @@ def seg_pooling(self, x, seg_len: int = 100, stype: str = 'avg'): raise ValueError('Wrong segment pooling type.') shape = seg.shape seg = seg.unsqueeze(-1).expand(shape[0], shape[1], shape[2], - seg_len).reshape(shape[0], shape[1], -1) + seg_len).reshape( + shape[0], shape[1], -1) seg = seg[..., :x.shape[-1]] return seg class CAMDenseTDNNLayer(nn.Module): + def __init__(self, in_channels, out_channels, @@ -167,6 +170,7 @@ def forward(self, x): class CAMDenseTDNNBlock(nn.ModuleList): + def __init__(self, num_layers, in_channels, @@ -197,6 +201,7 @@ def forward(self, x): class TransitLayer(nn.Module): + def __init__(self, in_channels, out_channels, @@ -213,6 +218,7 @@ def forward(self, x): class DenseLayer(nn.Module): + def __init__(self, in_channels, out_channels, @@ -233,6 +239,8 @@ def forward(self, x): '''Note: The stride used here is different from that in Resnet ''' + + class BasicResBlock(nn.Module): expansion = 1 @@ -260,8 +268,7 @@ def __init__(self, in_planes, planes, stride=1): self.expansion * planes, kernel_size=1, stride=(stride, 1), - bias=False), - nn.BatchNorm2d(self.expansion * planes)) + bias=False), nn.BatchNorm2d(self.expansion * planes)) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) @@ -272,11 +279,8 @@ def forward(self, x): class FCM(nn.Module): - def __init__(self, - block, - num_blocks, - m_channels=32, - feat_dim=80): + + def __init__(self, block, num_blocks, m_channels=32, feat_dim=80): super(FCM, self).__init__() self.in_planes = m_channels self.conv1 = nn.Conv2d(1, @@ -326,6 +330,7 @@ def forward(self, x): class CAMPPlus(nn.Module): + def __init__(self, feat_dim=80, embed_dim=512, @@ -336,7 +341,9 @@ def __init__(self, config_str='batchnorm-relu'): super(CAMPPlus, self).__init__() - self.head = FCM(block=BasicResBlock, num_blocks=[2, 2], feat_dim=feat_dim) + self.head = FCM(block=BasicResBlock, + num_blocks=[2, 2], + feat_dim=feat_dim) channels = self.head.out_channels self.xvector = nn.Sequential( diff --git a/wespeaker/models/convert_repvgg.py b/wespeaker/models/convert_repvgg.py index 3244bba0..1ee0c7e3 100644 --- a/wespeaker/models/convert_repvgg.py +++ b/wespeaker/models/convert_repvgg.py @@ -24,7 +24,8 @@ def convert(config='conf/config.yaml', **kwargs): configs = parse_config_or_kwargs(config, **kwargs) - speaker_model = get_speaker_model(configs['model'])(**configs['model_args']) + speaker_model = get_speaker_model( + configs['model'])(**configs['model_args']) configs['model_args']['deploy'] = True # save new configs for testing and deploying # NOTE: 'deploy': true @@ -47,5 +48,6 @@ def convert(config='conf/config.yaml', **kwargs): repvgg_model_convert(speaker_model, save_path=configs['save']) print("==> Saving convert model to '{}'".format(configs['save'])) + if __name__ == '__main__': fire.Fire(convert) diff --git a/wespeaker/models/ecapa_tdnn.py b/wespeaker/models/ecapa_tdnn.py index ae2a904f..bd95d879 100644 --- a/wespeaker/models/ecapa_tdnn.py +++ b/wespeaker/models/ecapa_tdnn.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - ''' This implementation is adapted from github repo: https://github.com/lawlict/ECAPA-TDNN. ''' @@ -22,16 +21,15 @@ import torch.nn as nn import torch.nn.functional as F import wespeaker.models.pooling_layers as pooling_layers - - - ''' Res2Conv1d + BatchNorm1d + ReLU ''' + class Res2Conv1dReluBn(nn.Module): """ in_channels == out_channels == channels """ + def __init__(self, channels, kernel_size=1, @@ -79,11 +77,12 @@ def forward(self, x): return out - ''' Conv1d + BatchNorm1d + ReLU ''' + class Conv1dReluBn(nn.Module): + def __init__(self, in_channels, out_channels, @@ -106,11 +105,12 @@ def forward(self, x): return self.bn(F.relu(self.conv(x))) - ''' The SE connection of 1D case. ''' + class SE_Connect(nn.Module): + def __init__(self, channels, se_bottleneck_dim=128): super().__init__() self.linear1 = nn.Linear(channels, se_bottleneck_dim) @@ -125,11 +125,12 @@ def forward(self, x): return out - ''' SE-Res2Block of the ECAPA-TDNN architecture. ''' + class SE_Res2Block(nn.Module): + def __init__(self, channels, kernel_size, stride, padding, dilation, scale): super().__init__() @@ -149,15 +150,14 @@ def __init__(self, channels, kernel_size, stride, padding, dilation, channels, kernel_size=1, stride=1, - padding=0), - SE_Connect(channels)) + padding=0), SE_Connect(channels)) def forward(self, x): return x + self.se_res2block(x) - class ECAPA_TDNN(nn.Module): + def __init__(self, channels=512, feat_dim=80, @@ -229,7 +229,10 @@ def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func='ASTP', emb_bn=False): emb_bn=emb_bn) -def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func='ASTP', emb_bn=False): +def ECAPA_TDNN_GLOB_c1024(feat_dim, + embed_dim, + pooling_func='ASTP', + emb_bn=False): return ECAPA_TDNN(channels=1024, feat_dim=feat_dim, embed_dim=embed_dim, @@ -246,7 +249,10 @@ def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func='ASTP', emb_bn=False): emb_bn=emb_bn) -def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func='ASTP', emb_bn=False): +def ECAPA_TDNN_GLOB_c512(feat_dim, + embed_dim, + pooling_func='ASTP', + emb_bn=False): return ECAPA_TDNN(channels=512, feat_dim=feat_dim, embed_dim=embed_dim, diff --git a/wespeaker/models/pooling_layers.py b/wespeaker/models/pooling_layers.py index 47120eb9..29b319df 100644 --- a/wespeaker/models/pooling_layers.py +++ b/wespeaker/models/pooling_layers.py @@ -94,7 +94,11 @@ class ASTP(nn.Module): statistics pooling, first used in ECAPA_TDNN. """ - def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False, **kwargs): + def __init__(self, + in_dim, + bottleneck_dim=128, + global_context_att=False, + **kwargs): super(ASTP, self).__init__() self.in_dim = in_dim self.global_context_att = global_context_att diff --git a/wespeaker/models/repvgg.py b/wespeaker/models/repvgg.py index 577b87e5..d993931f 100644 --- a/wespeaker/models/repvgg.py +++ b/wespeaker/models/repvgg.py @@ -209,14 +209,12 @@ def get_custom_L2(self): l2_loss_eq_kernel = (eq_kernel**2 / (t3**2 + t1**2)).sum() return l2_loss_eq_kernel + l2_loss_circle - -# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way. -# You can get the equivalent kernel and bias -# at any time and do whatever you want, -# for example, apply some penalties or constraints during training, -# just like you do to the other models. -# May be useful for quantization or pruning. - + # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way + # You can get the equivalent kernel and bias + # at any time and do whatever you want, + # for example, apply some penalties or constraints during training, + # just like you do to the other models. + # May be useful for quantization or pruning. def get_equivalent_kernel_bias(self): kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) diff --git a/wespeaker/models/resnet.py b/wespeaker/models/resnet.py index 2e903ba6..729382f1 100644 --- a/wespeaker/models/resnet.py +++ b/wespeaker/models/resnet.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - '''ResNet in PyTorch. Some modifications from the original architecture: @@ -59,8 +58,7 @@ def __init__(self, in_planes, planes, stride=1): self.expansion * planes, kernel_size=1, stride=stride, - bias=False), - nn.BatchNorm2d(self.expansion * planes)) + bias=False), nn.BatchNorm2d(self.expansion * planes)) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) @@ -97,8 +95,7 @@ def __init__(self, in_planes, planes, stride=1): self.expansion * planes, kernel_size=1, stride=stride, - bias=False), - nn.BatchNorm2d(self.expansion * planes)) + bias=False), nn.BatchNorm2d(self.expansion * planes)) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) @@ -110,6 +107,7 @@ def forward(self, x): class ResNet(nn.Module): + def __init__(self, block, num_blocks, @@ -149,11 +147,11 @@ def __init__(self, num_blocks[3], stride=2) - self.pool = getattr(pooling_layers, pooling_func)( - in_dim=self.stats_dim * block.expansion) + self.pool = getattr(pooling_layers, + pooling_func)(in_dim=self.stats_dim * + block.expansion) self.pool_out_dim = self.pool.get_out_dim() - self.seg_1 = nn.Linear(self.pool_out_dim, - embed_dim) + self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim) if self.two_emb_layer: self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False) self.seg_2 = nn.Linear(embed_dim, embed_dim) @@ -191,7 +189,6 @@ def forward(self, x): return torch.tensor(0.0), embed_a - def ResNet18(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=True): return ResNet(BasicBlock, [2, 2, 2, 2], feat_dim=feat_dim, @@ -250,9 +247,7 @@ def ResNet293(feat_dim, embed_dim, pooling_func='TSTP', two_emb_layer=True): if __name__ == '__main__': x = torch.zeros(10, 200, 80) - model = ResNet34(feat_dim=80, - embed_dim=256, - pooling_func='MQMHASTP') + model = ResNet34(feat_dim=80, embed_dim=256, pooling_func='MQMHASTP') model.eval() out = model(x) print(out[-1].size()) diff --git a/wespeaker/models/speaker_model.py b/wespeaker/models/speaker_model.py index 0baeaac7..1d3e1d91 100644 --- a/wespeaker/models/speaker_model.py +++ b/wespeaker/models/speaker_model.py @@ -18,6 +18,7 @@ import wespeaker.models.repvgg as repvgg import wespeaker.models.campplus as campplus + def get_speaker_model(model_name: str): if model_name.startswith("XVEC"): return getattr(tdnn, model_name) diff --git a/wespeaker/models/tdnn.py b/wespeaker/models/tdnn.py index 9c724e8d..daf45d9a 100644 --- a/wespeaker/models/tdnn.py +++ b/wespeaker/models/tdnn.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """TDNN model for x-vector learning""" import torch @@ -21,6 +20,7 @@ class TdnnLayer(nn.Module): + def __init__(self, in_dim, out_dim, context_size, dilation=1, padding=0): """Define the TDNN layer, essentially 1-D convolution @@ -54,6 +54,7 @@ def forward(self, x): class XVEC(nn.Module): + def __init__(self, feat_dim=40, hid_dim=512, diff --git a/wespeaker/ssl/bin/average_contrastive_model.py b/wespeaker/ssl/bin/average_contrastive_model.py index 06e98a78..e45ea08e 100644 --- a/wespeaker/ssl/bin/average_contrastive_model.py +++ b/wespeaker/ssl/bin/average_contrastive_model.py @@ -36,10 +36,11 @@ def get_args(): default=0, type=int, help='min epoch used for averaging model') - parser.add_argument('--max_epoch', - default=65536, # Big enough - type=int, - help='max epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') args = parser.parse_args() print(args) return args @@ -61,7 +62,8 @@ def get_model_encoder_state_dict(state_dict): def main(): args = get_args() - path_list = glob.glob('{}/[!avg][!final][!convert]*.pt'.format(args.src_path)) + path_list = glob.glob('{}/[!avg][!final][!convert]*.pt'.format( + args.src_path)) path_list = sorted( path_list, key=lambda p: int(re.findall(r"(?<=model_)\d*(?=.pt)", p)[0])) diff --git a/wespeaker/ssl/bin/average_dino_model.py b/wespeaker/ssl/bin/average_dino_model.py index 115b2d30..c09290e5 100644 --- a/wespeaker/ssl/bin/average_dino_model.py +++ b/wespeaker/ssl/bin/average_dino_model.py @@ -36,10 +36,11 @@ def get_args(): default=0, type=int, help='min epoch used for averaging model') - parser.add_argument('--max_epoch', - default=65536, # Big enough - type=int, - help='max epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') args = parser.parse_args() print(args) return args @@ -57,7 +58,8 @@ def get_t_model_state_dict(state_dict): def main(): args = get_args() - path_list = glob.glob('{}/[!avg][!final][!convert]*.pt'.format(args.src_path)) + path_list = glob.glob('{}/[!avg][!final][!convert]*.pt'.format( + args.src_path)) path_list = sorted( path_list, key=lambda p: int(re.findall(r"(?<=model_)\d*(?=.pt)", p)[0])) diff --git a/wespeaker/ssl/dataset/dataset.py b/wespeaker/ssl/dataset/dataset.py index 96c20e78..36f895a8 100644 --- a/wespeaker/ssl/dataset/dataset.py +++ b/wespeaker/ssl/dataset/dataset.py @@ -99,14 +99,15 @@ def SSLDataset(data_type, filter_conf = configs.get('filter_args', {}) dataset = Processor(dataset, processor.filter, - frame_shift=configs['fbank_args'].get('frame_shift', 10), + frame_shift=configs['fbank_args'].get( + 'frame_shift', 10), data_type=data_type, - **filter_conf - ) + **filter_conf) # Local shuffle if shuffle: - dataset = Processor(dataset, processor.shuffle, **configs['shuffle_args']) + dataset = Processor(dataset, processor.shuffle, + **configs['shuffle_args']) # spk2id dataset = Processor(dataset, ssl_processor.spk_to_id, spk2id_dict) @@ -158,6 +159,7 @@ def SSLDataset(data_type, # spec augmentation spec_aug_flag = configs.get('spec_aug', True) if spec_aug_flag: - dataset = Processor(dataset, ssl_processor.spec_aug, **configs['spec_aug_args']) + dataset = Processor(dataset, ssl_processor.spec_aug, + **configs['spec_aug_args']) return dataset diff --git a/wespeaker/ssl/utils/contrastive_executor.py b/wespeaker/ssl/utils/contrastive_executor.py index 5cd287c5..2f9fcd26 100644 --- a/wespeaker/ssl/utils/contrastive_executor.py +++ b/wespeaker/ssl/utils/contrastive_executor.py @@ -53,8 +53,7 @@ def run_epoch(dataloader, # loss, acc loss_meter.add(loss.item()) - acc_meter.add(logits.cpu().detach().numpy(), - labels.cpu().numpy()) + acc_meter.add(logits.cpu().detach().numpy(), labels.cpu().numpy()) # updata the model optimizer.zero_grad() diff --git a/wespeaker/utils/dataset_utils_deprecated.py b/wespeaker/utils/dataset_utils_deprecated.py index e3b14498..bc91fc4e 100644 --- a/wespeaker/utils/dataset_utils_deprecated.py +++ b/wespeaker/utils/dataset_utils_deprecated.py @@ -27,7 +27,8 @@ def get_random_chunk(data, chunk_len): data = data[chunk_start:chunk_start + adjust_chunk_len] # padding if needed if adjust_chunk_len < chunk_len: - chunk_shape = chunk_len if len(data_shape) == 1 else (chunk_len, data.shape[1]) + chunk_shape = chunk_len if len(data_shape) == 1 else (chunk_len, + data.shape[1]) data = np.resize(data, chunk_shape) # repeating return data diff --git a/wespeaker/utils/executor.py b/wespeaker/utils/executor.py index 7935a1b7..93e868c3 100644 --- a/wespeaker/utils/executor.py +++ b/wespeaker/utils/executor.py @@ -60,8 +60,7 @@ def run_epoch(dataloader, # loss, acc loss_meter.add(loss.item()) - acc_meter.add(outputs.cpu().detach().numpy(), - targets.cpu().numpy()) + acc_meter.add(outputs.cpu().detach().numpy(), targets.cpu().numpy()) # updata the model optimizer.zero_grad() @@ -83,8 +82,8 @@ def run_epoch(dataloader, break logger.info( - tp.row((epoch, i + 1, scheduler.get_lr(), - margin_scheduler.get_margin()) + - (loss_meter.value()[0], acc_meter.value()[0]), - width=10, - style='grid')) + tp.row( + (epoch, i + 1, scheduler.get_lr(), margin_scheduler.get_margin()) + + (loss_meter.value()[0], acc_meter.value()[0]), + width=10, + style='grid')) diff --git a/wespeaker/utils/file_utils.py b/wespeaker/utils/file_utils.py index 44113547..044d3ce2 100644 --- a/wespeaker/utils/file_utils.py +++ b/wespeaker/utils/file_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + def read_scp(scp_file): """ scp_file: mostly 2 columns """ diff --git a/wespeaker/utils/plda/kaldi_utils.py b/wespeaker/utils/plda/kaldi_utils.py index 9410ff5f..98bdcab0 100644 --- a/wespeaker/utils/plda/kaldi_utils.py +++ b/wespeaker/utils/plda/kaldi_utils.py @@ -120,7 +120,8 @@ def read_sparse_vector(fd): _format = fd.read(3).decode() assert (_format == 'SV ') _, dim = np.frombuffer(fd.read(5), dtype='int8,int32', count=1)[0] - _, num_elems = np.frombuffer(fd.read(5), dtype='int8,int32', count=1)[0] + _, num_elems = np.frombuffer(fd.read(5), dtype='int8,int32', + count=1)[0] col = [] data = [] for j in range(num_elems): @@ -146,5 +147,6 @@ def read_sparse_vector(fd): cols += col all_data += data max_dim = max(dim, max_dim) - sparse_mat = csr_matrix((all_data, (rows, cols)), shape=(num_rows, max_dim)) + sparse_mat = csr_matrix((all_data, (rows, cols)), + shape=(num_rows, max_dim)) return sparse_mat diff --git a/wespeaker/utils/plda/two_cov_plda.py b/wespeaker/utils/plda/two_cov_plda.py index f2d1ba9d..b4d74317 100644 --- a/wespeaker/utils/plda/two_cov_plda.py +++ b/wespeaker/utils/plda/two_cov_plda.py @@ -29,11 +29,12 @@ M_LOG_2PI = 1.8378770664093454835606594728112 - -ClassInfo = collections.namedtuple('ClassInfo', ['weight', 'num_example', 'mu']) +ClassInfo = collections.namedtuple('ClassInfo', + ['weight', 'num_example', 'mu']) class PldaStats(object): + def __init__(self, dim): self.dim = dim self.num_example, self.num_classes = 0, 0 @@ -61,7 +62,11 @@ def add_samples(self, weight, spk_embeddings): class TwoCovPLDA: - def __init__(self, scp_file=None, utt2spk_file=None, embed_dim=256, + + def __init__(self, + scp_file=None, + utt2spk_file=None, + embed_dim=256, normalize_length=True): self.normalize_length = normalize_length self.dim = embed_dim @@ -80,8 +85,8 @@ def __init__(self, scp_file=None, utt2spk_file=None, embed_dim=256, self.W_stats = np.zeros((self.dim, self.dim)) self.W_count = 0 if scp_file is not None: - samples, self.embeddings_dict = get_data_for_plda(scp_file, - utt2spk_file) + samples, self.embeddings_dict = get_data_for_plda( + scp_file, utt2spk_file) train_mean_vec = samples.mean(0) for key, mat in self.embeddings_dict.items(): mat = np.vstack(mat) @@ -143,8 +148,8 @@ def get_output(self): def transform_embedding(self, embedding): transformed_embedding = np.matmul(self.transform, embedding) transformed_embedding += self.offset - normalization_factor = math.sqrt(self.dim) / np.linalg.norm( - transformed_embedding) + normalization_factor = math.sqrt( + self.dim) / np.linalg.norm(transformed_embedding) if self.normalize_length: transformed_embedding = normalization_factor * transformed_embedding return transformed_embedding @@ -157,19 +162,24 @@ def log_likelihood_ratio(self, transformed_train_embedding, sqdiff = transformed_test_embedding - mean sqdiff = np.power(sqdiff, 2.0) variance = 1.0 / variance - loglike_given_class = -0.5 * ( - logdet + M_LOG_2PI * self.dim + np.dot(sqdiff, variance)) + loglike_given_class = -0.5 * (logdet + M_LOG_2PI * self.dim + + np.dot(sqdiff, variance)) sqdiff = transformed_test_embedding sqdiff = np.power(sqdiff, 2.0) variance = self.psi + 1.0 logdet = np.sum(np.log(variance)) variance = 1.0 / variance - loglike_without_class = -0.5 * ( - logdet + M_LOG_2PI * self.dim + np.dot(sqdiff, variance)) + loglike_without_class = -0.5 * (logdet + M_LOG_2PI * self.dim + + np.dot(sqdiff, variance)) loglike_ratio = loglike_given_class - loglike_without_class return loglike_ratio - def eval_sv(self, enroll_scp, enroll_utt2spk, test_scp, trials, score_file, + def eval_sv(self, + enroll_scp, + enroll_utt2spk, + test_scp, + trials, + score_file, indomain_scp=None): """ Caculate the plda score @@ -187,9 +197,8 @@ def eval_sv(self, enroll_scp, enroll_utt2spk, test_scp, trials, score_file, if indomain_scp is not None: indomain_embeddings_dict = read_vec_scp_file(indomain_scp) - mean_vec = np.vstack( - list(indomain_embeddings_dict.values()) - ).mean(0) + mean_vec = np.vstack(list( + indomain_embeddings_dict.values())).mean(0) else: mean_vec = np.zeros(self.dim) @@ -212,13 +221,11 @@ def eval_sv(self, enroll_scp, enroll_utt2spk, test_scp, trials, score_file, with open(trials, 'r') as read_trials: for line in tqdm(read_trials): tokens = line.strip().split() - score = self.log_likelihood_ratio( - enrollspks[tokens[0]], - testspks[tokens[1]]) + score = self.log_likelihood_ratio(enrollspks[tokens[0]], + testspks[tokens[1]]) segs = line.strip().split() - output_line = ( - '{} {} {:.5f} {}\n'.format(segs[0], segs[1], score, - segs[2])) + output_line = ('{} {} {:.5f} {}\n'.format( + segs[0], segs[1], score, segs[2])) write_score.write(output_line) def adapt(self, adapt_scp, ac_scale=0.5, wc_scale=0.5):