Skip to content

Commit

Permalink
Add support for 448px resolution to PaliGemma and PaliGemma2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702747814
  • Loading branch information
danielkeysers authored and copybara-github committed Dec 9, 2024
1 parent 66bb435 commit 4c0781a
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 44 deletions.
12 changes: 9 additions & 3 deletions gemma/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ constexpr const char* kModelFlags[] = {
"9b-pt", "9b-it", // Gemma2 9B
"27b-pt", "27b-it", // Gemma2 27B
"paligemma-224", // PaliGemma 224
"paligemma-448", // PaliGemma 448
"paligemma2-3b-224", // PaliGemma2 3B 224
"paligemma2-3b-448", // PaliGemma2 3B 448
"paligemma2-10b-224", // PaliGemma2 10B 224
"paligemma2-10b-448", // PaliGemma2 10B 448
};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
Expand All @@ -51,8 +54,11 @@ constexpr Model kModelTypes[] = {
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
Model::PALIGEMMA_224, // PaliGemma 224
Model::PALIGEMMA_448, // PaliGemma 448
Model::PALIGEMMA2_3B_224, // PaliGemma2 3B 224
Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
Expand All @@ -62,9 +68,9 @@ constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
ModelTraining::PALIGEMMA, // PaliGemma 224
ModelTraining::PALIGEMMA, // PaliGemma2 3B 224
ModelTraining::PALIGEMMA, // PaliGemma2 10B 224
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PaliGemma 224 / 448
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 3B 224 / 448
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 10B 224 / 448
};

constexpr size_t kNumModelFlags =
Expand Down
34 changes: 32 additions & 2 deletions gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ static ModelConfig ConfigGriffin2B() {
}

// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
static void AddVitConfig(ModelConfig& config) {
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
config.vit_model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152
config.image_size = 224;
config.image_size = image_size;
config.patch_width = 14;
for (auto& layer_config : config.layer_configs) {
layer_config.optimized_gating = false;
Expand Down Expand Up @@ -236,6 +236,14 @@ static ModelConfig ConfigPaliGemma_224() {
return config;
}

static ModelConfig ConfigPaliGemma_448() {
ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_448";
config.model = Model::PALIGEMMA_448;
AddVitConfig(config, /*image_size=*/448);
return config;
}

ModelConfig VitConfig(const ModelConfig& config) {
ModelConfig vit_config = ConfigNoSSM();
vit_config.model_dim = config.vit_model_dim;
Expand All @@ -254,6 +262,14 @@ static ModelConfig ConfigPaliGemma2_3B_224() {
return config;
}

static ModelConfig ConfigPaliGemma2_3B_448() {
ModelConfig config = ConfigGemma2_2B();
config.model_name = "PaliGemma2_3B_448";
config.model = Model::PALIGEMMA2_3B_448;
AddVitConfig(config, /*image_size=*/448);
return config;
}

static ModelConfig ConfigPaliGemma2_10B_224() {
ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_224";
Expand All @@ -262,6 +278,14 @@ static ModelConfig ConfigPaliGemma2_10B_224() {
return config;
}

static ModelConfig ConfigPaliGemma2_10B_448() {
ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_448";
config.model = Model::PALIGEMMA2_10B_448;
AddVitConfig(config, /*image_size=*/448);
return config;
}

ModelConfig ConfigFromModel(Model model) {
switch (model) {
case Model::GEMMA_2B:
Expand All @@ -280,10 +304,16 @@ ModelConfig ConfigFromModel(Model model) {
return ConfigGemmaTiny();
case Model::PALIGEMMA_224:
return ConfigPaliGemma_224();
case Model::PALIGEMMA_448:
return ConfigPaliGemma_448();
case Model::PALIGEMMA2_3B_224:
return ConfigPaliGemma2_3B_224();
case Model::PALIGEMMA2_3B_448:
return ConfigPaliGemma2_3B_448();
case Model::PALIGEMMA2_10B_224:
return ConfigPaliGemma2_10B_224();
case Model::PALIGEMMA2_10B_448:
return ConfigPaliGemma2_10B_448();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
Expand Down
7 changes: 6 additions & 1 deletion gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,20 @@ enum class Model {
GEMMA_TINY,
GEMMA2_2B,
PALIGEMMA_224,
PALIGEMMA_448,
PALIGEMMA2_3B_224,
PALIGEMMA2_3B_448,
PALIGEMMA2_10B_224,
PALIGEMMA2_10B_448,
};

// Allows the Model enum to be iterated over.
static constexpr Model kAllModels[] = {
Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B,
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
Model::PALIGEMMA_224, Model::PALIGEMMA2_3B_224, Model::PALIGEMMA2_10B_224,
Model::PALIGEMMA_224, Model::PALIGEMMA_448,
Model::PALIGEMMA2_3B_224, Model::PALIGEMMA2_3B_448,
Model::PALIGEMMA2_10B_224, Model::PALIGEMMA2_10B_448,
};

struct LayerConfig {
Expand Down
3 changes: 2 additions & 1 deletion gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
image.Resize();
const size_t image_size = model.GetModelConfig().image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
double image_tokens_start = hwy::platform::Now();
Expand Down
2 changes: 1 addition & 1 deletion gemma/tensor_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
.name = "img_pos_emb",
.source_names = {"img/pos_embedding"},
.axes = {0, 1},
.shape = {/*1,*/ 256, config.vit_model_dim},
.shape = {/*1,*/ config.vit_seq_len, config.vit_model_dim},
.min_size = Type::kF32,
},
};
Expand Down
3 changes: 2 additions & 1 deletion gemma/weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ struct ModelWeightsPtrs {
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim,
config.patch_width * config.patch_width * 3),
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
vit_img_pos_embedding("img_pos_emb", config.vit_seq_len,
config.vit_model_dim),
vit_img_head_bias("img_head_bias", 1, config.model_dim),
vit_img_head_kernel("img_head_kernel", config.model_dim,
config.vit_model_dim),
Expand Down
33 changes: 15 additions & 18 deletions paligemma/image.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@

namespace gcpp {
namespace {
// Hardcoded for PaliGemma-224 ViT input.
// Hardcoded for PaliGemma ViT input.
constexpr size_t kPatchSize = 14;
constexpr size_t kImageSize = 224;
constexpr size_t kNumPatches = kImageSize / kPatchSize; // 16

// Returns the linearly scaled index in [0, to_size) closest to the
// value in [0, from_size).
Expand Down Expand Up @@ -174,9 +172,7 @@ void Image::Set(int width, int height, const float* data) {
}
}

void Image::Resize() {
int new_width = 224;
int new_height = kImageSize;
void Image::Resize(int new_width, int new_height) {
std::vector<float> new_data(new_width * new_height * 3);
// TODO: go to bilinear interpolation, or antialias.
// E.g. consider WeightsSymmetric3Lowpass and SlowSymmetric3 from
Expand Down Expand Up @@ -211,18 +207,19 @@ bool Image::WriteBinary(const std::string& filename) const {
return true;
}

// Image.data() is kImageSize x kImageSize x 3, H x W x C.
// We want the N-th patch (of 256) of size kPatchSize x kPatchSize x 3.
// Patches are numbered in usual "pixel-order".
// Image.data() is H x W x 3.
// We want the N-th patch of size kPatchSize x kPatchSize x 3.
void Image::GetPatch(size_t patch_num, float* patch) const {
PROFILER_FUNC;
constexpr size_t kDataSize = kImageSize * kImageSize * 3;
const size_t kDataSize = width_ * height_ * 3;
HWY_ASSERT(size() == kDataSize);
constexpr size_t kPatchDataSize = kPatchSize * kPatchSize * 3;
size_t i_offs = patch_num / kNumPatches;
size_t j_offs = patch_num % kNumPatches;
HWY_ASSERT(0 <= i_offs && i_offs < kNumPatches);
HWY_ASSERT(0 <= j_offs && j_offs < kNumPatches);
HWY_ASSERT(width_ % kPatchSize == 0);
HWY_ASSERT(height_ % kPatchSize == 0);
const size_t kNumPatchesPerRow = width_ / kPatchSize;
size_t i_offs = patch_num / kNumPatchesPerRow;
size_t j_offs = patch_num % kNumPatchesPerRow;
HWY_ASSERT(0 <= i_offs && i_offs < height_ / kPatchSize);
HWY_ASSERT(0 <= j_offs && j_offs < kNumPatchesPerRow);
i_offs *= kPatchSize;
j_offs *= kPatchSize;
// This can be made faster, but let's first see whether it matters.
Expand All @@ -231,10 +228,10 @@ void Image::GetPatch(size_t patch_num, float* patch) const {
for (size_t j = 0; j < kPatchSize; ++j) {
for (size_t k = 0; k < 3; ++k) {
const size_t patch_index = (i * kPatchSize + j) * 3 + k;
HWY_ASSERT(patch_index < kPatchDataSize);
HWY_DASSERT(patch_index < kPatchSize * kPatchSize * 3);
const size_t image_index =
((i + i_offs) * kImageSize + (j + j_offs)) * 3 + k;
HWY_ASSERT(image_index < kDataSize);
((i + i_offs) * width_ + (j + j_offs)) * 3 + k;
HWY_DASSERT(image_index < kDataSize);
patch[patch_index] = image_data[image_index];
}
}
Expand Down
17 changes: 9 additions & 8 deletions paligemma/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

namespace gcpp {

// Very basic image loading and processing for PaliGemma-224. Does not try to be
// generic at the moment, e.g. the size to normalize to is hardcoded.
// Very basic image loading and processing for PaliGemma.
class Image {
public:
Image() = default;
Expand All @@ -38,15 +37,17 @@ class Image {
// Sets the image content to the given data. The data is copied and normalized
// to [-1, 1]. The data is expected to be of size width * height * 3.
void Set(int width, int height, const float* data);
// Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would
// be better).
void Resize();
// Resizes to width x height (nearest-neighbor for now, bilinear or antialias
// would be better).
void Resize(int width, int height);
// Writes the file as plain floats in binary. Useful to e.g. load in a colab.
bool WriteBinary(const std::string& filename) const;
// Stores the patch for the given patch number [0, 256) in `patch`.
// As sizes are hardcoded, the patch number is sufficient here.
// Stores the patch for the given patch number in `patch`.
// Patches are numbered in usual raster-order. E.g. for an image of size
// 224 x 224, there are 16 x 16 = 256 patches.
// `patch` should have space for at least 14 * 14 * 3 = 588 floats.
// Requires that Normalize() has been called.
// Requires that Normalize() has been called and that the image width and
// height are multiples of 14.
void GetPatch(size_t patch_num, float* patch) const;

float *data() { return data_.data(); }
Expand Down
68 changes: 64 additions & 4 deletions paligemma/image_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@

#include "paligemma/image.h"

#include <cstddef>
#include <string>
#include <vector>

#include "gtest/gtest.h"

namespace gcpp {
namespace {

float Normalize(int value) { return 2.0f * (value / 255.0f) - 1.0f; }
float Normalize(float value, float max_value = 255.0f) {
return 2.0f * (value / max_value) - 1.0f;
}

TEST(ImageTest, BasicFunctionality) {
TEST(ImageTest, LoadResize224GetPatch) {
return; // Need to figure out how to get the external path for the test file.
std::string path;
Image image;
Expand All @@ -48,7 +52,7 @@ TEST(ImageTest, BasicFunctionality) {
EXPECT_EQ(image.data()[33], Normalize(164));
EXPECT_EQ(image.data()[34], Normalize(185));
EXPECT_EQ(image.data()[35], Normalize(191));
image.Resize();
image.Resize(224, 224);
// Check first and last pixel.
EXPECT_EQ(image.data()[0], Normalize(160));
EXPECT_EQ(image.data()[1], Normalize(184));
Expand All @@ -63,10 +67,66 @@ TEST(ImageTest, BasicFunctionality) {
EXPECT_EQ(patch[1], Normalize(184));
EXPECT_EQ(patch[2], Normalize(188));
image.GetPatch(18, patch);
for (int i = 0; i < 10; ++i) {
// Check the first row of the patch.
for (size_t i = 0; i < 14 * 3; ++i) {
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
}
}

TEST(ImageTest, Non224) {
std::vector<float> data(28 * 42 * 3);
for (int i = 0; i < data.size(); ++i) {
data[i] = static_cast<float>(i);
}
float max_value = data.back();
Image image;
image.Set(28, 42, data.data());
EXPECT_EQ(image.width(), 28);
EXPECT_EQ(image.height(), 42);
EXPECT_EQ(image.size(), data.size());
// Resize 28 x 42 -> 56 x 42, "double" each pixel horizontally.
image.Resize(/*new_width=*/56, /*new_height=*/42);
// Check a few pixels.
EXPECT_NEAR(image.data()[0], Normalize(0.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[1], Normalize(1.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[2], Normalize(2.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[3], Normalize(0.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[4], Normalize(1.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[5], Normalize(2.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[6], Normalize(3.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 9],
Normalize(data.size() - 6, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 8],
Normalize(data.size() - 5, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 7],
Normalize(data.size() - 4, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 3],
Normalize(data.size() - 3, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 2],
Normalize(data.size() - 2, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 1],
Normalize(data.size() - 1, max_value), 1e-6);
// Extract two patches.
const size_t kPatchValues = 14 * 14 * 3; // = 588
float patch[kPatchValues];
// Patch 0 is just the "start" of the image.
image.GetPatch(0, patch);
EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6);
EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6);
EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6);
// The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its
// pixel coordinates are offset by (14, 28).
image.GetPatch(6, patch);
for (size_t n = 0; n < kPatchValues; ++n) {
size_t k = n % 3;
size_t j = ((n - k) / 3) % 14;
size_t i = (n - k - j * 3) / (14 * 3);
EXPECT_EQ(n, (i * 14 + j) * 3 + k);
i += 14;
j += 28;
EXPECT_EQ(patch[n], image.data()[(i * 56 + j) * 3 + k]);
}
}

} // namespace
} // namespace gcpp
Loading

0 comments on commit 4c0781a

Please sign in to comment.