From 9b1d90bc238a678a9a63cd243a5d38403f37a081 Mon Sep 17 00:00:00 2001 From: stduhpf Date: Sat, 23 Nov 2024 04:19:27 +0100 Subject: [PATCH] fix: improve clip text_projection support (#397) --- clip.hpp | 8 +++-- conditioner.hpp | 79 +++++++++++++++++++++---------------------------- 2 files changed, 40 insertions(+), 47 deletions(-) diff --git a/clip.hpp b/clip.hpp index f9ac631a..bf2a8c14 100644 --- a/clip.hpp +++ b/clip.hpp @@ -711,8 +711,12 @@ class CLIPTextModel : public GGMLBlock { if (return_pooled) { auto text_projection = params["text_projection"]; ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx); - pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, text_projection)), pooled); - return pooled; + if (text_projection != NULL) { + pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL); + } else { + LOG_DEBUG("Missing text_projection matrix, assuming identity..."); + } + return pooled; // [hidden_size, 1, 1] } return x; // [N, n_token, hidden_size] diff --git a/conditioner.hpp b/conditioner.hpp index ac2ab7eb..9f9d5ae1 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -798,21 +798,17 @@ struct SD3CLIPEmbedder : public Conditioner { } if (chunk_idx == 0) { - // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); - // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - // clip_l->compute(n_threads, - // input_ids, - // 0, - // NULL, - // max_token_idx, - // true, - // &pooled_l, - // work_ctx); - - // clip_l.transformer.text_model.text_projection no in file, ignore - // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection - pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768); - ggml_set_f32(pooled_l, 0.f); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + clip_l->compute(n_threads, + input_ids, + 0, + NULL, + max_token_idx, + true, + &pooled_l, + work_ctx); + } } @@ -852,21 +848,17 @@ struct SD3CLIPEmbedder : public Conditioner { } if (chunk_idx == 0) { - // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); - // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - // clip_g->compute(n_threads, - // input_ids, - // 0, - // NULL, - // max_token_idx, - // true, - // &pooled_g, - // work_ctx); - // clip_l.transformer.text_model.text_projection no in file, ignore pooled_g too - - // TODO: fix pooled_g - pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280); - ggml_set_f32(pooled_g, 0.f); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); + max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + clip_g->compute(n_threads, + input_ids, + 0, + NULL, + max_token_idx, + true, + &pooled_g, + work_ctx); + } } @@ -1104,21 +1096,18 @@ struct FluxCLIPEmbedder : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); size_t max_token_idx = 0; - // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); - // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - // clip_l->compute(n_threads, - // input_ids, - // 0, - // NULL, - // max_token_idx, - // true, - // &pooled, - // work_ctx); - - // clip_l.transformer.text_model.text_projection no in file, ignore - // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection - pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768); - ggml_set_f32(pooled, 0.f); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + + clip_l->compute(n_threads, + input_ids, + 0, + NULL, + max_token_idx, + true, + &pooled, + work_ctx); + } // t5