Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 448px resolution to PaliGemma and PaliGemma2. #459

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions gemma/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ constexpr const char* kModelFlags[] = {
"9b-pt", "9b-it", // Gemma2 9B
"27b-pt", "27b-it", // Gemma2 27B
"paligemma-224", // PaliGemma 224
"paligemma-448", // PaliGemma 448
"paligemma2-3b-224", // PaliGemma2 3B 224
"paligemma2-3b-448", // PaliGemma2 3B 448
"paligemma2-10b-224", // PaliGemma2 10B 224
"paligemma2-10b-448", // PaliGemma2 10B 448
};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
Expand All @@ -51,8 +54,11 @@ constexpr Model kModelTypes[] = {
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
Model::PALIGEMMA_224, // PaliGemma 224
Model::PALIGEMMA_448, // PaliGemma 448
Model::PALIGEMMA2_3B_224, // PaliGemma2 3B 224
Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
Expand All @@ -62,9 +68,9 @@ constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
ModelTraining::PALIGEMMA, // PaliGemma 224
ModelTraining::PALIGEMMA, // PaliGemma2 3B 224
ModelTraining::PALIGEMMA, // PaliGemma2 10B 224
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PaliGemma 224 / 448
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 3B 224 / 448
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 10B 224 / 448
};

constexpr size_t kNumModelFlags =
Expand Down
34 changes: 32 additions & 2 deletions gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ static ModelConfig ConfigGriffin2B() {
}

// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
static void AddVitConfig(ModelConfig& config) {
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
config.vit_model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152
config.image_size = 224;
config.image_size = image_size;
config.patch_width = 14;
for (auto& layer_config : config.layer_configs) {
layer_config.optimized_gating = false;
Expand Down Expand Up @@ -236,6 +236,14 @@ static ModelConfig ConfigPaliGemma_224() {
return config;
}

static ModelConfig ConfigPaliGemma_448() {
ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_448";
config.model = Model::PALIGEMMA_448;
AddVitConfig(config, /*image_size=*/448);
return config;
}

ModelConfig VitConfig(const ModelConfig& config) {
ModelConfig vit_config = ConfigNoSSM();
vit_config.model_dim = config.vit_model_dim;
Expand All @@ -254,6 +262,14 @@ static ModelConfig ConfigPaliGemma2_3B_224() {
return config;
}

static ModelConfig ConfigPaliGemma2_3B_448() {
ModelConfig config = ConfigGemma2_2B();
config.model_name = "PaliGemma2_3B_448";
config.model = Model::PALIGEMMA2_3B_448;
AddVitConfig(config, /*image_size=*/448);
return config;
}

static ModelConfig ConfigPaliGemma2_10B_224() {
ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_224";
Expand All @@ -262,6 +278,14 @@ static ModelConfig ConfigPaliGemma2_10B_224() {
return config;
}

static ModelConfig ConfigPaliGemma2_10B_448() {
ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_448";
config.model = Model::PALIGEMMA2_10B_448;
AddVitConfig(config, /*image_size=*/448);
return config;
}

ModelConfig ConfigFromModel(Model model) {
switch (model) {
case Model::GEMMA_2B:
Expand All @@ -280,10 +304,16 @@ ModelConfig ConfigFromModel(Model model) {
return ConfigGemmaTiny();
case Model::PALIGEMMA_224:
return ConfigPaliGemma_224();
case Model::PALIGEMMA_448:
return ConfigPaliGemma_448();
case Model::PALIGEMMA2_3B_224:
return ConfigPaliGemma2_3B_224();
case Model::PALIGEMMA2_3B_448:
return ConfigPaliGemma2_3B_448();
case Model::PALIGEMMA2_10B_224:
return ConfigPaliGemma2_10B_224();
case Model::PALIGEMMA2_10B_448:
return ConfigPaliGemma2_10B_448();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
Expand Down
7 changes: 6 additions & 1 deletion gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,20 @@ enum class Model {
GEMMA_TINY,
GEMMA2_2B,
PALIGEMMA_224,
PALIGEMMA_448,
PALIGEMMA2_3B_224,
PALIGEMMA2_3B_448,
PALIGEMMA2_10B_224,
PALIGEMMA2_10B_448,
};

// Allows the Model enum to be iterated over.
static constexpr Model kAllModels[] = {
Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B,
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
Model::PALIGEMMA_224, Model::PALIGEMMA2_3B_224, Model::PALIGEMMA2_10B_224,
Model::PALIGEMMA_224, Model::PALIGEMMA_448,
Model::PALIGEMMA2_3B_224, Model::PALIGEMMA2_3B_448,
Model::PALIGEMMA2_10B_224, Model::PALIGEMMA2_10B_448,
};

struct LayerConfig {
Expand Down
3 changes: 2 additions & 1 deletion gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
image.Resize();
const size_t image_size = model.GetModelConfig().image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
double image_tokens_start = hwy::platform::Now();
Expand Down
2 changes: 1 addition & 1 deletion gemma/tensor_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
.name = "img_pos_emb",
.source_names = {"img/pos_embedding"},
.axes = {0, 1},
.shape = {/*1,*/ 256, config.vit_model_dim},
.shape = {/*1,*/ config.vit_seq_len, config.vit_model_dim},
.min_size = Type::kF32,
},
};
Expand Down
3 changes: 2 additions & 1 deletion gemma/weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ struct ModelWeightsPtrs {
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim,
config.patch_width * config.patch_width * 3),
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
vit_img_pos_embedding("img_pos_emb", config.vit_seq_len,
config.vit_model_dim),
vit_img_head_bias("img_head_bias", 1, config.model_dim),
vit_img_head_kernel("img_head_kernel", config.model_dim,
config.vit_model_dim),
Expand Down
33 changes: 15 additions & 18 deletions paligemma/image.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@

namespace gcpp {
namespace {
// Hardcoded for PaliGemma-224 ViT input.
// Hardcoded for PaliGemma ViT input.
constexpr size_t kPatchSize = 14;
constexpr size_t kImageSize = 224;
constexpr size_t kNumPatches = kImageSize / kPatchSize; // 16

// Returns the linearly scaled index in [0, to_size) closest to the
// value in [0, from_size).
Expand Down Expand Up @@ -174,9 +172,7 @@ void Image::Set(int width, int height, const float* data) {
}
}

void Image::Resize() {
int new_width = 224;
int new_height = kImageSize;
void Image::Resize(int new_width, int new_height) {
std::vector<float> new_data(new_width * new_height * 3);
// TODO: go to bilinear interpolation, or antialias.
// E.g. consider WeightsSymmetric3Lowpass and SlowSymmetric3 from
Expand Down Expand Up @@ -211,18 +207,19 @@ bool Image::WriteBinary(const std::string& filename) const {
return true;
}

// Image.data() is kImageSize x kImageSize x 3, H x W x C.
// We want the N-th patch (of 256) of size kPatchSize x kPatchSize x 3.
// Patches are numbered in usual "pixel-order".
// Image.data() is H x W x 3.
// We want the N-th patch of size kPatchSize x kPatchSize x 3.
void Image::GetPatch(size_t patch_num, float* patch) const {
PROFILER_FUNC;
constexpr size_t kDataSize = kImageSize * kImageSize * 3;
const size_t kDataSize = width_ * height_ * 3;
HWY_ASSERT(size() == kDataSize);
constexpr size_t kPatchDataSize = kPatchSize * kPatchSize * 3;
size_t i_offs = patch_num / kNumPatches;
size_t j_offs = patch_num % kNumPatches;
HWY_ASSERT(0 <= i_offs && i_offs < kNumPatches);
HWY_ASSERT(0 <= j_offs && j_offs < kNumPatches);
HWY_ASSERT(width_ % kPatchSize == 0);
HWY_ASSERT(height_ % kPatchSize == 0);
const size_t kNumPatchesPerRow = width_ / kPatchSize;
size_t i_offs = patch_num / kNumPatchesPerRow;
size_t j_offs = patch_num % kNumPatchesPerRow;
HWY_ASSERT(0 <= i_offs && i_offs < height_ / kPatchSize);
HWY_ASSERT(0 <= j_offs && j_offs < kNumPatchesPerRow);
i_offs *= kPatchSize;
j_offs *= kPatchSize;
// This can be made faster, but let's first see whether it matters.
Expand All @@ -231,10 +228,10 @@ void Image::GetPatch(size_t patch_num, float* patch) const {
for (size_t j = 0; j < kPatchSize; ++j) {
for (size_t k = 0; k < 3; ++k) {
const size_t patch_index = (i * kPatchSize + j) * 3 + k;
HWY_ASSERT(patch_index < kPatchDataSize);
HWY_DASSERT(patch_index < kPatchSize * kPatchSize * 3);
const size_t image_index =
((i + i_offs) * kImageSize + (j + j_offs)) * 3 + k;
HWY_ASSERT(image_index < kDataSize);
((i + i_offs) * width_ + (j + j_offs)) * 3 + k;
HWY_DASSERT(image_index < kDataSize);
patch[patch_index] = image_data[image_index];
}
}
Expand Down
17 changes: 9 additions & 8 deletions paligemma/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

namespace gcpp {

// Very basic image loading and processing for PaliGemma-224. Does not try to be
// generic at the moment, e.g. the size to normalize to is hardcoded.
// Very basic image loading and processing for PaliGemma.
class Image {
public:
Image() = default;
Expand All @@ -38,15 +37,17 @@ class Image {
// Sets the image content to the given data. The data is copied and normalized
// to [-1, 1]. The data is expected to be of size width * height * 3.
void Set(int width, int height, const float* data);
// Resizes to 224x224 (nearest-neighbor for now, bilinear or antialias would
// be better).
void Resize();
// Resizes to width x height (nearest-neighbor for now, bilinear or antialias
// would be better).
void Resize(int width, int height);
// Writes the file as plain floats in binary. Useful to e.g. load in a colab.
bool WriteBinary(const std::string& filename) const;
// Stores the patch for the given patch number [0, 256) in `patch`.
// As sizes are hardcoded, the patch number is sufficient here.
// Stores the patch for the given patch number in `patch`.
// Patches are numbered in usual raster-order. E.g. for an image of size
// 224 x 224, there are 16 x 16 = 256 patches.
// `patch` should have space for at least 14 * 14 * 3 = 588 floats.
// Requires that Normalize() has been called.
// Requires that Normalize() has been called and that the image width and
// height are multiples of 14.
void GetPatch(size_t patch_num, float* patch) const;

float *data() { return data_.data(); }
Expand Down
68 changes: 64 additions & 4 deletions paligemma/image_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@

#include "paligemma/image.h"

#include <cstddef>
#include <string>
#include <vector>

#include "gtest/gtest.h"

namespace gcpp {
namespace {

float Normalize(int value) { return 2.0f * (value / 255.0f) - 1.0f; }
float Normalize(float value, float max_value = 255.0f) {
return 2.0f * (value / max_value) - 1.0f;
}

TEST(ImageTest, BasicFunctionality) {
TEST(ImageTest, LoadResize224GetPatch) {
return; // Need to figure out how to get the external path for the test file.
std::string path;
Image image;
Expand All @@ -48,7 +52,7 @@ TEST(ImageTest, BasicFunctionality) {
EXPECT_EQ(image.data()[33], Normalize(164));
EXPECT_EQ(image.data()[34], Normalize(185));
EXPECT_EQ(image.data()[35], Normalize(191));
image.Resize();
image.Resize(224, 224);
// Check first and last pixel.
EXPECT_EQ(image.data()[0], Normalize(160));
EXPECT_EQ(image.data()[1], Normalize(184));
Expand All @@ -63,10 +67,66 @@ TEST(ImageTest, BasicFunctionality) {
EXPECT_EQ(patch[1], Normalize(184));
EXPECT_EQ(patch[2], Normalize(188));
image.GetPatch(18, patch);
for (int i = 0; i < 10; ++i) {
// Check the first row of the patch.
for (size_t i = 0; i < 14 * 3; ++i) {
EXPECT_EQ(patch[i], image.data()[(14 * 224 + 2 * 14) * 3 + i]);
}
}

TEST(ImageTest, Non224) {
std::vector<float> data(28 * 42 * 3);
for (int i = 0; i < data.size(); ++i) {
data[i] = static_cast<float>(i);
}
float max_value = data.back();
Image image;
image.Set(28, 42, data.data());
EXPECT_EQ(image.width(), 28);
EXPECT_EQ(image.height(), 42);
EXPECT_EQ(image.size(), data.size());
// Resize 28 x 42 -> 56 x 42, "double" each pixel horizontally.
image.Resize(/*new_width=*/56, /*new_height=*/42);
// Check a few pixels.
EXPECT_NEAR(image.data()[0], Normalize(0.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[1], Normalize(1.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[2], Normalize(2.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[3], Normalize(0.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[4], Normalize(1.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[5], Normalize(2.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[6], Normalize(3.0f, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 9],
Normalize(data.size() - 6, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 8],
Normalize(data.size() - 5, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 7],
Normalize(data.size() - 4, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 3],
Normalize(data.size() - 3, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 2],
Normalize(data.size() - 2, max_value), 1e-6);
EXPECT_NEAR(image.data()[image.size() - 1],
Normalize(data.size() - 1, max_value), 1e-6);
// Extract two patches.
const size_t kPatchValues = 14 * 14 * 3; // = 588
float patch[kPatchValues];
// Patch 0 is just the "start" of the image.
image.GetPatch(0, patch);
EXPECT_NEAR(patch[0], Normalize(0.0f, max_value), 1e-6);
EXPECT_NEAR(patch[1], Normalize(1.0f, max_value), 1e-6);
EXPECT_NEAR(patch[2], Normalize(2.0f, max_value), 1e-6);
// The "image" has 4x3 patches, so patch 6 has coordinates (1, 2) and its
// pixel coordinates are offset by (14, 28).
image.GetPatch(6, patch);
for (size_t n = 0; n < kPatchValues; ++n) {
size_t k = n % 3;
size_t j = ((n - k) / 3) % 14;
size_t i = (n - k - j * 3) / (14 * 3);
EXPECT_EQ(n, (i * 14 + j) * 3 + k);
i += 14;
j += 28;
EXPECT_EQ(patch[n], image.data()[(i * 56 + j) * 3 + k]);
}
}

} // namespace
} // namespace gcpp
Loading
Loading