Skip to content

Commit

Permalink
Fix AUTO device (#2608)
Browse files Browse the repository at this point in the history
CVS-158945
  • Loading branch information
aleksandr-mokrov authored Dec 19, 2024
1 parent 5e808a1 commit 2177b6e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion notebooks/catvton/gradio_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def image_grid(imgs, rows, cols):

def make_demo(pipeline, mask_processor, automasker, output_dir):
def submit_function(person_image, cloth_image, cloth_type, num_inference_steps, guidance_scale, seed, show_type):
width = 1024
width = 768
height = 1024
person_image, mask = person_image["background"], person_image["layers"][0]
mask = Image.open(mask).convert("L")
Expand Down
4 changes: 2 additions & 2 deletions notebooks/catvton/ov_catvton_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def convert_pipeline_models(pipeline):
convert(VaeDecoder(pipeline.vae), VAE_DECODER_PATH, torch.zeros(1, 4, 128, 96))
del pipeline.vae

inpainting_latent_model_input = torch.zeros(2, 9, 256, 96)
inpainting_latent_model_input = torch.rand(2, 9, 256, 96)
timestep = torch.tensor(0)
encoder_hidden_states = torch.zeros(2, 1, 768)
encoder_hidden_states = torch.Tensor(0)
example_input = (inpainting_latent_model_input, timestep, encoder_hidden_states)

convert(UNetWrapper(pipeline.unet), UNET_PATH, example_input)
Expand Down

0 comments on commit 2177b6e

Please sign in to comment.