Skip to content

Commit

Permalink
[BUG FIX] [Stable Audio Pipeline] Resolve torch.Tensor.new_zeros() Ty…
Browse files Browse the repository at this point in the history
…peError in function prepare_latents caused by audio_vae_length (#10306)

[BUG FIX] [Stable Audio Pipeline] TypeError: new_zeros(): argument 'size' failed to unpack the object at pos 3 with error "type must be tuple of ints,but got float"

torch.Tensor.new_zeros() takes a single argument size (int...) – a list, tuple, or torch.Size of integers defining the shape of the output tensor.

in function prepare_latents:
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
...
audio = initial_audio_waveforms.new_zeros(audio_shape)

audio_vae_length evaluates to float because self.transformer.config.sample_size returns a float

Co-authored-by: hlky <[email protected]>
  • Loading branch information
syntaxticsugr and hlky authored Dec 20, 2024
1 parent c8ee4af commit 9020086
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def prepare_latents(
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
)

audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)

# check num_channels
Expand Down

0 comments on commit 9020086

Please sign in to comment.