Skip to content

Commit

Permalink
Fix BatchRandTransform & PitchShift
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed May 6, 2024
1 parent e43c005 commit 5f7e5de
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
12 changes: 9 additions & 3 deletions fastxtend/audio/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from fractions import Fraction
from functools import reduce
from itertools import chain, count, islice, repeat
from packaging.version import parse

import colorednoise
from primePy import primes

import torchaudio
from torch import _VF
from torch.distributions import Bernoulli
import torchaudio.transforms as tatfms
Expand Down Expand Up @@ -408,11 +410,13 @@ def _get_fast_stretches(
def pitch_shift(x:TensorAudio, n_fft, hop_length, shift, sr, new_sr, gcd, kernel, width, padmode, constant, device):
shape = x.shape
x = x.reshape(shape[0] * shape[1], shape[2])
x = torch.stft(x, n_fft, hop_length, return_complex=True)
window = torch.hann_window(n_fft, device=device)
x = torch.stft(x, n_fft, hop_length, window=window, return_complex=True)
phase_advance = torch.linspace(0, math.pi * hop_length, x.shape[1], device=device)[..., None]
x = TAF.phase_vocoder(x, float(1 / shift), phase_advance)
phase_advance = None
x = torch.istft(x, n_fft, hop_length)
x = torch.istft(x, n_fft, hop_length, window=window)
window = None
x = retain_type(_apply_sinc_resample_kernel(x, sr, new_sr, gcd, kernel, width), typ=TensorAudio)
crop_start = torch.randint(0, x.shape[-1]-shape[-1], (1,)) if shape[-1] < x.shape[-1] else None
pad_len = (shape[-1]-x.shape[-1]) if shape[-1] > x.shape[-1] else 0
Expand All @@ -430,6 +434,8 @@ def __init__(self,
constant:Numeric=0, # Value for `AudioPadMode.Constant`
split:int|None=None # Apply transform to `split` items at a time. Use to prevent GPU OOM.
):
if parse(torch.__version__) < parse('2.1.0'):
raise ImportError(f"`PitchShift` requires a minimum of PyTorch 2.1. Current version: {torch.__version__}")
super().__init__(p=p)
store_attr(but='p')
self.sr = 0
Expand All @@ -455,7 +461,7 @@ def before_call(self,
self.new_sr = int(self.sr/self.shift)
self.gcd = math.gcd(self.sr, self.new_sr)
self.kernel, self.width = _get_sinc_resample_kernel(self.sr, self.new_sr, self.gcd,
6, 0.99, 'sinc_interpolation',
6, 0.99, 'sinc_interp_hann',
None, self.device, self.type)

def encodes(self, x:TensorAudio) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion fastxtend/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def before_call(self,

def __call__(self,
b:Tensor|tuple[Tensor,...], # Batch item(s)
split_idx:int, # Train (0) or valid (1) index
split_idx:int|None=None, # Train (0) or valid (1) index
**kwargs
) -> Tensor|tuple[Tensor,...]:
"Call `super().__call__` if `self.do`"
Expand Down
12 changes: 9 additions & 3 deletions nbs/audio.03_augment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@
"from fractions import Fraction\n",
"from functools import reduce\n",
"from itertools import chain, count, islice, repeat\n",
"from packaging.version import parse\n",
"\n",
"import colorednoise\n",
"from primePy import primes\n",
"\n",
"import torchaudio\n",
"from torch import _VF\n",
"from torch.distributions import Bernoulli\n",
"import torchaudio.transforms as tatfms\n",
Expand Down Expand Up @@ -913,11 +915,13 @@
"def pitch_shift(x:TensorAudio, n_fft, hop_length, shift, sr, new_sr, gcd, kernel, width, padmode, constant, device):\n",
" shape = x.shape\n",
" x = x.reshape(shape[0] * shape[1], shape[2])\n",
" x = torch.stft(x, n_fft, hop_length, return_complex=True)\n",
" window = torch.hann_window(n_fft, device=device)\n",
" x = torch.stft(x, n_fft, hop_length, window=window, return_complex=True)\n",
" phase_advance = torch.linspace(0, math.pi * hop_length, x.shape[1], device=device)[..., None]\n",
" x = TAF.phase_vocoder(x, float(1 / shift), phase_advance)\n",
" phase_advance = None\n",
" x = torch.istft(x, n_fft, hop_length)\n",
" x = torch.istft(x, n_fft, hop_length, window=window)\n",
" window = None\n",
" x = retain_type(_apply_sinc_resample_kernel(x, sr, new_sr, gcd, kernel, width), typ=TensorAudio)\n",
" crop_start = torch.randint(0, x.shape[-1]-shape[-1], (1,)) if shape[-1] < x.shape[-1] else None\n",
" pad_len = (shape[-1]-x.shape[-1]) if shape[-1] > x.shape[-1] else 0\n",
Expand All @@ -942,6 +946,8 @@
" constant:Numeric=0, # Value for `AudioPadMode.Constant`\n",
" split:int|None=None # Apply transform to `split` items at a time. Use to prevent GPU OOM.\n",
" ):\n",
" if parse(torch.__version__) < parse('2.1.0'):\n",
" raise ImportError(f\"`PitchShift` requires a minimum of PyTorch 2.1. Current version: {torch.__version__}\")\n",
" super().__init__(p=p)\n",
" store_attr(but='p')\n",
" self.sr = 0\n",
Expand All @@ -967,7 +973,7 @@
" self.new_sr = int(self.sr/self.shift)\n",
" self.gcd = math.gcd(self.sr, self.new_sr)\n",
" self.kernel, self.width = _get_sinc_resample_kernel(self.sr, self.new_sr, self.gcd,\n",
" 6, 0.99, 'sinc_interpolation',\n",
" 6, 0.99, 'sinc_interp_hann',\n",
" None, self.device, self.type)\n",
"\n",
" def encodes(self, x:TensorAudio) -> Tensor:\n",
Expand Down
2 changes: 1 addition & 1 deletion nbs/transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
"\n",
" def __call__(self,\n",
" b:Tensor|tuple[Tensor,...], # Batch item(s)\n",
" split_idx:int, # Train (0) or valid (1) index\n",
" split_idx:int|None=None, # Train (0) or valid (1) index\n",
" **kwargs\n",
" ) -> Tensor|tuple[Tensor,...]:\n",
" \"Call `super().__call__` if `self.do`\"\n",
Expand Down

0 comments on commit 5f7e5de

Please sign in to comment.