From 58185c21aaa20fe5fc1c9bd5f2d5da9df6a6e38b Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Wed, 27 Nov 2024 11:31:59 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 700761457 --- gemma/common.cc | 6 ++++++ gemma/configs.cc | 20 ++++++++++++++++++++ gemma/configs.h | 4 +++- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/gemma/common.cc b/gemma/common.cc index 447deb6..6678c0b 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -39,6 +39,8 @@ constexpr const char* kModelFlags[] = { "9b-pt", "9b-it", // Gemma2 9B "27b-pt", "27b-it", // Gemma2 27B "paligemma-224", // PaliGemma 224 + "paligemma2-3b-224", // PaliGemma2 3B 224 + "paligemma2-10b-224", // PaliGemma2 10B 224 }; constexpr Model kModelTypes[] = { Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B @@ -49,6 +51,8 @@ constexpr Model kModelTypes[] = { Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B Model::PALIGEMMA_224, // PaliGemma 224 + Model::PALIGEMMA2_3B_224, // PaliGemma2 3B 224 + Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224 }; constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B @@ -59,6 +63,8 @@ constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B ModelTraining::PALIGEMMA, // PaliGemma 224 + ModelTraining::PALIGEMMA, // PaliGemma2 3B 224 + ModelTraining::PALIGEMMA, // PaliGemma2 10B 224 }; constexpr size_t kNumModelFlags = diff --git a/gemma/configs.cc b/gemma/configs.cc index 7a792cf..b219e76 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -246,6 +246,22 @@ ModelConfig VitConfig(const ModelConfig& config) { return vit_config; } +static ModelConfig ConfigPaliGemma2_3B_224() { + ModelConfig config = ConfigGemma2_2B(); + config.model_name = "PaliGemma2_3B_224"; + config.model = Model::PALIGEMMA2_3B_224; + AddVitConfig(config); + return config; +} + +static ModelConfig ConfigPaliGemma2_10B_224() { + ModelConfig config = ConfigGemma2_9B(); + config.model_name = "PaliGemma2_10B_224"; + config.model = Model::PALIGEMMA2_10B_224; + AddVitConfig(config); + return config; +} + ModelConfig ConfigFromModel(Model model) { switch (model) { case Model::GEMMA_2B: @@ -264,6 +280,10 @@ ModelConfig ConfigFromModel(Model model) { return ConfigGemmaTiny(); case Model::PALIGEMMA_224: return ConfigPaliGemma_224(); + case Model::PALIGEMMA2_3B_224: + return ConfigPaliGemma2_3B_224(); + case Model::PALIGEMMA2_10B_224: + return ConfigPaliGemma2_10B_224(); default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } diff --git a/gemma/configs.h b/gemma/configs.h index 6bbbc45..8a63665 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -114,13 +114,15 @@ enum class Model { GEMMA_TINY, GEMMA2_2B, PALIGEMMA_224, + PALIGEMMA2_3B_224, + PALIGEMMA2_10B_224, }; // Allows the Model enum to be iterated over. static constexpr Model kAllModels[] = { Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B, Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B, - Model::PALIGEMMA_224, + Model::PALIGEMMA_224, Model::PALIGEMMA2_3B_224, Model::PALIGEMMA2_10B_224, }; struct LayerConfig {