diff --git a/README.md b/README.md index 2ea73a1e..c379ac40 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,29 @@ online feature extraction or loading pre-extracted features in kaldi-format. ## Installation +### Install python package +``` sh +pip install git+https://github.com/wenet-e2e/wespeaker.git +``` +**Command-line usage** (use `-h` for parameters): + +``` sh +$ wespeaker --task embedding --audio_file audio.wav +$ wespeaker --task similarity --audio_file audio.wav --audio_file2 audio2.wav +$ wespeaker --task diarization --audio_file audio.wav # TODO +``` + +**Python programming usage**: + +``` python +import wespeaker + +model = wespeaker.load_model('chinese') +embedding = model.extract_embeding('audio.wav') +similarity = model.compute_similarity('audio1.wav', 'audio2.wav') +``` + +### Install for development & deployment * Clone this repo ``` sh git clone https://github.com/wenet-e2e/wespeaker.git diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..21fe914f --- /dev/null +++ b/setup.cfg @@ -0,0 +1,15 @@ +[metadata] +name = wespeaker +version = 0.0.0 +license = Apache Software License +description = End to end speech speaker toolkit +long_description = file: README.md +classifiers = + License :: OSI Approved :: Apache Software License + Operating System :: OS Independent + Programming Language :: Python :: 3 + +[options] +packages = find: +include_package_data = True +python_requires = >= 3.8 diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..67f62933 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, find_packages + +requirements = [ + "tqdm", + "onnxruntime>=1.12.0", + "python-speech-features>=0.6", + "scipy>=1.5.2", +] + +setup( + name="wespeaker", + install_requires=requirements, + packages=find_packages(), + entry_points={"console_scripts": [ + "wespeaker = wespeaker.cli.speaker:main", + ]}, +) diff --git a/wespeaker/__init__.py b/wespeaker/__init__.py new file mode 100644 index 00000000..fda4fac3 --- /dev/null +++ b/wespeaker/__init__.py @@ -0,0 +1 @@ +from wespeaker.cli.speaker import load_model # noqa diff --git a/wespeaker/cli/__init__.py b/wespeaker/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/wespeaker/cli/hub.py b/wespeaker/cli/hub.py new file mode 100644 index 00000000..3af6a3ca --- /dev/null +++ b/wespeaker/cli/hub.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 Mddct(hamddct@gmail.com) +# 2023 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import requests +import sys +from pathlib import Path +from urllib.request import urlretrieve + +import tqdm + + +def download(url: str, dest: str): + """ download from url to dest + """ + print('Downloading {} to {}'.format(url, dest)) + + def progress_hook(t): + last_b = [0] + + def update_to(b=1, bsize=1, tsize=None): + if tsize not in (None, -1): + t.total = tsize + displayed = t.update((b - last_b[0]) * bsize) + last_b[0] = b + return displayed + + return update_to + + # *.tar.gz + name = url.split('?')[0].split('/')[-1] + with tqdm.tqdm(unit='B', + unit_scale=True, + unit_divisor=1024, + miniters=1, + desc=(name)) as t: + urlretrieve(url, filename=dest, reporthook=progress_hook(t), data=None) + t.total = t.n + + +class Hub(object): + Assets = { + "chinese": "cnceleb_resnet34.onnx", + "english": "voxceleb_resnet221_LM.onnx", + } + + def __init__(self) -> None: + pass + + @staticmethod + def get_model(lang: str) -> str: + if lang not in Hub.Assets.keys(): + print('ERROR: Unsupported lang {} !!!'.format(lang)) + sys.exit(1) + model = Hub.Assets[lang] + model_path = os.path.join(Path.home(), ".wespeaker", model) + if not os.path.exists(model_path): + if not os.path.exists(os.path.dirname(model_path)): + os.makedirs(os.path.dirname(model_path)) + response = requests.get( + "https://modelscope.cn/api/v1/datasets/wenet/wespeaker_pretrained_models/oss/tree" # noqa + ) + model_info = next(data for data in response.json()["Data"] + if data["Key"] == model) + model_url = model_info['Url'] + download(model_url, model_path) + return model_path diff --git a/wespeaker/cli/speaker.py b/wespeaker/cli/speaker.py new file mode 100644 index 00000000..a05aa5a8 --- /dev/null +++ b/wespeaker/cli/speaker.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys + +import numpy as np +import onnxruntime as ort +import scipy.io.wavfile as wav +from numpy.linalg import norm +from python_speech_features import fbank + +from wespeaker.cli.hub import Hub + + +class Speaker: + def __init__(self, model_path: str): + self.session = ort.InferenceSession(model_path) + + def extract_embedding(self, audio_path: str): + sample_rate, pcm = wav.read(audio_path) + # TODO(Binbin Zhang): verify the feat + feats, _ = fbank(pcm, + sample_rate, + nfilt=80, + lowfreq=20, + winfunc=np.hamming) + feats = np.log(feats) + feats = np.expand_dims(feats, axis=0).astype(np.float32) + outputs = self.session.run(None, {"feats": feats}) + embedding = outputs[0][0] + return embedding + + def compute_similarity(self, audio_path1: str, audio_path2) -> float: + e1 = self.extract_embedding(audio_path1) + e2 = self.extract_embedding(audio_path2) + s = np.dot(e1, e2) / (norm(e1) * norm(e2)) + return s + + # TODO(Chengdong Liang): Add implementation + def register(self, audio_path: str): + pass + + # TODO(Chengdong Liang): Add implementation + def recognize(self, audio_path: str): + pass + + +def load_model(language: str) -> Speaker: + model_path = Hub.get_model(language) + return Speaker(model_path) + + +def get_args(): + parser = argparse.ArgumentParser(description='') + parser.add_argument('-t', + '--task', + choices=[ + 'embedding', + 'similarity', + 'diarization', + ], + default='embedding', + help='task type') + parser.add_argument('-l', + '--language', + choices=[ + 'chinese', + 'english', + ], + default='chinese', + help='language type') + parser.add_argument('--audio_file', help='audio file') + parser.add_argument('--audio_file2', + help='audio file2, for similarity task') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + model = load_model(args.language) + if args.task == 'embedding': + print(model.extract_embedding(args.audio_file)) + elif args.task == 'similarity': + print(model.compute_similarity(args.audio_file, args.audio_file2)) + elif args.task == 'diarization': + # TODO(Chengdong Liang): Add diarization surport + pass + else: + print('Unsupported task {}'.format(args.task)) + sys.exit(-1) + + +if __name__ == '__main__': + main()