From 2d6c545fcc22bc4c3e223c16c85e4cfd537658bc Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Wed, 4 Dec 2024 09:10:33 -0800 Subject: [PATCH] Add 448 resolution to PaliGemma and PaliGemma2. PiperOrigin-RevId: 702747814 --- gemma/common.cc | 12 +++++-- gemma/configs.cc | 18 ++++++++-- gemma/configs.h | 7 +++- gemma/run.cc | 3 +- gemma/tensor_index.cc | 2 +- gemma/weights.h | 3 +- paligemma/image.cc | 28 +++++++-------- paligemma/image.h | 17 +++++----- paligemma/image_test.cc | 68 ++++++++++++++++++++++++++++++++++--- paligemma/paligemma_test.cc | 33 +++++++++++++++--- 10 files changed, 150 insertions(+), 41 deletions(-) diff --git a/gemma/common.cc b/gemma/common.cc index 6678c0bd..ba073df1 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -39,8 +39,11 @@ constexpr const char* kModelFlags[] = { "9b-pt", "9b-it", // Gemma2 9B "27b-pt", "27b-it", // Gemma2 27B "paligemma-224", // PaliGemma 224 + "paligemma-448", // PaliGemma 448 "paligemma2-3b-224", // PaliGemma2 3B 224 + "paligemma2-3b-448", // PaliGemma2 3B 448 "paligemma2-10b-224", // PaliGemma2 10B 224 + "paligemma2-10b-448", // PaliGemma2 10B 448 }; constexpr Model kModelTypes[] = { Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B @@ -51,8 +54,11 @@ constexpr Model kModelTypes[] = { Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B Model::PALIGEMMA_224, // PaliGemma 224 + Model::PALIGEMMA_448, // PaliGemma 448 Model::PALIGEMMA2_3B_224, // PaliGemma2 3B 224 + Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448 Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224 + Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448 }; constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B @@ -62,9 +68,9 @@ constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B - ModelTraining::PALIGEMMA, // PaliGemma 224 - ModelTraining::PALIGEMMA, // PaliGemma2 3B 224 - ModelTraining::PALIGEMMA, // PaliGemma2 10B 224 + ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PaliGemma 224 / 448 + ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 3B 224 / 448 + ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 10B 224 / 448 }; constexpr size_t kNumModelFlags = diff --git a/gemma/configs.cc b/gemma/configs.cc index b219e76c..b6edc8c8 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -205,10 +205,10 @@ static ModelConfig ConfigGriffin2B() { } // Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config. -static void AddVitConfig(ModelConfig& config) { +static void AddVitConfig(ModelConfig& config, size_t image_size = 224) { config.vit_model_dim = 1152; config.vocab_size = 256000 + 1024 + 128; // = 257152 - config.image_size = 224; + config.image_size = image_size; config.patch_width = 14; for (auto& layer_config : config.layer_configs) { layer_config.optimized_gating = false; @@ -236,6 +236,14 @@ static ModelConfig ConfigPaliGemma_224() { return config; } +static ModelConfig ConfigPaliGemma_448() { + ModelConfig config = ConfigGemma2B(); + config.model_name = "PaliGemma_448"; + config.model = Model::PALIGEMMA_448; + AddVitConfig(config, /*image_size=*/448); + return config; +} + ModelConfig VitConfig(const ModelConfig& config) { ModelConfig vit_config = ConfigNoSSM(); vit_config.model_dim = config.vit_model_dim; @@ -280,10 +288,16 @@ ModelConfig ConfigFromModel(Model model) { return ConfigGemmaTiny(); case Model::PALIGEMMA_224: return ConfigPaliGemma_224(); + case Model::PALIGEMMA_448: + return ConfigPaliGemma_448(); case Model::PALIGEMMA2_3B_224: return ConfigPaliGemma2_3B_224(); + case Model::PALIGEMMA2_3B_448: + return ConfigPaliGemma2_3B_448(); case Model::PALIGEMMA2_10B_224: return ConfigPaliGemma2_10B_224(); + case Model::PALIGEMMA2_10B_448: + return ConfigPaliGemma2_10B_448(); default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } diff --git a/gemma/configs.h b/gemma/configs.h index 8a636653..52bf8b9c 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -114,15 +114,20 @@ enum class Model { GEMMA_TINY, GEMMA2_2B, PALIGEMMA_224, + PALIGEMMA_448, PALIGEMMA2_3B_224, + PALIGEMMA2_3B_448, PALIGEMMA2_10B_224, + PALIGEMMA2_10B_448, }; // Allows the Model enum to be iterated over. static constexpr Model kAllModels[] = { Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B, Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B, - Model::PALIGEMMA_224, Model::PALIGEMMA2_3B_224, Model::PALIGEMMA2_10B_224, + Model::PALIGEMMA_224, Model::PALIGEMMA_448, + Model::PALIGEMMA2_3B_224, Model::PALIGEMMA2_3B_448, + Model::PALIGEMMA2_10B_224, Model::PALIGEMMA2_10B_448, }; struct LayerConfig { diff --git a/gemma/run.cc b/gemma/run.cc index 87c7c9dd..d21c5d65 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -97,7 +97,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, model.GetModelConfig().model_dim)); HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); HWY_ASSERT(image.ReadPPM(args.image_file.path)); - image.Resize(); + const size_t image_size = model.GetModelConfig().image_size; + image.Resize(image_size, image_size); RuntimeConfig runtime_config = { .gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin}; double image_tokens_start = hwy::platform::Now(); diff --git a/gemma/tensor_index.cc b/gemma/tensor_index.cc index 9f44a9c1..68d9e490 100644 --- a/gemma/tensor_index.cc +++ b/gemma/tensor_index.cc @@ -80,7 +80,7 @@ std::vector ModelTensors(const ModelConfig& config) { .name = "img_pos_emb", .source_names = {"img/pos_embedding"}, .axes = {0, 1}, - .shape = {/*1,*/ 256, config.vit_model_dim}, + .shape = {/*1,*/ config.vit_seq_len, config.vit_model_dim}, .min_size = Type::kF32, }, }; diff --git a/gemma/weights.h b/gemma/weights.h index b9acf899..ecd917b4 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -351,7 +351,8 @@ struct ModelWeightsPtrs { vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim), vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim, config.patch_width * config.patch_width * 3), - vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim), + vit_img_pos_embedding("img_pos_emb", config.vit_seq_len, + config.vit_model_dim), vit_img_head_bias("img_head_bias", 1, config.model_dim), vit_img_head_kernel("img_head_kernel", config.model_dim, config.vit_model_dim), diff --git a/paligemma/image.cc b/paligemma/image.cc index b642a4fb..93e12532 100644 --- a/paligemma/image.cc +++ b/paligemma/image.cc @@ -36,10 +36,8 @@ namespace gcpp { namespace { -// Hardcoded for PaliGemma-224 ViT input. +// Hardcoded for PaliGemma ViT input. constexpr size_t kPatchSize = 14; -constexpr size_t kImageSize = 224; -constexpr size_t kNumPatches = kImageSize / kPatchSize; // 16 // Returns the linearly scaled index in [0, to_size) closest to the // value in [0, from_size). @@ -174,9 +172,7 @@ void Image::Set(int width, int height, const float* data) { } } -void Image::Resize() { - int new_width = 224; - int new_height = kImageSize; +void Image::Resize(int new_width, int new_height) { std::vector new_data(new_width * new_height * 3); // TODO: go to bilinear interpolation, or antialias. // E.g. consider WeightsSymmetric3Lowpass and SlowSymmetric3 from @@ -211,18 +207,20 @@ bool Image::WriteBinary(const std::string& filename) const { return true; } -// Image.data() is kImageSize x kImageSize x 3, H x W x C. -// We want the N-th patch (of 256) of size kPatchSize x kPatchSize x 3. -// Patches are numbered in usual "pixel-order". +// Image.data() is H x W x 3. +// We want the N-th patch of size kPatchSize x kPatchSize x 3. void Image::GetPatch(size_t patch_num, float* patch) const { PROFILER_FUNC; - constexpr size_t kDataSize = kImageSize * kImageSize * 3; + const size_t kDataSize = width_ * height_ * 3; HWY_ASSERT(size() == kDataSize); constexpr size_t kPatchDataSize = kPatchSize * kPatchSize * 3; - size_t i_offs = patch_num / kNumPatches; - size_t j_offs = patch_num % kNumPatches; - HWY_ASSERT(0 <= i_offs && i_offs < kNumPatches); - HWY_ASSERT(0 <= j_offs && j_offs < kNumPatches); + HWY_ASSERT(width_ % kPatchSize == 0); + HWY_ASSERT(height_ % kPatchSize == 0); + const size_t kNumPatchesPerRow = width_ / kPatchSize; + size_t i_offs = patch_num / kNumPatchesPerRow; + size_t j_offs = patch_num % kNumPatchesPerRow; + HWY_ASSERT(0 <= i_offs && i_offs < height_ / kPatchSize); + HWY_ASSERT(0 <= j_offs && j_offs < kNumPatchesPerRow); i_offs *= kPatchSize; j_offs *= kPatchSize; // This can be made faster, but let's first see whether it matters. @@ -233,7 +231,7 @@ void Image::GetPatch(size_t patch_num, float* patch) const { const size_t patch_index = (i * kPatchSize + j) * 3 + k; HWY_ASSERT(patch_index < kPatchDataSize); const size_t image_index = - ((i + i_offs) * kImageSize + (j + j_offs)) * 3 + k; + ((i + i_offs) * width_ + (j + j_offs)) * 3 + k; HWY_ASSERT(image_index < kDataSize); patch[patch_index] = image_data[image_index]; } diff --git a/paligemma/image.h b/paligemma/image.h index 0ed88ab3..e0b15308 100644 --- a/paligemma/image.h +++ b/paligemma/image.h @@ -24,8 +24,7 @@ namespace gcpp { -// Very basic image loading and processing for PaliGemma-224. Does not try to be -// generic at the moment, e.g. the size to normalize to is hardcoded. +// Very basic image loading and processing for PaliGemma. class Image { public: Image() = default; @@ -38,15 +37,17 @@ class Image { // Sets the image content to the given data. The data is copied and normalized // to [-1, 1]. The data is expected to be of size width * height * 3. void Set(int width, int height, const float* data); - // Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would - // be better). - void Resize(); + // Resizes to width x height (nearest-neighbor for now, bilinear or antialias + // would be better). + void Resize(int width, int height); // Writes the file as plain floats in binary. Useful to e.g. load in a colab. bool WriteBinary(const std::string& filename) const; - // Stores the patch for the given patch number [0, 256) in `patch`. - // As sizes are hardcoded, the patch number is sufficient here. + // Stores the patch for the given patch number in `patch`. + // Patches are numbered in usual raster-order. E.g. for an image of size + // 224 x 224, there are 16 x 16 = 256 patches. // `patch` should have space for at least 14 * 14 * 3 = 588 floats. - // Requires that Normalize() has been called. + // Requires that Normalize() has been called and that the image width and + // height are multiples of 14. void GetPatch(size_t patch_num, float* patch) const; float *data() { return data_.data(); } diff --git a/paligemma/image_test.cc b/paligemma/image_test.cc index 8c7a46a8..f114fe51 100644 --- a/paligemma/image_test.cc +++ b/paligemma/image_test.cc @@ -15,16 +15,20 @@ #include "paligemma/image.h" +#include #include +#include #include "gtest/gtest.h" namespace gcpp { namespace { -float Normalize(int value) { return 2.0f * (value / 255.0f) - 1.0f; } +float Normalize(float value, float max_value = 255.0f) { + return 2.0f * (value / max_value) - 1.0f; +} -TEST(ImageTest, BasicFunctionality) { +TEST(ImageTest, LoadResize224GetPatch) { return; // Need to figure out how to get the external path for the test file. std::string path; Image image; @@ -48,7 +52,7 @@ TEST(ImageTest, BasicFunctionality) { EXPECT_EQ(image.data()[33], Normalize(164)); EXPECT_EQ(image.data()[34], Normalize(185)); EXPECT_EQ(image.data()[35], Normalize(191)); - image.Resize(); + image.Resize(224, 224); // Check first and last pixel. EXPECT_EQ(image.data()[0], Normalize(160)); EXPECT_EQ(image.data()[1], Normalize(184)); @@ -63,10 +67,66 @@ TEST(ImageTest, BasicFunctionality) { EXPECT_EQ(patch[1], Normalize(184)); EXPECT_EQ(patch[2], Normalize(188)); image.GetPatch(18, patch); - for (int i = 0; i < 10; ++i) { + // Check the first row of the patch. + for (size_t i = 0; i < 14 * 3; ++i) { EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]); } } +TEST(ImageTest, Non224) { + std::vector data(28 * 42 * 3); + for (int i = 0; i < data.size(); ++i) { + data[i] = static_cast(i); + } + float max_value = data.back(); + Image image; + image.Set(28, 42, data.data()); + EXPECT_EQ(image.width(), 28); + EXPECT_EQ(image.height(), 42); + EXPECT_EQ(image.size(), data.size()); + // Resize 28 x 42 -> 56 x 42, "double" each pixel horizontally. + image.Resize(/*new_width=*/56, /*new_height=*/42); + // Check a few pixels. + EXPECT_NEAR(image.data()[0], Normalize(0.0f, max_value), 1e-6); + EXPECT_NEAR(image.data()[1], Normalize(1.0f, max_value), 1e-6); + EXPECT_NEAR(image.data()[2], Normalize(2.0f, max_value), 1e-6); + EXPECT_NEAR(image.data()[3], Normalize(0.0f, max_value), 1e-6); + EXPECT_NEAR(image.data()[4], Normalize(1.0f, max_value), 1e-6); + EXPECT_NEAR(image.data()[5], Normalize(2.0f, max_value), 1e-6); + EXPECT_NEAR(image.data()[6], Normalize(3.0f, max_value), 1e-6); + EXPECT_NEAR(image.data()[image.size() - 9], + Normalize(data.size() - 6, max_value), 1e-6); + EXPECT_NEAR(image.data()[image.size() - 8], + Normalize(data.size() - 5, max_value), 1e-6); + EXPECT_NEAR(image.data()[image.size() - 7], + Normalize(data.size() - 4, max_value), 1e-6); + EXPECT_NEAR(image.data()[image.size() - 3], + Normalize(data.size() - 3, max_value), 1e-6); + EXPECT_NEAR(image.data()[image.size() - 2], + Normalize(data.size() - 2, max_value), 1e-6); + EXPECT_NEAR(image.data()[image.size() - 1], + Normalize(data.size() - 1, max_value), 1e-6); + // Extract two patches. + const size_t kPatchValues = 14 * 14 * 3; // = 588 + float patch[kPatchValues]; + // Patch 0 is just the "start" of the image. + image.GetPatch(0, patch); + EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6); + EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6); + EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6); + // The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its + // pixel coordinates are offset by (14, 28). + image.GetPatch(6, patch); + for (size_t n = 0; n < kPatchValues; ++n) { + size_t k = n % 3; + size_t j = ((n - k) / 3) % 14; + size_t i = (n - k - j * 3) / (14 * 3); + EXPECT_EQ(n, (i * 14 + j) * 3 + k); + i += 14; + j += 28; + EXPECT_EQ(patch[n], image.data()[(i * 56 + j) * 3 + k]); + } +} + } // namespace } // namespace gcpp diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 64c0ee82..53c3b79b 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -28,7 +28,7 @@ // To run the test, pass the following flags: // --model paligemma-224 --tokenizer --weights // It should pass for the following models: -// paligemma-3b-mix-224 +// paligemma-3b-mix-224, paligemma2-3b-pt-448 namespace gcpp { namespace { @@ -55,7 +55,8 @@ void PaliGemmaTest::InitVit(const std::string& path) { Image image; HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); HWY_ASSERT(image.ReadPPM(path)); - image.Resize(); + const size_t image_size = model.GetModelConfig().image_size; + image.Resize(image_size, image_size); RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0}; model.GenerateImageTokens(runtime_config, image, image_tokens_); } @@ -102,7 +103,8 @@ void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) { } TEST_F(PaliGemmaTest, General) { - static const char* kQA[][2] = { + ASSERT_NE(s_env->GetModel(), nullptr); + static const char* kQA_3B_mix_224[][2] = { {"describe this image", "A large building with two towers stands tall on the water's edge."}, {"describe image briefly", @@ -113,8 +115,29 @@ TEST_F(PaliGemmaTest, General) { {"segment water", " water"}, {"Which city is this more likely? Tokio or Zurich?", "zurich"}, }; - static const size_t kNum = sizeof(kQA) / sizeof(kQA[0]); - TestQuestions(kQA, kNum); + static const char* kQA_2_3B_pt_448[][2] = { + {"describe this image", "The Grossmünster in Zürich"}, + {"describe image briefly", "The Grossmünster"}, + {"answer en What objects are in the image?", "Building, Tower"}, + {"segment water", " water"}, + }; + const char* (*qa)[2]; + size_t num; + switch (s_env->GetModel()->Info().model) { + case Model::PALIGEMMA_224: + qa = kQA_3B_mix_224; + num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]); + break; + case Model::PALIGEMMA2_3B_448: + qa = kQA_2_3B_pt_448; + num = sizeof(kQA_2_3B_pt_448) / sizeof(kQA_2_3B_pt_448[0]); + break; + default: + FAIL() << "Unsupported model: " + << s_env->GetModel()->GetModelConfig().model_name; + break; + } + TestQuestions(qa, num); } } // namespace