From a3b49f6255c72542a831414176042549c9b2a611 Mon Sep 17 00:00:00 2001 From: r-sawata Date: Fri, 7 May 2021 12:58:19 +0000 Subject: [PATCH 1/3] [Fix] Update "egs/musdb18/X-UMX/requirements.txt" --- egs/musdb18/X-UMX/requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/musdb18/X-UMX/requirements.txt b/egs/musdb18/X-UMX/requirements.txt index d7288c514..fb12b8c8a 100755 --- a/egs/musdb18/X-UMX/requirements.txt +++ b/egs/musdb18/X-UMX/requirements.txt @@ -1,2 +1,4 @@ scikit-learn>=0.22 musdb>=0.4.0 +museval>=0.4.0 +norbert>=0.2.1 From 3bc3b5dc988a8aa1cbb389fd049da21cbfa7310b Mon Sep 17 00:00:00 2001 From: r-sawata Date: Fri, 25 Jun 2021 21:21:34 +0900 Subject: [PATCH 2/3] [Fix] Bug of X-UMX and README.md --- README.md | 1 + egs/musdb18/X-UMX/README.md | 2 ++ egs/musdb18/X-UMX/local/dataloader.py | 8 ++++--- egs/musdb18/X-UMX/train.py | 33 +++++++++++++++++++++------ 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 40efc9168..16ade7660 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,7 @@ More information in [egs/README.md](./egs). * [x] [DPTNet](./asteroid/models/dptnet.py) ([Chen et al.](https://arxiv.org/abs/2007.13975)) * [x] [DCCRNet](./asteroid/models/dccrnet.py) ([Hu et al.](https://arxiv.org/abs/2008.00264)) * [x] [DCUNet](./asteroid/models/dcunet.py) ([Choi et al.](https://arxiv.org/abs/1903.03107)) +* [x] [CrossNet-Open-Unmix](./asteroid/models/x_umx.py) ([Sawata et al.](https://arxiv.org/abs/2010.04228)) * [ ] Open-Unmix (coming) ([Stöter et al.](https://sigsep.github.io/open-unmix/)) * [ ] Wavesplit (coming) ([Zeghidour et al.](https://arxiv.org/abs/2002.08933)) diff --git a/egs/musdb18/X-UMX/README.md b/egs/musdb18/X-UMX/README.md index 80da553ab..e91eb4236 100755 --- a/egs/musdb18/X-UMX/README.md +++ b/egs/musdb18/X-UMX/README.md @@ -2,6 +2,8 @@ This recipe contains __CrossNet-Open-Unmix (X-UMX)__, an improved version of [Open-Unmix (UMX)](https://github.com/sigsep/open-unmix-nnabla) for music source separation. X-UMX achieves an improved performance without additional learnable parameters compared to the original UMX model. Details of X-UMX can be found in [this paper](https://arxiv.org/abs/2010.04228). X-UMX is one of the two official baseline models for the [Music Demixing (MDX) Challenge 2021](https://www.aicrowd.com/challenges/music-demixing-challenge-ismir-2021). +__Related Projects:__ [umx-pytorch](https://github.com/sigsep/open-unmix-pytorch) | [umx-nnabla](https://github.com/sigsep/open-unmix-nnabla) | x-umx-pytorch | [x-umx-nnabla](https://github.com/sony/ai-research-code/tree/master/x-umx) | [musdb](https://github.com/sigsep/sigsep-mus-db) | [museval](https://github.com/sigsep/sigsep-mus-eval) + ### Source separation with pretrained model Pretrained models on MUSDB18 for X-UMX, which reproduce the results from our paper, are available and can be easily tried out: ``` diff --git a/egs/musdb18/X-UMX/local/dataloader.py b/egs/musdb18/X-UMX/local/dataloader.py index e54961f5e..28a3bb7c5 100755 --- a/egs/musdb18/X-UMX/local/dataloader.py +++ b/egs/musdb18/X-UMX/local/dataloader.py @@ -69,9 +69,11 @@ def filtering_out_valid(input_dataset): Return: input_dataset (w/o validation tracks) """ - for i, tmp in enumerate(input_dataset.tracks): - if str(tmp["path"]).split("/")[-1] in validation_tracks: - del input_dataset.tracks[i] + input_dataset.tracks = [ + tmp + for tmp in input_dataset.tracks + if not (str(tmp["path"]).split("/")[-1] in validation_tracks) + ] return input_dataset diff --git a/egs/musdb18/X-UMX/train.py b/egs/musdb18/X-UMX/train.py index b04ee2872..fb1f4a9bc 100755 --- a/egs/musdb18/X-UMX/train.py +++ b/egs/musdb18/X-UMX/train.py @@ -234,15 +234,25 @@ class MultiDomainLoss(_Loss): https://arxiv.org/abs/2010.04228 (and ICASSP 2021) """ - def __init__(self, args): + def __init__( + self, + window_length, + in_chan, + n_hop, + spec_power, + nb_channels, + loss_combine_sources, + loss_use_multidomain, + mix_coef, + ): super().__init__() self.transform = nn.Sequential( - _STFT(window_length=args.window_length, n_fft=args.in_chan, n_hop=args.nhop), - _Spectrogram(spec_power=args.spec_power, mono=(args.nb_channels == 1)), + _STFT(window_length=window_length, n_fft=in_chan, n_hop=n_hop), + _Spectrogram(spec_power=spec_power, mono=(nb_channels == 1)), ) - self._combi = args.loss_combine_sources - self._multi = args.loss_use_multidomain - self.coef = args.mix_coef + self._combi = loss_combine_sources + self._multi = loss_use_multidomain + self.coef = mix_coef print("Combination Loss: {}".format(self._combi)) if self._multi: print( @@ -413,7 +423,16 @@ def main(conf, args): es = EarlyStopping(monitor="val_loss", mode="min", patience=args.patience, verbose=True) # Define Loss function. - loss_func = MultiDomainLoss(args) + loss_func = MultiDomainLoss( + window_length=args.window_length, + in_chan=args.in_chan, + n_hop=args.nhop, + spec_power=args.spec_power, + nb_channels=args.nb_channels, + loss_combine_sources=args.loss_combine_sources, + loss_use_multidomain=args.loss_use_multidomain, + mix_coef=args.mix_coef, + ) system = XUMXManager( model=x_unmix, loss_func=loss_func, From d66d47b36462342e6527e6235100b290a4eb3fc4 Mon Sep 17 00:00:00 2001 From: r-sawata Date: Thu, 4 Nov 2021 17:59:01 +0000 Subject: [PATCH 3/3] [egs] Fix the issue regarding X-UMX when receiving a mono audio as the input --- egs/musdb18/X-UMX/train.py | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/egs/musdb18/X-UMX/train.py b/egs/musdb18/X-UMX/train.py index fb1f4a9bc..0dec7bed3 100755 --- a/egs/musdb18/X-UMX/train.py +++ b/egs/musdb18/X-UMX/train.py @@ -15,7 +15,7 @@ from asteroid.engine.system import System from asteroid.engine.optimizers import make_optimizer from asteroid.models import XUMX -from asteroid.models.x_umx import _STFT, _Spectrogram +from asteroid.models.x_umx import _STFT, _Spectrogram, _ISTFT from asteroid.losses import singlesrc_mse from torch.nn.modules.loss import _Loss from torch import nn @@ -79,14 +79,14 @@ def freq_domain_loss(s_hat, gt_spec, combination=True): calculated frequency-domain loss """ - n_src = len(s_hat) + n_src, _, _, n_channel, _ = s_hat.shape idx_list = [i for i in range(n_src)] inferences = [] refrences = [] for i, s in enumerate(s_hat): inferences.append(s) - refrences.append(gt_spec[..., 2 * i : 2 * i + 2, :]) + refrences.append(gt_spec[..., n_channel * i : n_channel * (i + 1), :]) assert inferences[0].shape == refrences[0].shape _loss_mse = 0.0 @@ -143,7 +143,7 @@ def time_domain_loss(mix, time_hat, gt_time, combination=True): # Prepare Data and Fix Shape mix_ref = [mix] - mix_ref.extend([gt_time[..., 2 * i : 2 * i + 2, :] for i in range(n_src)]) + mix_ref.extend([gt_time[..., n_channel * i : n_channel * (i + 1), :] for i in range(n_src)]) mix_ref = torch.stack(mix_ref) mix_ref = mix_ref.view(-1, time_length) time_hat = time_hat.view(n_batch * n_channel * time_hat.shape[0], time_hat.shape[-1]) @@ -250,6 +250,8 @@ def __init__( _STFT(window_length=window_length, n_fft=in_chan, n_hop=n_hop), _Spectrogram(spec_power=spec_power, mono=(nb_channels == 1)), ) + self.istft = _ISTFT(window=self.transform[0].window, n_fft=in_chan, hop_length=n_hop) + self.nb_channels = nb_channels self._combi = loss_combine_sources self._multi = loss_use_multidomain self.coef = mix_coef @@ -277,12 +279,33 @@ def forward(self, est_targets, targets, return_est=False, **kwargs): # Fix shape and apply transformation of targets n_batch, n_src, n_channel, time_length = targets.shape - targets = targets.view(n_batch, n_src * n_channel, time_length) - Y = self.transform(targets)[0] + + # downmix in the frequency domain + if n_channel == 2 and self.nb_channels == 1: + Y = [] + signals = [] + for i in range(n_src): + spec, ang = self.transform(targets[:, i, ...]) + Y.append(spec.clone()) + spec = spec.permute(1, 2, 3, 0) + sig_downmix = self.istft(spec.unsqueeze(0), ang.unsqueeze(0)) + signals.append(sig_downmix.permute(1, 0, 2, 3)) + targets = torch.cat(signals, 1) + mixture_t = torch.sum(targets, 1) + targets = targets.squeeze(2) + Y = torch.cat(Y, dim=2) + else: + targets = targets.view(n_batch, n_src * self.nb_channels, time_length) + Y = self.transform(targets)[0] if self._multi: n_src = spec_hat.shape[0] - mixture_t = sum([targets[:, 2 * i : 2 * i + 2, ...] for i in range(n_src)]) + mixture_t = sum( + [ + targets[:, self.nb_channels * i : self.nb_channels * (i + 1), ...] + for i in range(n_src) + ] + ) loss_f = freq_domain_loss(spec_hat, Y, combination=self._combi) loss_t = time_domain_loss(mixture_t, time_hat, targets, combination=self._combi) loss = float(self.coef) * loss_t + loss_f