-
Notifications
You must be signed in to change notification settings - Fork 2
/
augment_binaural_speech.py
166 lines (144 loc) · 6.35 KB
/
augment_binaural_speech.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import numpy as np
import argparse
from multiprocessing import Pool
import random
import soundfile as sf
import scipy.signal as ssi
from tqdm import tqdm
import utility
import pickle
def augment_data(speech_path, output_path, irfile_path):
speech, fs_s = sf.read(speech_path)
speech_length = speech.shape[0]
if(speech_length>96000):
speech = speech[0:96000]
# sf.write(process_full_path,IR,fs_s)
else:
zeros_len = 96000 - speech_length
zeros_lis = np.zeros(zeros_len)
speech = np.concatenate([speech,zeros_lis])
# if len(speech.shape) != 1:
# speech = speech[:, 0]
if np.issubdtype(speech.dtype, np.integer):
speech = utility.pcm2float(speech, 'float32')
# convolution
if irfile_path:
IR, fs_i = sf.read(irfile_path)
IR_length = IR.shape[0]
if(IR_length>fs_s):
IR = IR[0:fs_s,:]
# sf.write(process_full_path,IR,fs_s)
else:
zeros_len = fs_s - IR_length
zeros_lis = np.zeros([zeros_len,2])
IR = np.concatenate([IR,zeros_lis])
# sf.write(process_full_path,IR,fs_s)
# if len(IR.shape) != 1:
# IR = IR[:, 0]
if np.issubdtype(IR.dtype, np.integer):
IR = utility.pcm2float(IR, 'float32')
# speech = utility.convert_samplerate(speech, fs_s, fs_i)
# fs_s = fs_i
# eliminate delays due to direct path propagation
direct_idx = np.argmax(np.fabs(IR))
#print('speech {} direct index is {} of total {} samples'.format(speech_path, direct_idx, len(IR)))
# # temp = utility.smart_convolve(speech, IR[direct_idx:])
# print("speech_shape ",speech.shape)
# print("IR shape ",IR.shape)
# input("summa ")
temp0 = utility.smart_convolve(speech, IR[:,0])
temp1 = utility.smart_convolve(speech, IR[:,1])
temp =np.transpose(np.concatenate(([temp0], [temp1]),axis=0))
speech = np.array(temp)
# # adding noises
# if noise_path:
# noise, fs_n = sf.read(noise_path)
# if len(noise.shape) != 1:
# print("noise file should be single channel")
# return -1
# if np.issubdtype(noise.dtype, np.integer):
# noise = utility.pcm2float(noise, 'float32')
# noise = utility.convert_samplerate(noise, fs_n, fs_s)
# fs_n = fs_s
# speech_len = len(speech)
# noise_len = len(noise)
# nrep = int(speech_len * 2 / noise_len)
# if nrep >= 1:
# noise = np.repeat(noise, nrep + 1)
# noise_len = len(noise)
# start = np.random.randint(noise_len - speech_len)
# noise = noise[start:(start + speech_len)]
# signal_power = utility.calc_valid_power(speech)
# noise_power = utility.calc_valid_power(noise)
# K = (signal_power / noise_power) * np.power(10, -SNR / 10)
# new_noise = np.sqrt(K) * noise
# speech = speech + new_noise
maxval = np.max(np.fabs(speech))
if maxval == 0:
print("file {} not saved due to zero strength".format(speech_path))
return -1
if maxval >= 1:
amp_ratio = 0.99 / maxval
speech = speech * amp_ratio
sf.write(output_path, speech, fs_s)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='augment',
description="""Script to augment dataset""")
parser.add_argument("--ir", "-i", default=None, help="Directory of IR files", type=str)
# parser.add_argument("--noise", "-no", default=None, help="Directory of noise files", type=str)
parser.add_argument("--speech", "-sp", required=True, help="Directory of speech files", type=str)
parser.add_argument("--out", "-o", required=True, help="Output folder path", type=str)
parser.add_argument("--seed", "-s", default=0, help="Random seed", type=int)
parser.add_argument("--nthreads", "-n", type=int, default=16, help="Number of threads to use")
args = parser.parse_args()
speech_folder = args.speech
# noise_folder = args.noise
ir_folder = args.ir
output_folder = args.out
nthreads = args.nthreads
embedding_list={}
# force input and output folder to have the same ending format (i.e., w/ or w/o slash)
speech_folder = os.path.join(speech_folder, '')
output_folder = os.path.join(output_folder, '')
add_reverb = True if ir_folder else False
# add_noise = True if noise_folder else False
assert os.path.exists(speech_folder)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# if add_noise:
# assert os.path.exists(noise_folder)
if add_reverb:
assert os.path.exists(ir_folder)
speechlist = [os.path.join(root, name) for root, dirs, files in os.walk(speech_folder)
for name in files if name.endswith(".wav")]
irlist = [os.path.join(root, name) for root, dirs, files in os.walk(ir_folder)
for name in files if name.endswith(".wav")] if add_reverb else []
# noiselist = [os.path.join(root, name) for root, dirs, files in os.walk(noise_folder)
# for name in files if name.endswith(".wav")] if add_noise else []
# apply_async callback
pbar = tqdm(total=len(speechlist))
def update(*a):
pbar.update()
try:
# # Create a pool to communicate with the worker threads
pool = Pool(processes=nthreads)
for speech_path in speechlist:
ir_sample = random.choice(irlist) if add_reverb else None
# noise_sample = random.choice(noiselist) if add_noise else None
# SNR = np.random.uniform(10, 20)
output_path = speech_path.replace(speech_folder, output_folder)
embedding_list[output_path.split("/")[-1]] = ir_sample.split("/")[-1]
# embeddings = [speech_path,output_path,ir_sample]
# embedding_list.append(embeddings)
if not os.path.exists(os.path.dirname(output_path)):
os.makedirs(os.path.dirname(output_path))
pool.apply_async(augment_data, args=(speech_path, output_path, ir_sample), callback=update)
except Exception as e:
print(str(e))
pool.close()
pool.close()
pool.join() #s
embeddings_pickle =output_folder+"dictionary.pickle"
with open(embeddings_pickle, 'wb') as f:
pickle.dump(embedding_list, f, protocol=2)