From 9c1a1278ce03108c1b01fa8fc499b267b88bd565 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 16 Aug 2024 15:37:10 +0800 Subject: [PATCH] Add SD3LatentFormat and CFGDenoiser --- .../models/stable_diffusion_v3/cfg_denoiser.py | 9 +++++++++ .../stable_diffusion_v3/sd3_latent_format.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 keras_cv/src/models/stable_diffusion_v3/cfg_denoiser.py create mode 100644 keras_cv/src/models/stable_diffusion_v3/sd3_latent_format.py diff --git a/keras_cv/src/models/stable_diffusion_v3/cfg_denoiser.py b/keras_cv/src/models/stable_diffusion_v3/cfg_denoiser.py new file mode 100644 index 0000000000..553f31754b --- /dev/null +++ b/keras_cv/src/models/stable_diffusion_v3/cfg_denoiser.py @@ -0,0 +1,9 @@ +from keras import ops + + +class CFGDenoiser: + def __call__(self, batched, cond_scale): + # `batched` is the outputs from `BaseModel.apply_model` + pos_out, neg_out = ops.split(batched, 2, axis=0) + scaled = neg_out + (pos_out - neg_out) * cond_scale + return scaled diff --git a/keras_cv/src/models/stable_diffusion_v3/sd3_latent_format.py b/keras_cv/src/models/stable_diffusion_v3/sd3_latent_format.py new file mode 100644 index 0000000000..c32294ddb4 --- /dev/null +++ b/keras_cv/src/models/stable_diffusion_v3/sd3_latent_format.py @@ -0,0 +1,15 @@ +class SD3LatentFormat: + """Latents are slightly shifted from center. + + This class must be called after VAE Decode to correct for the shift. + """ + + def __init__(self): + self.scale_factor = 1.5305 + self.shift_factor = 0.0609 + + def process_in(self, latent): + return (latent - self.shift_factor) * self.scale_factor + + def process_out(self, latent): + return (latent / self.scale_factor) + self.shift_factor