Skip to content

Commit

Permalink
refactor: add some sd vesion helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
leejet committed Nov 23, 2024
1 parent 1c168d9 commit b5f4932
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 38 deletions.
8 changes: 4 additions & 4 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct Conditioner {
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
SDVersion version = VERSION_SD1;
PMVersion pm_version = VERSION_1;
PMVersion pm_version = PM_VERSION_1;
CLIPTokenizer tokenizer;
ggml_type wtype;
std::shared_ptr<CLIPTextModelRunner> text_model;
Expand All @@ -60,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
ggml_type wtype,
const std::string& embd_dir,
SDVersion version = VERSION_SD1,
PMVersion pv = VERSION_1,
PMVersion pv = PM_VERSION_1,
int clip_skip = -1)
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
if (clip_skip <= 0) {
Expand Down Expand Up @@ -270,7 +270,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
std::vector<int> clean_input_ids_tmp;
for (uint32_t i = 0; i < class_token_index[0]; i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]);
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
for (uint32_t i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
clean_input_ids_tmp.push_back(class_token);
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]);
Expand All @@ -286,7 +286,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
// weights.insert(weights.begin(), 1.0);

tokenizer.pad_tokens(tokens, weights, max_length, padding);
int offset = pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
for (uint32_t i = 0; i < tokens.size(); i++) {
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs
Expand Down
25 changes: 23 additions & 2 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,30 @@ enum SDVersion {
VERSION_COUNT,
};

static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
return true;
}
return false;
}

static inline bool sd_version_is_sd3(SDVersion version) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
return true;
}
return false;
}

static inline bool sd_version_is_dit(SDVersion version) {
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
return true;
}
return false;
}

enum PMVersion {
VERSION_1,
VERSION_2,
PM_VERSION_1,
PM_VERSION_2,
};

struct TensorStorage {
Expand Down
16 changes: 8 additions & 8 deletions pmid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
struct PhotoMakerIDEncoder : public GGMLRunner {
public:
SDVersion version = VERSION_SDXL;
PMVersion pm_version = VERSION_1;
PMVersion pm_version = PM_VERSION_1;
PhotoMakerIDEncoderBlock id_encoder;
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2;
float style_strength;
Expand All @@ -623,14 +623,14 @@ struct PhotoMakerIDEncoder : public GGMLRunner {
std::vector<float> zeros_right;

public:
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = VERSION_1, float sty = 20.f)
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, float sty = 20.f)
: GGMLRunner(backend, wtype),
version(version),
pm_version(pm_v),
style_strength(sty) {
if (pm_version == VERSION_1) {
if (pm_version == PM_VERSION_1) {
id_encoder.init(params_ctx, wtype);
} else if (pm_version == VERSION_2) {
} else if (pm_version == PM_VERSION_2) {
id_encoder2.init(params_ctx, wtype);
}
}
Expand All @@ -644,9 +644,9 @@ struct PhotoMakerIDEncoder : public GGMLRunner {
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
if (pm_version == VERSION_1)
if (pm_version == PM_VERSION_1)
id_encoder.get_param_tensors(tensors, prefix);
else if (pm_version == VERSION_2)
else if (pm_version == PM_VERSION_2)
id_encoder2.get_param_tensors(tensors, prefix);
}

Expand Down Expand Up @@ -734,14 +734,14 @@ struct PhotoMakerIDEncoder : public GGMLRunner {
}
}
struct ggml_tensor* updated_prompt_embeds = NULL;
if (pm_version == VERSION_1)
if (pm_version == PM_VERSION_1)
updated_prompt_embeds = id_encoder.forward(ctx0,
id_pixel_values_d,
prompt_embeds_d,
class_tokens_mask_d,
class_tokens_mask_pos,
left, right);
else if (pm_version == VERSION_2)
else if (pm_version == PM_VERSION_2)
updated_prompt_embeds = id_encoder2.forward(ctx0,
id_pixel_values_d,
prompt_embeds_d,
Expand Down
46 changes: 23 additions & 23 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ class StableDiffusionGGML {
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
}
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
} else if (sd_version_is_sd3(version)) {
scale_factor = 1.5305f;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
} else if (sd_version_is_flux(version)) {
scale_factor = 0.3611;
// TODO: shift_factor
}
Expand All @@ -309,7 +309,7 @@ class StableDiffusionGGML {
} else {
clip_backend = backend;
bool use_t5xxl = false;
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
if (sd_version_is_dit(version)) {
use_t5xxl = true;
}
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
Expand All @@ -323,18 +323,18 @@ class StableDiffusionGGML {
if (diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
}
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
if (sd_version_is_sd3(version)) {
if (diffusion_flash_attn) {
LOG_WARN("flash attention in this diffusion model is currently unsupported!");
}
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
} else if (sd_version_is_flux(version)) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
} else {
if (id_embeddings_path.find("v2") != std::string::npos) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, VERSION_2);
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, PM_VERSION_2);
} else {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version);
}
Expand Down Expand Up @@ -373,7 +373,7 @@ class StableDiffusionGGML {
}

