Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError When Using TensorAudio Batch Transforms #23

Closed
vishalbakshi opened this issue May 5, 2024 · 3 comments · Fixed by #24
Closed

TypeError When Using TensorAudio Batch Transforms #23

vishalbakshi opened this issue May 5, 2024 · 3 comments · Fixed by #24

Comments

@vishalbakshi
Copy link

vishalbakshi commented May 5, 2024

First off, thanks for such an awesome library, I'm using it for the first time and really enjoying it.

I am running into a TypeError when using one of the TensorAudio batch transforms (VolumeBatch, PitchShift, TimeStretch, PitchShiftOrTimeStretch) when calling auds.summary where auds is the DataBlock.

I am using fastxtend version 0.1.7 and fastai version 2.7.14 on Kaggle.

I have created a reproducible example in this Kaggle Notebook. I chose to use Kaggle so that I can easily incorporate an audio classification dataset in the example.

In short, when I run the following code in that notebook:

auds = DataBlock(blocks = (AudioBlock, CategoryBlock),  
                 get_x = ColReader("filename", pref=path/"audio"/"audio"), 
                 splitter = ColSplitter(),
                 get_y = ColReader("category"),
                 batch_tfms = PitchShift())

auds.summary(df)

I get the error:

TypeError: BatchRandTransform.__call__() missing 1 required positional argument: 'split_idx'

I don't get this error when using the different item transforms (Flip, Roll, Volume, etc.) I also tried using fastai's built-in aug_transforms and RandomResizedCropGPU as batch_tfms and while it obviously doesn't do anything to the TensorAudio, auds.summary runs without error.

I looked at the source code for PitchShift and split_idx is defined at the top.

I also saw that there is an audio_fixes branch open so perhaps this is already a known issue. I'm not proficient at the internals of fastai, but am happy to help troubleshoot this issue if you can point me in a general direction of where to start looking. Or, perhaps I'm implementing it wrong and that's why I'm getting this error.

@warner-benjamin
Copy link
Owner

I think the issue is probably due to BatchRandTransform.__call__ not defaulting split_idx=None

Can you add this code right after the imports and see if it resolves the issue?

from fastxtend.transform import BatchRandTransform

def call_fix(self,
    b:Tensor|tuple[Tensor,...], # Batch item(s)
    split_idx:int|None = None, # Train (0) or valid (1) index
    **kwargs
) -> Tensor|tuple[Tensor,...]:
    "Call `super().__call__` if `self.do`"
    self.before_call(b, split_idx=split_idx)
    return super().__call__(b, split_idx=split_idx, **kwargs) if self.do else b
    
BatchRandTransform.__call__ = call_fix

The audio_fixes branch was merged in #18. I forgot to delete it.

@vishalbakshi
Copy link
Author

vishalbakshi commented May 6, 2024

Thanks that works!

This code:

auds = DataBlock(blocks = (AudioBlock, CategoryBlock),  
                 get_x = ColReader("filename", pref=path/"audio"/"audio"), 
                 splitter = ColSplitter(),
                 get_y = ColReader("category"),
                 batch_tfms = PitchShift())

auds.summary(df)

Now outputs the expected result:

image

Note that I was running into the following error when using your fix:

image

when running auds.summary(df):

image

So I modified the return statement as follows (ChatGPT assist) and it worked:

def call_fix(self,
    b:Tensor|tuple[Tensor,...], # Batch item(s)
    split_idx:int|None = None, # Train (0) or valid (1) index
    **kwargs
) -> Tensor|tuple[Tensor,...]:
    "Call `super().__call__` if `self.do`"
    self.before_call(b, split_idx=split_idx)
    return super(BatchRandTransform, self).__call__(b, split_idx=split_idx, **kwargs) if self.do else b
    
BatchRandTransform.__call__ = call_fix

@warner-benjamin
Copy link
Owner

@vishalbakshi can you check that #24 resolves this issue? If so, I will merge it and cut a release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants