diff --git a/docs/python_package.md b/docs/python_package.md index 9c72ba0f..98cccaa2 100644 --- a/docs/python_package.md +++ b/docs/python_package.md @@ -21,6 +21,8 @@ $ wespeaker --task embedding --audio_file audio.wav --output_file embedding.txt $ wespeaker --task embedding_kaldi --wav_scp wav.scp --output_file /path/to/embedding $ wespeaker --task similarity --audio_file audio.wav --audio_file2 audio2.wav $ wespeaker --task diarization --audio_file audio.wav +$ wespeaker --task diarization --audio_file audio.wav --device cuda:0 # use CUDA on Windows/Linux +$ wespeaker --task diarization --audio_file audio.wav --device mps # use Metal Performance Shaders on MacOS ``` You can specify the following parameters. (use `-h` for details) @@ -33,7 +35,7 @@ You can specify the following parameters. (use `-h` for details) - diarization_list: apply speaker diarization for a kaldi-style wav.scp * `-l` or `--language`: use Chinese/English speaker models * `-p` or `--pretrain`: the path of pretrained model, `avg_model.pt` and `config.yaml` should be contained -* `-g` or `--gpu`: use GPU for inference, number $< 0$ means using CPU +* `--device`: set pytorch device, `cpu`, `cuda`, `cuda:0` or `mps` * `--campplus`: use [`campplus_cn_common_200k` of damo](https://www.modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) * `--eres2net`: @@ -69,14 +71,14 @@ which can either be the ones we provided and trained by yourself. import wespeaker model = wespeaker.load_model('chinese') -# set_gpu to enable the cuda inference, number < 0 means using CPU -model.set_gpu(0) +# set the device on which tensors are or will be allocated. +model.set_device('cuda:0') # embedding/embedding_kaldi/similarity/diarization embedding = model.extract_embedding('audio.wav') utt_names, embeddings = model.extract_embedding_list('wav.scp') similarity = model.compute_similarity('audio1.wav', 'audio2.wav') -diar_result = model.diarize('audio.wav') +diar_result = model.diarize('audio.wav', 'give_this_utt_a_name') # register and recognize model.register('spk1', 'spk1_audio1.wav') diff --git a/wespeaker/cli/speaker.py b/wespeaker/cli/speaker.py index be064ac1..67030556 100644 --- a/wespeaker/cli/speaker.py +++ b/wespeaker/cli/speaker.py @@ -72,11 +72,7 @@ def set_resample_rate(self, resample_rate: int): def set_vad(self, apply_vad: bool): self.apply_vad = apply_vad - def set_gpu(self, device_id: int): - if device_id >= 0: - device = 'cuda:{}'.format(device_id) - else: - device = 'cpu' + def set_device(self, device: str): self.device = torch.device(device) self.model = self.model.to(self.device) @@ -304,7 +300,7 @@ def main(): model = load_model_local(args.pretrain) model.set_resample_rate(args.resample_rate) model.set_vad(args.vad) - model.set_gpu(args.gpu) + model.set_device(args.device) model.set_diarization_params(min_duration=args.diar_min_duration, window_secs=args.diar_window_secs, period_secs=args.diar_period_secs, diff --git a/wespeaker/cli/utils.py b/wespeaker/cli/utils.py index 3d5125f2..34eb26b0 100644 --- a/wespeaker/cli/utils.py +++ b/wespeaker/cli/utils.py @@ -52,11 +52,12 @@ def get_args(): type=str, default="", help='model directory') - parser.add_argument('-g', - '--gpu', - type=int, - default=-1, - help='which gpu to use (number <0 means using cpu)') + parser.add_argument('--device', + type=str, + default='cpu', + help="device type (most commonly cpu or cuda," + "but also potentially mps, xpu, xla or meta)" + "and optional device ordinal for the device type.") parser.add_argument('--audio_file', help='audio file') parser.add_argument('--audio_file2', help='audio file2, specifically for similarity task') diff --git a/wespeaker/diar/extract_emb.py b/wespeaker/diar/extract_emb.py index da16fc85..6e0f4cfa 100644 --- a/wespeaker/diar/extract_emb.py +++ b/wespeaker/diar/extract_emb.py @@ -37,7 +37,7 @@ def init_session(source, device): opts = ort.SessionOptions() opts.inter_op_num_threads = 1 opts.intra_op_num_threads = 1 - opts.log_severity_level = 0 + opts.log_severity_level = 1 session = ort.InferenceSession(source, sess_options=opts, providers=providers) diff --git a/wespeaker/frontend/s3prl.py b/wespeaker/frontend/s3prl.py index d96168d5..37bf8c8e 100644 --- a/wespeaker/frontend/s3prl.py +++ b/wespeaker/frontend/s3prl.py @@ -57,7 +57,7 @@ def __init__(self, if layer != -1: layer_selections = [layer] - assert not multilayer_feature,\ + assert not multilayer_feature, \ "multilayer_feature must be False if layer is specified" else: layer_selections = None