diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 6d3522f..cc7cef7 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -49,7 +49,7 @@ TEST(OptimizeTest, GradientDescent) { const ModelInfo info = { .model = Model::GEMMA_TINY, - .training = ModelTraining::GEMMA_IT, + .wrapping = PromptWrapping::GEMMA_IT, .weight = Type::kF32, }; ModelConfig config = ConfigFromModel(info.model); diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc index ff607e7..66e95ec 100644 --- a/compression/compress_weights.cc +++ b/compression/compress_weights.cc @@ -40,8 +40,8 @@ #include #include "compression/compress.h" -#include "compression/shared.h" // ModelTraining #include "compression/io.h" // Path +#include "compression/shared.h" // PromptWrapping #include "gemma/common.h" // Model #include "gemma/weights.h" #include "util/allocator.h" @@ -74,8 +74,8 @@ struct Args : public ArgsBase { // Returns error string or nullptr if OK. const char* Validate() { - if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_, - model_training_)) { + if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_, + prompt_wrapping_)) { return err; } if (const char* err = ParseType(weight_type_str, weight_type_)) { @@ -127,12 +127,12 @@ struct Args : public ArgsBase { // Uninitialized before Validate, must call after that. gcpp::Model ModelType() const { return model_type_; } - gcpp::ModelTraining ModelTrainingType() const { return model_training_; } + gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; } gcpp::Type WeightType() const { return weight_type_; } private: Model model_type_; - ModelTraining model_training_; + PromptWrapping prompt_wrapping_; Type weight_type_; }; @@ -212,7 +212,7 @@ namespace gcpp { void Run(Args& args) { hwy::ThreadPool pool(args.num_threads); - if (args.ModelTrainingType() == ModelTraining::PALIGEMMA) { + if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) { HWY_ABORT("PaliGemma is not supported in compress_weights."); } const Model model_type = args.ModelType(); diff --git a/compression/shared.h b/compression/shared.h index 40c8f1c..8a70873 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -196,7 +196,7 @@ constexpr bool IsNuqStream() { } // Instruction-tuned models require extra 'turn structure' tokens in prompts. -enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA }; +enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA }; // Tensor types for loading weights. Note that not all types are supported as // weights for a model, but can be used for other purposes, such as types for diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 1ea4f65..8682189 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -187,7 +187,7 @@ int main(int argc, char** argv) { const std::string golden_path = benchmark_args.goldens.path + "/" + gcpp::ModelString(env.GetModel()->Info().model, - env.GetModel()->Info().training) + + env.GetModel()->Info().wrapping) + ".txt"; return BenchmarkGoldens(env, golden_path); } else if (!benchmark_args.summarize_text.Empty()) { diff --git a/gemma/common.cc b/gemma/common.cc index f1e6ff2..dc37e3e 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -60,25 +60,25 @@ constexpr Model kModelTypes[] = { Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224 Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448 }; -constexpr ModelTraining kModelTraining[] = { - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma - ModelTraining::GEMMA_IT, // Gemma Tiny - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B - ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PaliGemma 224 / 448 - ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 3B 224 / 448 - ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 10B 224 / 448 +constexpr PromptWrapping kPromptWrapping[] = { + PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B + PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B + PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // RecurrentGemma + PromptWrapping::GEMMA_IT, // Gemma Tiny + PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 2B + PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B + PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B + PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448 + PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448 + PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448 }; constexpr size_t kNumModelFlags = std::size(kModelFlags); static_assert(kNumModelFlags == std::size(kModelTypes)); -static_assert(kNumModelFlags == std::size(kModelTraining)); +static_assert(kNumModelFlags == std::size(kPromptWrapping)); -const char* ParseModelTypeAndTraining(const std::string& model_flag, - Model& model, ModelTraining& training) { +const char* ParseModelTypeAndWrapping(const std::string& model_flag, + Model& model, PromptWrapping& wrapping) { static std::string kErrorMessageBuffer = "Invalid or missing model flag, need to specify one of "; for (size_t i = 0; i + 1 < kNumModelFlags; ++i) { @@ -93,21 +93,21 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag, for (size_t i = 0; i < kNumModelFlags; ++i) { if (kModelFlags[i] == model_type_lc) { model = kModelTypes[i]; - training = kModelTraining[i]; - HWY_ASSERT(std::string(ModelString(model, training)) == model_type_lc); + wrapping = kPromptWrapping[i]; + HWY_ASSERT(std::string(ModelString(model, wrapping)) == model_type_lc); return nullptr; } } return kErrorMessageBuffer.c_str(); } -const char* ModelString(Model model, ModelTraining training) { +const char* ModelString(Model model, PromptWrapping wrapping) { for (size_t i = 0; i < kNumModelFlags; i++) { - if (kModelTypes[i] == model && kModelTraining[i] == training) + if (kModelTypes[i] == model && kPromptWrapping[i] == wrapping) return kModelFlags[i]; } - HWY_ABORT("Unknown model %d training %d\n", static_cast(model), - static_cast(training)); + HWY_ABORT("Unknown model %d wrapping %d\n", static_cast(model), + static_cast(wrapping)); } const char* StringFromType(Type type) { @@ -139,7 +139,7 @@ const char* ParseType(const std::string& type_string, Type& type) { void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { // Instruction-tuned models are trained to expect control tokens. - if (info.training == ModelTraining::GEMMA_IT) { + if (info.wrapping == PromptWrapping::GEMMA_IT) { // Prepend "" if this is a multi-turn dialogue continuation. const std::string start = (pos == 0) ? "user\n" diff --git a/gemma/common.h b/gemma/common.h index 6b5539a..984b0ba 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -20,7 +20,7 @@ #include -#include "compression/shared.h" // ModelTraining +#include "compression/shared.h" // PromptWrapping #include "gemma/configs.h" // IWYU pragma: export #include "hwy/base.h" // ConvertScalarTo @@ -29,18 +29,18 @@ namespace gcpp { // Struct to bundle model information. struct ModelInfo { Model model; - ModelTraining training; + PromptWrapping wrapping; Type weight; }; // Returns error string or nullptr if OK. // Thread-hostile. -const char* ParseModelTypeAndTraining(const std::string& model_flag, - Model& model, ModelTraining& training); +const char* ParseModelTypeAndWrapping(const std::string& model_flag, + Model& model, PromptWrapping& wrapping); const char* ParseType(const std::string& type_string, Type& type); -// Inverse of ParseModelTypeAndTraining. -const char* ModelString(Model model, ModelTraining training); +// Inverse of ParseModelTypeAndWrapping. +const char* ModelString(Model model, PromptWrapping wrapping); const char* StringFromType(Type type); // Wraps the given prompt using the expected control tokens for IT models. diff --git a/gemma/configs.cc b/gemma/configs.cc index 37eed6a..8a714c1 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -366,13 +366,13 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial, bool ModelConfig::TestEqual(const ModelConfig& other, bool partial, bool debug) const { bool result = true; - // We don't care about model_name, model, training, or weight being different, + // We don't care about model_name, model, wrapping, or weight being different, // but will output in debug mode if they are. if (debug) { WARN_IF_NOT_EQUAL(model_name, other.model_name); WARN_IF_NOT_EQUAL(static_cast(model), static_cast(other.model)); - WARN_IF_NOT_EQUAL(static_cast(training), - static_cast(other.training)); + WARN_IF_NOT_EQUAL(static_cast(wrapping), + static_cast(other.wrapping)); WARN_IF_NOT_EQUAL(static_cast(weight), static_cast(other.weight)); } TEST_EQUAL(model_dim, other.model_dim); diff --git a/gemma/configs.h b/gemma/configs.h index 52bf8b9..9c33b17 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -200,7 +200,7 @@ struct ModelConfig { std::string model_name; Model model; - ModelTraining training; + PromptWrapping wrapping; Type weight; size_t num_layers = 0; size_t model_dim = 0; diff --git a/gemma/gemma.h b/gemma/gemma.h index 5b84053..1ad8717 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -244,7 +244,7 @@ class Gemma { ModelInfo info_; }; -// Adds BOS token and possibly 'turn' annotations, which depend on `training` +// Adds BOS token and possibly 'turn' annotations, which depend on `info` // and `pos`, the number of tokens decoded so far; returns the corresponding // tokens. Asserts that tokenization is successful. std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, diff --git a/gemma/run.cc b/gemma/run.cc index 659be81..b7082e4 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -22,7 +22,7 @@ #include // Placeholder for internal header, do not modify. -#include "compression/shared.h" // ModelTraining +#include "compression/shared.h" // PromptWrapping #include "evals/benchmark_helper.h" #include "gemma/common.h" #include "gemma/gemma.h" // Gemma @@ -96,7 +96,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, if (have_image) { image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim)); - HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); + HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(image.ReadPPM(args.image_file.path)); const size_t image_size = model.GetModelConfig().image_size; image.Resize(image_size, image_size); @@ -207,7 +207,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, std::cout << "\n\n"; // Prepare for the next turn. - if (!args.multiturn || model.Info().training == ModelTraining::PALIGEMMA) { + if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) { abs_pos = 0; // Start a new turn at position 0. InitGenerator(args, gen); } else { diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index d6247ca..ffd71ae 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -22,7 +22,7 @@ #include #include "compression/io.h" // Path -#include "compression/shared.h" // ModelTraining +#include "compression/shared.h" // PromptWrapping #include "gemma/common.h" // Wrap #include "hwy/base.h" // HWY_ASSERT #include "hwy/profiler.h" @@ -110,7 +110,7 @@ std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, } // PaliGemma separator. The SEP token "\n" is always tokenized separately. - if (info.training == ModelTraining::PALIGEMMA) { + if (info.wrapping == PromptWrapping::PALIGEMMA) { std::vector sep_tokens; HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end()); diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 53c3b79..f5d5304 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -53,7 +53,7 @@ void PaliGemmaTest::InitVit(const std::string& path) { image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len, model.GetModelConfig().model_dim)); Image image; - HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); + HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(image.ReadPPM(path)); const size_t image_size = model.GetModelConfig().image_size; image.Resize(image_size, image_size); diff --git a/util/app.h b/util/app.h index 8736ecd..aa17567 100644 --- a/util/app.h +++ b/util/app.h @@ -136,8 +136,8 @@ struct LoaderArgs : public ArgsBase { // Returns error string or nullptr if OK. const char* Validate() { - if (const char* err = ParseModelTypeAndTraining(model_type_str, info_.model, - info_.training)) { + if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model, + info_.wrapping)) { return err; } if (const char* err = ParseType(weight_type_str, info_.weight)) {