Skip to content

Commit

Permalink
Add SD3LatentFormat and CFGDenoiser
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Aug 16, 2024
1 parent 0a78ebd commit 9c1a127
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
9 changes: 9 additions & 0 deletions keras_cv/src/models/stable_diffusion_v3/cfg_denoiser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from keras import ops


class CFGDenoiser:
def __call__(self, batched, cond_scale):
# `batched` is the outputs from `BaseModel.apply_model`
pos_out, neg_out = ops.split(batched, 2, axis=0)
scaled = neg_out + (pos_out - neg_out) * cond_scale
return scaled
15 changes: 15 additions & 0 deletions keras_cv/src/models/stable_diffusion_v3/sd3_latent_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class SD3LatentFormat:
"""Latents are slightly shifted from center.
This class must be called after VAE Decode to correct for the shift.
"""

def __init__(self):
self.scale_factor = 1.5305
self.shift_factor = 0.0609

def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor

def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor

0 comments on commit 9c1a127

Please sign in to comment.