Skip to content

Commit

Permalink
Rename ModelTraining to PromptWrapping which is a more accurate name.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705594320
  • Loading branch information
danielkeysers authored and copybara-github committed Dec 13, 2024
1 parent 6254f2e commit e180599
Show file tree
Hide file tree
Showing 13 changed files with 49 additions and 49 deletions.
2 changes: 1 addition & 1 deletion backprop/optimize_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ TEST(OptimizeTest, GradientDescent) {

const ModelInfo info = {
.model = Model::GEMMA_TINY,
.training = ModelTraining::GEMMA_IT,
.wrapping = PromptWrapping::GEMMA_IT,
.weight = Type::kF32,
};
ModelConfig config = ConfigFromModel(info.model);
Expand Down
12 changes: 6 additions & 6 deletions compression/compress_weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
#include <vector>

#include "compression/compress.h"
#include "compression/shared.h" // ModelTraining
#include "compression/io.h" // Path
#include "compression/shared.h" // PromptWrapping
#include "gemma/common.h" // Model
#include "gemma/weights.h"
#include "util/allocator.h"
Expand Down Expand Up @@ -74,8 +74,8 @@ struct Args : public ArgsBase<Args> {

// Returns error string or nullptr if OK.
const char* Validate() {
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
model_training_)) {
if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_,
prompt_wrapping_)) {
return err;
}
if (const char* err = ParseType(weight_type_str, weight_type_)) {
Expand Down Expand Up @@ -127,12 +127,12 @@ struct Args : public ArgsBase<Args> {

// Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; }
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; }
gcpp::Type WeightType() const { return weight_type_; }

private:
Model model_type_;
ModelTraining model_training_;
PromptWrapping prompt_wrapping_;
Type weight_type_;
};

Expand Down Expand Up @@ -212,7 +212,7 @@ namespace gcpp {

void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads);
if (args.ModelTrainingType() == ModelTraining::PALIGEMMA) {
if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) {
HWY_ABORT("PaliGemma is not supported in compress_weights.");
}
const Model model_type = args.ModelType();
Expand Down
2 changes: 1 addition & 1 deletion compression/shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ constexpr bool IsNuqStream() {
}

// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA };
enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA };

