diff --git a/fabric/configs/preference_model_feedback.yaml b/fabric/configs/preference_model_feedback.yaml index 98c7fa6..59d2673 100644 --- a/fabric/configs/preference_model_feedback.yaml +++ b/fabric/configs/preference_model_feedback.yaml @@ -5,6 +5,7 @@ hydra: job: chdir: true +size: 512 model_version: "1.5" model_name: dreamlike-art/dreamlike-photoreal-2.0 negative_prompt: lowres, bad anatomy, bad hands, cropped, worst quality @@ -23,6 +24,7 @@ feedback: min_weight: 0.0 max_weight: 0.3 neg_scale: 0.5 + warmup_power: 1 global_seed: 0 image_seed: null liked_images: [] @@ -30,4 +32,5 @@ disliked_images: [] output_path: images # For loading the Human Preference Score LoRA weights, set null to not use them -lora_weights: null # resources/hps_lora/adapted_model.bin \ No newline at end of file +lora_weights: null # resources/hps_lora/adapted_model.bin +topk: 0 \ No newline at end of file diff --git a/fabric/configs/single_round.yaml b/fabric/configs/single_round.yaml index 1773f5e..2f03f80 100644 --- a/fabric/configs/single_round.yaml +++ b/fabric/configs/single_round.yaml @@ -1,17 +1,20 @@ model_version: "1.5" model_ckpt: null model_name: dreamlike-art/dreamlike-photoreal-2.0 -prompt: photo of a dog running on grassland, masterpiece, best quality, fine details +prompt: A robot holding a sign that says "The future is now", masterpiece, best quality negative_prompt: lowres, bad anatomy, bad hands, cropped, worst quality liked: [] disliked: [] -n_images: 4 -denoising_steps: 20 -guidance_scale: 7 +n_images: 1 +denoising_steps: 100 +guidance_scale: 6 seed: 37 feedback: - start: 0.33 - end: 0.66 + start: 0.0 + end: 0.8 min_weight: 0.05 - max_weight: 0.8 + max_weight: 0.9 neg_scale: 0.5 + warmup_power: 1 +topk: 4 +size: 512 \ No newline at end of file diff --git a/fabric/configs/target_image_feedback.yaml b/fabric/configs/target_image_feedback.yaml index f064b24..c30e129 100644 --- a/fabric/configs/target_image_feedback.yaml +++ b/fabric/configs/target_image_feedback.yaml @@ -5,13 +5,14 @@ hydra: job: chdir: true +size: 512 model_version: "1.5" model_name: dreamlike-art/dreamlike-photoreal-2.0 # runwayml/stable-diffusion-v1-5 negative_prompt: lowres, bad anatomy, bad hands, cropped, worst quality n_images: 4 n_rounds: 3 num_prompts: 1000 -denoising_steps: 20 +denoising_steps: 10 guidance_scale: 6 prompt_dropout: 0.0 use_pos_feedback: yes @@ -22,6 +23,7 @@ feedback: min_weight: 0.0 max_weight: 0.3 neg_scale: 0.5 + warmup_power: 1 global_seed: 0 image_seed: null liked_images: [] @@ -30,4 +32,5 @@ prompthero_path: resources/prompthero output_path: images # For loading the Human Preference Score LoRA weights, set null to not use them -lora_weights: null # resources/hps_lora/adapted_model.bin \ No newline at end of file +lora_weights: null # resources/hps_lora/adapted_model.bin +topk: 0 \ No newline at end of file diff --git a/fabric/generator.py b/fabric/generator.py index e464c66..336642f 100644 --- a/fabric/generator.py +++ b/fabric/generator.py @@ -1,42 +1,28 @@ import warnings from typing import List, Optional, Union +import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F -import numpy as np +import xformers from PIL import Image -from tqdm import tqdm -from diffusers import ( - StableDiffusionPipeline, - EulerAncestralDiscreteScheduler, -) +from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler as Scheduler from diffusers.models.attention import BasicTransformerBlock from diffusers.models.cross_attention import LoRACrossAttnProcessor +from tqdm import tqdm -try: - import xformers - has_xformers = True -except ImportError: - print("WARNING: xformers is not installed. Please install it using `pip install xformers`") - has_xformers = False +has_xformers = True def apply_unet_lora_weights(pipeline, unet_path): model_weight = torch.load(unet_path, map_location="cpu") unet = pipeline.unet lora_attn_procs = {} - lora_rank = list( - set([v.size(0) for k, v in model_weight.items() if k.endswith("down.weight")]) - ) + lora_rank = list(set([v.size(0) for k, v in model_weight.items() if k.endswith("down.weight")])) assert len(lora_rank) == 1 lora_rank = lora_rank[0] for name in unet.attn_processors.keys(): - cross_attention_dim = ( - None - if name.endswith("attn1.processor") - else unet.config.cross_attention_dim - ) + cross_attention_dim = (None if name.endswith("attn1.processor") else unet.config.cross_attention_dim) if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -46,36 +32,21 @@ def apply_unet_lora_weights(pipeline, unet_path): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=lora_rank, - ).to(pipeline.device) + lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, + rank=lora_rank, ).to(pipeline.device) unet.set_attn_processor(lora_attn_procs) unet.load_state_dict(model_weight, strict=False) -def attn_with_weights( - attn: nn.Module, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - weights=None, # shape: (batch_size, sequence_length) - lora_scale=1.0, -): +def attn_with_weights(attn: nn.Module, hidden_states, encoder_hidden_states=None, attention_mask=None, weights=None, + # shape: (batch_size, sequence_length) + lora_scale=1.0, batch: Optional[int] = None): batch_size, sequence_length, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if isinstance(attn.processor, LoRACrossAttnProcessor): - query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora( - hidden_states - ) + query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states) else: query = attn.to_q(hidden_states) @@ -85,49 +56,53 @@ def attn_with_weights( encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if isinstance(attn.processor, LoRACrossAttnProcessor): - key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora( - encoder_hidden_states - ) - value = attn.to_v( - encoder_hidden_states - ) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states) else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() + if batch is None: + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() - if not has_xformers: - attention_probs = attn.get_attention_scores(query, key, attention_mask) - if weights is not None: - if weights.shape[0] != 1: - weights = weights.repeat_interleave(attn.heads, dim=0) - attention_probs = attention_probs * weights[:, None] - attention_probs = attention_probs / attention_probs.sum(dim=-1, keepdim=True) - hidden_states = torch.bmm(attention_probs, value) + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=None, + scale=attn.scale) else: + ref_q, ref_k, ref_v = query[batch:], key[batch:], value[batch:] + q, k, v = query[:batch], key[:batch], value[:batch] + q_pos, q_neg = q.chunk(2, 0) + k_pos, k_neg = k.chunk(2, 0) + v_pos, v_neg = v.chunk(2, 0) + + ref_v2 = ref_v.view(1, -1, *ref_v.shape[2:]).expand(q_pos.size(0), -1, -1) + ref_k2 = ref_k.view(1, -1, *ref_k.shape[2:]).expand(q_pos.size(0), -1, -1) + k, v = torch.cat([k_pos, ref_k2], dim=1), torch.cat([v_pos, ref_v2], dim=1) + if weights is not None: - bias = weights.repeat_interleave(attn.heads, dim=0).unsqueeze(1).expand(-1, query.size(1), -1).log() + bias = torch.ones(k.size(1), device=k.device, dtype=k.dtype) + bias[k_pos.size(1):] = weights + bias = bias.log().view(1, 1, -1).expand(attn.heads, query.size(1), -1) if attention_mask is None: attention_mask = bias else: attention_mask += bias - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=None, scale=attn.scale - ) - hidden_states = hidden_states.to(query.dtype) - - + ref_q, ref_k, ref_v, q_pos, k, v, q_neg, k_neg, v_neg = [attn.head_to_batch_dim(x).contiguous() for x in + [ref_q, ref_k, ref_v, q_pos, k, v, q_neg, k_neg, + v_neg]] + ref_hidden_states = xformers.ops.memory_efficient_attention(ref_q, ref_k, ref_v, op=None, scale=attn.scale) + pos_hidden_states = xformers.ops.memory_efficient_attention(q_pos, k, v, attn_bias=attention_mask, op=None, + scale=attn.scale) + neg_hidden_states = xformers.ops.memory_efficient_attention(q_neg, k_neg, v_neg, op=None, scale=attn.scale) + hidden_states = torch.cat([pos_hidden_states, neg_hidden_states, ref_hidden_states], dim=0) + hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj if isinstance(attn.processor, LoRACrossAttnProcessor): - hidden_states = attn.to_out[0]( - hidden_states - ) + lora_scale * attn.processor.to_out_lora(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states) else: hidden_states = attn.to_out[0](hidden_states) # dropout @@ -137,14 +112,8 @@ def attn_with_weights( class AttentionBasedGenerator(nn.Module): - def __init__( - self, - model_name: Optional[str] = None, - model_ckpt: Optional[str] = None, - stable_diffusion_version: str = "1.5", - lora_weights: Optional[str] = None, - torch_dtype=torch.float32 - ): + def __init__(self, model_name: Optional[str] = None, model_ckpt: Optional[str] = None, + stable_diffusion_version: str = "1.5", lora_weights: Optional[str] = None, torch_dtype=torch.float32): super().__init__() if stable_diffusion_version == "2.1": @@ -157,39 +126,31 @@ def __init__( model_name = "stabilityai/stable-diffusion-2-1" else: raise ValueError( - f"Unknown stable diffusion version: {stable_diffusion_version}. Version must be either '1.5' or '2.1'" - ) + f"Unknown stable diffusion version: {stable_diffusion_version}. Version must be either '1.5' or '2.1'") - scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler") + scheduler = Scheduler.from_pretrained(model_name, subfolder="scheduler") if model_ckpt is not None: - pipe = StableDiffusionPipeline.from_ckpt( - model_ckpt, - scheduler=scheduler, - torch_dtype=torch_dtype, - safety_checker=None, - ) + pipe = StableDiffusionPipeline.from_ckpt(model_ckpt, scheduler=scheduler, torch_dtype=torch_dtype, + safety_checker=None, ) pipe.scheduler = scheduler else: - pipe = StableDiffusionPipeline.from_pretrained( - model_name, - scheduler=scheduler, - torch_dtype=torch_dtype, - safety_checker=None, - ) + pipe = StableDiffusionPipeline.from_pretrained(model_name, scheduler=scheduler, torch_dtype=torch_dtype, + safety_checker=None, ) + + pipe.enable_vae_slicing() + pipe.enable_vae_tiling() if lora_weights: print(f"Applying LoRA weights from {lora_weights}") - apply_unet_lora_weights( - pipeline=pipe, unet_path=lora_weights - ) + apply_unet_lora_weights(pipeline=pipe, unet_path=lora_weights) self.pipeline = pipe self.unet = pipe.unet self.vae = pipe.vae self.text_encoder = pipe.text_encoder self.tokenizer = pipe.tokenizer - self.scheduler = scheduler + self.scheduler: Scheduler = scheduler self.dtype = torch_dtype @property @@ -201,26 +162,16 @@ def to(self, device): return super().to(device) def initialize_prompts(self, prompts: List[str]): - prompt_tokens = self.tokenizer( - prompts, - return_tensors="pt", - max_length=self.tokenizer.model_max_length, - padding="max_length", - truncation=True, - ) - - if ( - hasattr(self.text_encoder.config, "use_attention_mask") - and self.text_encoder.config.use_attention_mask - ): + prompt_tokens = self.tokenizer(prompts, return_tensors="pt", max_length=self.tokenizer.model_max_length, + padding="max_length", truncation=True, ) + + if (hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask): attention_mask = prompt_tokens.attention_mask.to(self.device) else: attention_mask = None - prompt_embd = self.text_encoder( - input_ids=prompt_tokens.input_ids.to(self.device), - attention_mask=attention_mask, - ).last_hidden_state + prompt_embd = self.text_encoder(input_ids=prompt_tokens.input_ids.to(self.device), + attention_mask=attention_mask, ).last_hidden_state return prompt_embd @@ -228,7 +179,6 @@ def get_unet_hidden_states(self, z_all, t, prompt_embd): cached_hidden_states = [] for module in self.unet.modules(): if isinstance(module, BasicTransformerBlock): - def new_forward(self, hidden_states, *args, **kwargs): cached_hidden_states.append(hidden_states.clone().detach().cpu()) return self.old_forward(hidden_states, *args, **kwargs) @@ -247,91 +197,24 @@ def new_forward(self, hidden_states, *args, **kwargs): return cached_hidden_states - def unet_forward_with_cached_hidden_states( - self, - z_all, - t, - prompt_embd, - cached_pos_hiddens: Optional[List[torch.Tensor]] = None, - cached_neg_hiddens: Optional[List[torch.Tensor]] = None, - pos_weights=(0.8, 0.8), - neg_weights=(0.5, 0.5), - ): - if cached_pos_hiddens is None and cached_neg_hiddens is None: + def unet_forward_with_cached_hidden_states(self, z_all, t, prompt_embd, + cached_pos_hiddens: Optional[List[torch.Tensor]] = None, + cached_neg_hiddens: Optional[List[torch.Tensor]] = None, + pos_weights=(0.8, 0.8), neg_weights=(0.5, 0.5), + batch: Optional[int] = None): + if batch is None: return self.unet(z_all, t, encoder_hidden_states=prompt_embd) - local_pos_weights = torch.linspace( - *pos_weights, steps=len(self.unet.down_blocks) + 1 - )[:-1].tolist() - local_neg_weights = torch.linspace( - *neg_weights, steps=len(self.unet.down_blocks) + 1 - )[:-1].tolist() - - for block, pos_weight, neg_weight in zip( - self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks, - local_pos_weights + [pos_weights[1]] + local_pos_weights[::-1], - local_neg_weights + [neg_weights[1]] + local_neg_weights[::-1], - ): + local_pos_weights = torch.linspace(*pos_weights, steps=len(self.unet.down_blocks) + 1)[:-1].tolist() + local_neg_weights = torch.linspace(*neg_weights, steps=len(self.unet.down_blocks) + 1)[:-1].tolist() + + for block, pos_weight, neg_weight in zip(self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks, + local_pos_weights + [pos_weights[1]] + local_pos_weights[::-1], + local_neg_weights + [neg_weights[1]] + local_neg_weights[::-1], ): for module in block.modules(): if isinstance(module, BasicTransformerBlock): - - def new_forward( - self, - hidden_states, - pos_weight=pos_weight, - neg_weight=neg_weight, - **kwargs, - ): - cond_hiddens, uncond_hiddens = hidden_states.chunk(2, dim=0) - batch_size, d_model = cond_hiddens.shape[:2] - device, dtype = hidden_states.device, hidden_states.dtype - - weights = torch.ones( - batch_size, d_model, device=device, dtype=dtype - ) - - if cached_pos_hiddens is not None: - cached_pos_hs = cached_pos_hiddens.pop(0).to( - hidden_states.device - ) - cond_pos_hs = torch.cat( - [cond_hiddens, cached_pos_hs], dim=1 - ) - pos_weights = weights.clone().repeat( - 1, 1 + cached_pos_hs.shape[1] // d_model - ) - pos_weights[:, d_model:] = pos_weight - out_pos = attn_with_weights( - self, - cond_hiddens, - encoder_hidden_states=cond_pos_hs, - weights=pos_weights, - ) - else: - out_pos = self.old_forward(cond_hiddens) - - if cached_neg_hiddens is not None: - cached_neg_hs = cached_neg_hiddens.pop(0).to( - hidden_states.device - ) - uncond_neg_hs = torch.cat( - [uncond_hiddens, cached_neg_hs], dim=1 - ) - neg_weights = weights.clone().repeat( - 1, 1 + cached_neg_hs.shape[1] // d_model - ) - neg_weights[:, d_model:] = neg_weight - out_neg = attn_with_weights( - self, - uncond_hiddens, - encoder_hidden_states=uncond_neg_hs, - weights=neg_weights, - ) - else: - out_neg = self.old_forward(uncond_hiddens) - - out = torch.cat([out_pos, out_neg], dim=0) - return out + def new_forward(self, hidden_states, pos_weight=pos_weight, neg_weight=neg_weight, **kwargs, ): + return attn_with_weights(self, hidden_states, weights=pos_weight, batch=batch) module.attn1.old_forward = module.attn1.forward module.attn1.forward = new_forward.__get__(module.attn1) @@ -347,24 +230,13 @@ def new_forward( return out @torch.no_grad() - def generate( - self, - prompt: Union[str, List[str]] = "a photo of an astronaut riding a horse on mars", - negative_prompt: Union[str, List[str]] = "", - liked: List[Image.Image] = [], - disliked: List[Image.Image] = [], - seed: int = 42, - n_images: int = 1, - guidance_scale: float = 8.0, - denoising_steps: int = 20, - feedback_start: float = 0.33, - feedback_end: float = 0.66, - min_weight: float = 0.1, - max_weight: float = 1.0, - neg_scale: float = 0.5, - pos_bottleneck_scale: float = 1.0, - neg_bottleneck_scale: float = 1.0, - ): + def generate(self, prompt: Union[str, List[str]] = "a photo of an astronaut riding a horse on mars", + negative_prompt: Union[str, List[str]] = "", liked: List[Image.Image] = [], + disliked: List[Image.Image] = [], seed: int = 42, n_images: int = 1, guidance_scale: float = 8.0, + denoising_steps: int = 20, feedback_start: float = 0.33, feedback_end: float = 0.66, + min_weight: float = 0.1, max_weight: float = 1.0, neg_scale: float = 0.5, + pos_bottleneck_scale: float = 1.0, neg_bottleneck_scale: float = 1.0, size: int = 2048, + warmup_power: float = 1): """ Generate a trajectory of images with binary feedback. The feedback can be given as a list of liked and disliked images. @@ -372,25 +244,19 @@ def generate( if seed is not None: torch.manual_seed(seed) - z = torch.randn(n_images, 4, 64, 64, device=self.device, dtype=self.dtype) + z = torch.randn(n_images, 4, size // 8, size // 8, device=self.device, dtype=self.dtype) if liked and len(liked) > 0: - pos_images = [self.image_to_tensor(img) for img in liked] + pos_images = [self.image_to_tensor(img, size) for img in liked] pos_images = torch.stack(pos_images).to(self.device, dtype=self.dtype) - pos_latents = ( - self.vae.config.scaling_factor - * self.vae.encode(pos_images).latent_dist.sample() - ) + pos_latents = (self.vae.config.scaling_factor * self.vae.encode(pos_images).latent_dist.sample()) else: pos_latents = torch.tensor([], device=self.device, dtype=self.dtype) if disliked and len(disliked) > 0: - neg_images = [self.image_to_tensor(img) for img in disliked] + neg_images = [self.image_to_tensor(img, size) for img in disliked] neg_images = torch.stack(neg_images).to(self.device, dtype=self.dtype) - neg_latents = ( - self.vae.config.scaling_factor - * self.vae.encode(neg_images).latent_dist.sample() - ) + neg_latents = (self.vae.config.scaling_factor * self.vae.encode(neg_images).latent_dist.sample()) else: neg_latents = torch.tensor([], device=self.device, dtype=self.dtype) @@ -403,11 +269,8 @@ def generate( else: assert len(negative_prompt) == n_images - ( - cond_prompt_embs, - uncond_prompt_embs, - null_prompt_emb, - ) = self.initialize_prompts(prompt + negative_prompt + [""]).split([n_images, n_images, 1]) + (cond_prompt_embs, uncond_prompt_embs, null_prompt_emb,) = self.initialize_prompts( + prompt + negative_prompt + [""]).split([n_images, n_images, 1]) batched_prompt_embd = torch.cat([cond_prompt_embs, uncond_prompt_embs], dim=0) self.scheduler.set_timesteps(denoising_steps, device=self.device) @@ -428,72 +291,53 @@ def generate( sigma = self.scheduler.sigmas[i] else: sigma = 0 - alpha_hat = 1 / (sigma**2 + 1) + alpha_hat = 1 / (sigma ** 2 + 1) z_single = self.scheduler.scale_model_input(z, t) - z_all = torch.cat([z_single] * 2, dim=0) - z_ref = torch.cat([pos_latents, neg_latents], dim=0) - if i >= ref_start_idx and i <= ref_end_idx: - weight = max_weight + do_cfg = i >= ref_start_idx and i <= ref_end_idx + if do_cfg: + z_all = torch.cat([z_single] * 2, dim=0) + z_ref = torch.cat([pos_latents, neg_latents], dim=0) + scale = (ref_end_idx - i) / (ref_end_idx - ref_start_idx) + weight = (max_weight - min_weight) * scale ** warmup_power + min_weight + prompt_embd = batched_prompt_embd + pos_ws = (weight, weight * pos_bottleneck_scale) + neg_ws = (weight * neg_scale, weight * neg_scale * neg_bottleneck_scale) else: - weight = min_weight - pos_ws = (weight, weight * pos_bottleneck_scale) - neg_ws = (weight * neg_scale, weight * neg_scale * neg_bottleneck_scale) - - if z_ref.size(0) > 0 and weight > 0: + z_all = z_single + weight = 0 + prompt_embd = cond_prompt_embs + pos_ws = None + neg_ws = None + + do_fabric = do_cfg and z_ref.size(0) > 0 and weight > 0 + if do_fabric: + batch = z_all.size(0) noise = torch.randn_like(z_ref) - if isinstance(self.scheduler, EulerAncestralDiscreteScheduler): - z_ref_noised = ( - alpha_hat**0.5 * z_ref + (1 - alpha_hat) ** 0.5 * noise - ) - else: - z_ref_noised = self.scheduler.add_noise(z_ref, noise, t) - - ref_prompt_embd = torch.cat([null_prompt_emb] * (pos_latents.size(0) + neg_latents.size(0)), dim=0) - - cached_hidden_states = self.get_unet_hidden_states( - z_ref_noised, t, ref_prompt_embd - ) - - n_pos, n_neg = pos_latents.shape[0], neg_latents.shape[0] - cached_pos_hs, cached_neg_hs = [], [] - for hs in cached_hidden_states: - cached_pos, cached_neg = hs.split([n_pos, n_neg], dim=0) - cached_pos = cached_pos.view( - 1, -1, *cached_pos.shape[2:] - ).expand(n_images, -1, -1) - cached_neg = cached_neg.view( - 1, -1, *cached_neg.shape[2:] - ).expand(n_images, -1, -1) - cached_pos_hs.append(cached_pos) - cached_neg_hs.append(cached_neg) - - if n_pos == 0: - cached_pos_hs = None - if n_neg == 0: - cached_neg_hs = None + z_ref_noised = (alpha_hat ** 0.5 * z_ref + (1 - alpha_hat) ** 0.5 * noise) + + ref_prompt_embd = torch.cat([cond_prompt_embs] * (pos_latents.size(0) + neg_latents.size(0)), dim=0) + + z_all = torch.cat([z_all, z_ref_noised], dim=0) + prompt_embd = torch.cat([prompt_embd, ref_prompt_embd], dim=0) + else: + batch = None + + unet_out = self.unet_forward_with_cached_hidden_states(z_all, t, prompt_embd=prompt_embd, # + pos_weights=pos_ws, neg_weights=neg_ws, + batch=batch).sample + if do_fabric: + unet_out = unet_out[:batch] + if do_cfg: + noise_cond, noise_uncond = unet_out.chunk(2) + guidance = noise_cond - noise_uncond + noise_pred = noise_uncond + guidance_scale * guidance else: - cached_pos_hs, cached_neg_hs = None, None - - unet_out = self.unet_forward_with_cached_hidden_states( - z_all, - t, - prompt_embd=batched_prompt_embd, - cached_pos_hiddens=cached_pos_hs, - cached_neg_hiddens=cached_neg_hs, - pos_weights=pos_ws, - neg_weights=neg_ws, - ).sample - - noise_cond, noise_uncond = unet_out.chunk(2) - guidance = noise_cond - noise_uncond - noise_pred = noise_uncond + guidance_scale * guidance + noise_pred = unet_out z = self.scheduler.step(noise_pred, t, z).prev_sample - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): pbar.update() y = self.pipeline.decode_latents(z) @@ -502,7 +346,7 @@ def generate( return imgs @staticmethod - def image_to_tensor(image: Union[str, Image.Image]): + def image_to_tensor(image: Union[str, Image.Image], size: int): """ Convert a PIL image to a torch tensor. """ @@ -510,7 +354,7 @@ def image_to_tensor(image: Union[str, Image.Image]): image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") - image = image.resize((512, 512)) + image = image.resize((size, size)) image = np.array(image).astype(np.uint8) image = (image / 127.5 - 1.0).astype(np.float32) return torch.from_numpy(image).permute(2, 0, 1) diff --git a/fabric/run_single.py b/fabric/run_single.py index 200a29d..c78a590 100644 --- a/fabric/run_single.py +++ b/fabric/run_single.py @@ -4,60 +4,47 @@ import hydra import torch +from PIL import Image from omegaconf import DictConfig from fabric.generator import AttentionBasedGenerator -from fabric.utils import get_free_gpu, tile_images +from fabric.utils import get_free_gpu @hydra.main(config_path="configs", config_name="single_round", version_base=None) def main(ctx: DictConfig): + print("main") device = "cpu" # "mps" if torch.backends.mps.is_available() else "cpu" device = get_free_gpu() if torch.cuda.is_available() else device print(f"Using device: {device}") - dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + dtype = torch.float16 print(f"Using dtype: {dtype}") - generator = AttentionBasedGenerator( - model_ckpt=ctx.model_ckpt if hasattr(ctx, "model_ckpt") else None, - model_name=ctx.model_name if hasattr(ctx, "model_name") else None, - stable_diffusion_version=ctx.model_version, - torch_dtype=dtype, - ).to(device) + generator = AttentionBasedGenerator(model_ckpt=ctx.model_ckpt if hasattr(ctx, "model_ckpt") else None, + model_name=ctx.model_name if hasattr(ctx, "model_name") else None, stable_diffusion_version=ctx.model_version, + torch_dtype=dtype, ).to(device) + + liked = [Image.open("square.jpg")][:ctx.topk] + + imgs = generator.generate(prompt=ctx.prompt, negative_prompt=ctx.negative_prompt, liked=liked, disliked=[], + seed=ctx.seed, n_images=ctx.n_images, guidance_scale=ctx.guidance_scale, denoising_steps=ctx.denoising_steps, + feedback_start=ctx.feedback.start, feedback_end=ctx.feedback.end, min_weight=ctx.feedback.min_weight, + max_weight=ctx.feedback.max_weight, neg_scale=ctx.feedback.neg_scale, size=ctx.size, + warmup_power=ctx.feedback.warmup_power) - imgs = generator.generate( - prompt=ctx.prompt, - negative_prompt=ctx.negative_prompt, - liked=list(ctx.liked) if ctx.liked else [], - disliked=list(ctx.disliked) if ctx.disliked else [], - seed=ctx.seed, - n_images=ctx.n_images, - guidance_scale=ctx.guidance_scale, - denoising_steps=ctx.denoising_steps, - feedback_start=ctx.feedback.start, - feedback_end=ctx.feedback.end, - min_weight=ctx.feedback.min_weight, - max_weight=ctx.feedback.max_weight, - neg_scale=ctx.feedback.neg_scale, - ) - date_str = date.today().strftime("%Y-%m-%d") out_folder = os.path.join("outputs", "images", date_str) os.makedirs(out_folder, exist_ok=True) - - n_files = max([int(f.split(".")[0].split("_")[1]) for f in os.listdir(out_folder) if re.match(r"example_[0-9_]+\.png", f)], default=0) + 1 + + n_files = max( + [int(f.split(".")[0].split("_")[1]) for f in os.listdir(out_folder) if re.match(r"example_[0-9_]+\.png", f)], + default=0) + 1 for i, img in enumerate(imgs): # each image is of the form example_ID.png. Extract the max id out_path = os.path.join(out_folder, f"example_{n_files}_{i}.png") - img.save(out_path) - print(f"Saved image to {out_path}") - - if len(imgs) > 1: - tiled = tile_images(imgs) - tiled_path = os.path.join(out_folder, f"tiled_{n_files}.png") - tiled.save(tiled_path) - print(f"Saved tile to {tiled_path}") + img.save(out_path) + print(f"Saved image to {out_path}") if __name__ == "__main__":