if (id_embeddings_path.find("v2") != std::string::npos) {
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, VERSION_2);
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, PM_VERSION_2);
LOG_INFO("using PhotoMaker Version 2");
} else {
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version);
Expand Down Expand Up @@ -527,10 +527,10 @@ class StableDiffusionGGML {
is_using_v_parameterization = true;
}

if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
if (sd_version_is_sd3(version)) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
} else if (sd_version_is_flux(version)) {
LOG_INFO("running in Flux FLOW mode");
float shift = 1.15f;
if (version == VERSION_FLUX_SCHNELL) {
Expand Down Expand Up @@ -804,7 +804,7 @@ class StableDiffusionGGML {
out_uncond = ggml_dup_tensor(work_ctx, x);
}
if (has_skiplayer) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (sd_version_is_dit(version)) {
out_skip = ggml_dup_tensor(work_ctx, x);
} else {
has_skiplayer = false;
Expand Down Expand Up @@ -995,9 +995,9 @@ class StableDiffusionGGML {
if (use_tiny_autoencoder) {
C = 4;
} else {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
if (sd_version_is_sd3(version)) {
C = 32;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
} else if (sd_version_is_flux(version)) {
C = 32;
}
}
Expand Down Expand Up @@ -1214,7 +1214,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
}
// preprocess input id images
std::vector<sd_image_t*> input_id_images;
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == VERSION_2;
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2;
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
for (std::string img_file : img_files) {
Expand Down Expand Up @@ -1343,9 +1343,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
// Sample
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
if (sd_version_is_sd3(sd_ctx->sd->version)) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16;
}
int W = width / 8;
Expand Down Expand Up @@ -1464,10 +1464,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,

struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
if (sd_version_is_sd3(sd_ctx->sd->version)) {
params.mem_size *= 3;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
if (sd_version_is_flux(sd_ctx->sd->version)) {
params.mem_size *= 4;
}
if (sd_ctx->sd->stacked_id) {
Expand All @@ -1490,17 +1490,17 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);

int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
if (sd_version_is_sd3(sd_ctx->sd->version)) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16;
}
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
if (sd_version_is_sd3(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.0609f);
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
ggml_set_f32(init_latent, 0.1159f);
} else {
ggml_set_f32(init_latent, 0.f);
Expand Down Expand Up @@ -1567,10 +1567,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,

struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
if (sd_version_is_sd3(sd_ctx->sd->version)) {
params.mem_size *= 2;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
if (sd_version_is_flux(sd_ctx->sd->version)) {
params.mem_size *= 3;
}
if (sd_ctx->sd->stacked_id) {
Expand Down
2 changes: 1 addition & 1 deletion vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ class AutoencodingEngine : public GGMLBlock {
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
if (sd_version_is_dit(version)) {
dd_config.z_channels = 16;
use_quant = false;
}
Expand Down

0 comments on commit b5f4932

Please sign in to comment.