-
Notifications
You must be signed in to change notification settings - Fork 0
/
sonify.py
73 lines (59 loc) · 3.13 KB
/
sonify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from EvalWhisper import *
import soundfile as sf
import argparse
import json
import pathlib
from pathlib import Path
import statistics
from EvalWhisper import EvalWhisper
from datasets import load_dataset
from tqdm import tqdm
def get_audio(output_path, whisper_evaluator, inputs, mask_ratios, mode_value, what_to_mask_list, sampling_rate=16000, num_skipped=0):
output_audio_list = []
for idx, input in tqdm(enumerate(inputs)):
masked_spectrograms = []
original_spectrogram = whisper_evaluator.get_spectrogram(input)
# print(f"os:{original_spectrogram.shape}")
#sf.write(f"{output_path}/{idx+num_skipped}_original_all.wav", input["audio"]["array"], sampling_rate) #whisper_evaluator.top_r_features(input, r=1.0, mode="retain", where="top").detach().numpy()
#sf.write(f"{output_path}/{idx+num_skipped}_{10}_all.wav", whisper_evaluator.sonify(original_spectrogram), sampling_rate) #whisper_evaluator.top_r_features(input, r=1.0, mode="retain", where="top").detach().numpy()
with open(f"{output_path}/{idx+num_skipped}_original_transcription.txt", "w") as op_file:
op_file.write(input["text"] + "\n")
for mask_ratio in tqdm(mask_ratios):
mask_list = []
for mask in what_to_mask_list:
masked_spectrogram = whisper_evaluator.top_r_features(input, r=mask_ratio, mode=mode_value, where=mask)
masked_audio = whisper_evaluator.sonify(masked_spectrogram.detach().numpy())
mask_list.append(masked_audio)
sf.write(f"{output_path}/{idx+num_skipped}_{int(mask_ratio * 10)}_{mask[0]}.wav", masked_audio, sampling_rate)
masked_spectrograms.append(mask_list)
def main(args):
num_skipped = args.num_skipped
num_samples = args.num_samples
output_path = Path(args.output_dir)
model_size = "large"
model_checkpoint = "/scratch/general/vast/u0403624/cs6966/salASR/models/current/checkpoint-1440" #f"openai/whisper-{model_size}"
processor_checkpoint = [f"openai/whisper-{model_size}"]
#load processor and model
print(f"Loading model . . . ({model_checkpoint})")
whisper_evaluator = EvalWhisper(model_checkpoint, *processor_checkpoint)
print(f"Loaded model")
print(f"Loading data . . . .")
ds = load_dataset("librispeech_asr", split="validation.clean", streaming=True)
print(f"Loaded data")
pathlib.Path.mkdir(output_path, exist_ok=True)
mask_ratios = [0.8, 0.5, 0.2]
mode_value = "retain"
what_to_mask_list = ["top", "bottom", "random"]
# num_samples = 3
skip_to_index = num_skipped
inputs = []
for sample in tqdm(ds.skip(skip_to_index).take(num_samples)):
inputs.append(sample)
get_audio(output_path, whisper_evaluator, inputs, mask_ratios, mode_value, what_to_mask_list, num_skipped=num_skipped)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--num_skipped", type=int, default=0)
parser.add_argument("-n", "--num_samples", type=int, default=30)
parser.add_argument("-o", "--output_dir", type=str, default="./")
args = parser.parse_args()
main(args)