Skip to content

Commit

Permalink
feat(ml):verify multi batch prompt match x
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Jun 24, 2024
1 parent 0b7d0ac commit db72bb4
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions models/modules/img2img_turbo/img2img_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def __init__(self, in_channels, out_channels, lora_rank_unet, lora_rank_vae):
unet.enable_gradient_checkpointing()

def forward(self, x, prompt):

caption_tokens = self.tokenizer(
prompt,
max_length=self.tokenizer.model_max_length,
Expand All @@ -205,7 +204,7 @@ def forward(self, x, prompt):
batch_size = caption_enc.shape[0]
repeated_encs = [
caption_enc[i].repeat(int(x.shape[0] / batch_size), 1, 1)
for i in range(caption_enc.shape[0])
for i in range(batch_size)
]

# Concatenate the repeated encodings along the batch dimension
Expand Down

0 comments on commit db72bb4

Please sign in to comment.