// Tensor types for loading weights. Note that not all types are supported as
// weights for a model, but can be used for other purposes, such as types for
Expand Down
2 changes: 1 addition & 1 deletion evals/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ int main(int argc, char** argv) {
const std::string golden_path =
benchmark_args.goldens.path + "/" +
gcpp::ModelString(env.GetModel()->Info().model,
env.GetModel()->Info().training) +
env.GetModel()->Info().wrapping) +
".txt";
return BenchmarkGoldens(env, golden_path);
} else if (!benchmark_args.summarize_text.Empty()) {
Expand Down
42 changes: 21 additions & 21 deletions gemma/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,25 @@ constexpr Model kModelTypes[] = {
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
ModelTraining::GEMMA_IT, // Gemma Tiny
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PaliGemma 224 / 448
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 3B 224 / 448
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 10B 224 / 448
constexpr PromptWrapping kPromptWrapping[] = {
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // RecurrentGemma
PromptWrapping::GEMMA_IT, // Gemma Tiny
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 2B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
};

constexpr size_t kNumModelFlags = std::size(kModelFlags);
static_assert(kNumModelFlags == std::size(kModelTypes));
static_assert(kNumModelFlags == std::size(kModelTraining));
static_assert(kNumModelFlags == std::size(kPromptWrapping));

const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) {
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
Model& model, PromptWrapping& wrapping) {
static std::string kErrorMessageBuffer =
"Invalid or missing model flag, need to specify one of ";
for (size_t i = 0; i + 1 < kNumModelFlags; ++i) {
Expand All @@ -93,21 +93,21 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
for (size_t i = 0; i < kNumModelFlags; ++i) {
if (kModelFlags[i] == model_type_lc) {
model = kModelTypes[i];
training = kModelTraining[i];
HWY_ASSERT(std::string(ModelString(model, training)) == model_type_lc);
wrapping = kPromptWrapping[i];
HWY_ASSERT(std::string(ModelString(model, wrapping)) == model_type_lc);
return nullptr;
}
}
return kErrorMessageBuffer.c_str();
}

const char* ModelString(Model model, ModelTraining training) {
const char* ModelString(Model model, PromptWrapping wrapping) {
for (size_t i = 0; i < kNumModelFlags; i++) {
if (kModelTypes[i] == model && kModelTraining[i] == training)
if (kModelTypes[i] == model && kPromptWrapping[i] == wrapping)
return kModelFlags[i];
}
HWY_ABORT("Unknown model %d training %d\n", static_cast<int>(model),
static_cast<int>(training));
HWY_ABORT("Unknown model %d wrapping %d\n", static_cast<int>(model),
static_cast<int>(wrapping));
}

const char* StringFromType(Type type) {
Expand Down Expand Up @@ -139,7 +139,7 @@ const char* ParseType(const std::string& type_string, Type& type) {
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {

// Instruction-tuned models are trained to expect control tokens.
if (info.training == ModelTraining::GEMMA_IT) {
if (info.wrapping == PromptWrapping::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0)
? "<start_of_turn>user\n"
Expand Down
12 changes: 6 additions & 6 deletions gemma/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include <string>

#include "compression/shared.h" // ModelTraining
#include "compression/shared.h" // PromptWrapping
#include "gemma/configs.h" // IWYU pragma: export
#include "hwy/base.h" // ConvertScalarTo

Expand All @@ -29,18 +29,18 @@ namespace gcpp {
// Struct to bundle model information.
struct ModelInfo {
Model model;
ModelTraining training;
PromptWrapping wrapping;
Type weight;
};

// Returns error string or nullptr if OK.
// Thread-hostile.
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
Model& model, PromptWrapping& wrapping);
const char* ParseType(const std::string& type_string, Type& type);

// Inverse of ParseModelTypeAndTraining.
const char* ModelString(Model model, ModelTraining training);
// Inverse of ParseModelTypeAndWrapping.
const char* ModelString(Model model, PromptWrapping wrapping);
const char* StringFromType(Type type);

// Wraps the given prompt using the expected control tokens for IT models.
Expand Down
6 changes: 3 additions & 3 deletions gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,13 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
bool debug) const {
bool result = true;
// We don't care about model_name, model, training, or weight being different,
// We don't care about model_name, model, wrapping, or weight being different,
// but will output in debug mode if they are.
if (debug) {
WARN_IF_NOT_EQUAL(model_name, other.model_name);
WARN_IF_NOT_EQUAL(static_cast<int>(model), static_cast<int>(other.model));
WARN_IF_NOT_EQUAL(static_cast<int>(training),
static_cast<int>(other.training));
WARN_IF_NOT_EQUAL(static_cast<int>(wrapping),
static_cast<int>(other.wrapping));
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
}
TEST_EQUAL(model_dim, other.model_dim);
Expand Down
2 changes: 1 addition & 1 deletion gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ struct ModelConfig {

std::string model_name;
Model model;
ModelTraining training;
PromptWrapping wrapping;
Type weight;
size_t num_layers = 0;
size_t model_dim = 0;
Expand Down
2 changes: 1 addition & 1 deletion gemma/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class Gemma {
ModelInfo info_;
};

// Adds BOS token and possibly 'turn' annotations, which depend on `training`
// Adds BOS token and possibly 'turn' annotations, which depend on `info`
// and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
Expand Down
6 changes: 3 additions & 3 deletions gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <vector>

// Placeholder for internal header, do not modify.
#include "compression/shared.h" // ModelTraining
#include "compression/shared.h" // PromptWrapping
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/gemma.h" // Gemma
Expand Down Expand Up @@ -96,7 +96,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
if (have_image) {
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
const size_t image_size = model.GetModelConfig().image_size;
image.Resize(image_size, image_size);
Expand Down Expand Up @@ -207,7 +207,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
std::cout << "\n\n";

// Prepare for the next turn.
if (!args.multiturn || model.Info().training == ModelTraining::PALIGEMMA) {
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0.
InitGenerator(args, gen);
} else {
Expand Down
4 changes: 2 additions & 2 deletions gemma/tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <vector>

#include "compression/io.h" // Path
#include "compression/shared.h" // ModelTraining
#include "compression/shared.h" // PromptWrapping
#include "gemma/common.h" // Wrap
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/profiler.h"
Expand Down Expand Up @@ -110,7 +110,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
}

// PaliGemma separator. The SEP token "\n" is always tokenized separately.
if (info.training == ModelTraining::PALIGEMMA) {
if (info.wrapping == PromptWrapping::PALIGEMMA) {
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
Expand Down
2 changes: 1 addition & 1 deletion paligemma/paligemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
model.GetModelConfig().model_dim));
Image image;
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = model.GetModelConfig().image_size;
image.Resize(image_size, image_size);
Expand Down
4 changes: 2 additions & 2 deletions util/app.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {

// Returns error string or nullptr if OK.
const char* Validate() {
if (const char* err = ParseModelTypeAndTraining(model_type_str, info_.model,
info_.training)) {
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.wrapping)) {
return err;
}
if (const char* err = ParseType(weight_type_str, info_.weight)) {
Expand Down

0 comments on commit e180599

Please sign in to comment.