From 3d1625d8c578605e3a618715f858714df3471867 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Thu, 21 Nov 2024 05:27:02 -0800 Subject: [PATCH] Improved consistency of compressor API, and added a universal method with a target type arg. Moved configs pybind up to root level. PiperOrigin-RevId: 698743417 --- compression/compress-inl.h | 3 +- compression/python/BUILD.bazel | 2 + compression/python/compression_clif_aux.cc | 42 +++++- compression/python/compression_clif_aux.h | 6 +- compression/python/compression_extension.cc | 14 +- compression/python/compression_test.py | 9 +- gemma/configs.cc | 3 + gemma/configs_test.cc | 1 - gemma/python/BUILD.bazel | 17 --- gemma/python/configs.cc | 156 -------------------- 10 files changed, 63 insertions(+), 190 deletions(-) delete mode 100644 gemma/python/BUILD.bazel delete mode 100644 gemma/python/configs.cc diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 7fd097c9..be4c5afa 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -690,12 +690,13 @@ class Compressor { } } - void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) { + BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) { const BlobError err = writer_.WriteAll(pool, blob_filename); if (err != 0) { fprintf(stderr, "Failed to write blobs to %s (error %d)\n", blob_filename.path.c_str(), err); } + return err; } private: diff --git a/compression/python/BUILD.bazel b/compression/python/BUILD.bazel index 6b451bd1..f12d8bea 100644 --- a/compression/python/BUILD.bazel +++ b/compression/python/BUILD.bazel @@ -28,6 +28,7 @@ pybind_extension( deps = [ ":compression_clif_aux", "@abseil-cpp//absl/types:span", + "//compression:sfp", ], ) @@ -38,6 +39,7 @@ py_test( deps = [ ":compression", "//testing/pybase", + "//python:configs", "//third_party/py/numpy", ], ) diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index ba91781e..e313bbee 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -31,7 +31,9 @@ class WriterInterface { public: virtual ~WriterInterface() = default; - virtual void Insert(std::string name, absl::Span weights) = 0; + virtual void Insert(std::string name, absl::Span weights, + Type type) = 0; + virtual void InsertSfp(std::string name, absl::Span weights) = 0; virtual void InsertNUQ(std::string name, absl::Span weights) = 0; virtual void InsertBfloat16(std::string name, absl::Span weights) = 0; @@ -39,7 +41,7 @@ class WriterInterface { absl::Span weights) = 0; virtual void AddScales(const std::vector& scales) = 0; - virtual void Write(std::string path) = 0; + virtual int Write(std::string path) = 0; }; } // namespace gcpp @@ -67,7 +69,27 @@ class SbsWriterImpl : public WriterInterface { public: SbsWriterImpl() : pool_(0), compressor_(pool_) {} - void Insert(std::string name, absl::Span weights) override { + void Insert(std::string name, absl::Span weights, + Type type) override { + switch (type) { + case Type::kSFP: + AllocateAndCompress(name, weights); + break; + case Type::kNUQ: + AllocateAndCompress(name, weights); + break; + case Type::kBF16: + AllocateAndCompress(name, weights); + break; + case Type::kF32: + AllocateAndCompress(name, weights); + break; + default: + HWY_ABORT("Unsupported type"); + } + } + + void InsertSfp(std::string name, absl::Span weights) override { AllocateAndCompress(name, weights); } @@ -90,8 +112,8 @@ class SbsWriterImpl : public WriterInterface { compressor_.AddScales(scales_.data(), scales_.size()); } - void Write(std::string path) override { - compressor_.WriteAll(pool_, gcpp::Path(path)); + int Write(std::string path) override { + return compressor_.WriteAll(pool_, gcpp::Path(path)); } hwy::ThreadPool pool_; @@ -115,8 +137,12 @@ HWY_EXPORT(NewSbsWriter); SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {} SbsWriter::~SbsWriter() = default; -void SbsWriter::Insert(std::string name, absl::Span weights) { - impl_->Insert(name, weights); +void SbsWriter::Insert(std::string name, absl::Span weights, + Type type) { + impl_->Insert(name, weights, type); +} +void SbsWriter::InsertSfp(std::string name, absl::Span weights) { + impl_->InsertSfp(name, weights); } void SbsWriter::InsertNUQ(std::string name, absl::Span weights) { impl_->InsertNUQ(name, weights); @@ -132,7 +158,7 @@ void SbsWriter::InsertFloat(std::string name, absl::Span weights) { void SbsWriter::AddScales(const std::vector& scales) { impl_->AddScales(scales); } -void SbsWriter::Write(std::string path) { impl_->Write(path); } +int SbsWriter::Write(std::string path) { return impl_->Write(path); } } // namespace gcpp #endif // HWY_ONCE diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h index fd4efc8d..cd2e4f1a 100644 --- a/compression/python/compression_clif_aux.h +++ b/compression/python/compression_clif_aux.h @@ -6,6 +6,7 @@ #include #include "absl/types/span.h" +#include "compression/shared.h" namespace gcpp { @@ -16,13 +17,14 @@ class SbsWriter { SbsWriter(); ~SbsWriter(); - void Insert(std::string name, absl::Span weights); + void Insert(std::string name, absl::Span weights, Type type); + void InsertSfp(std::string name, absl::Span weights); void InsertNUQ(std::string name, absl::Span weights); void InsertBfloat16(std::string name, absl::Span weights); void InsertFloat(std::string name, absl::Span weights); void AddScales(const std::vector& scales); - void Write(std::string path); + int Write(std::string path); private: // Isolates Highway-dispatched types and other internals from CLIF. diff --git a/compression/python/compression_extension.cc b/compression/python/compression_extension.cc index c2916a8b..c56f263e 100644 --- a/compression/python/compression_extension.cc +++ b/compression/python/compression_extension.cc @@ -1,11 +1,11 @@ #include -#include #include #include #include "absl/types/span.h" #include "compression/python/compression_clif_aux.h" +#include "compression/shared.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -22,6 +22,15 @@ void wrap_span(SbsWriter& writer, std::string name, py::array_t data) { } std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size())); } +template +void wrap_span_typed(SbsWriter& writer, std::string name, + py::array_t data, gcpp::Type type) { + if (data.ndim() != 1 || data.strides(0) != sizeof(float)) { + throw std::domain_error("Input array must be 1D and densely packed."); + } + std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()), + type); +} } // namespace PYBIND11_MODULE(compression, m) { @@ -29,7 +38,8 @@ PYBIND11_MODULE(compression, m) { .def(py::init<>()) // NOTE: Individual compression backends may impose constraints on the // array length, such as a minimum of (say) 32 elements. - .def("insert", wrap_span<&SbsWriter::Insert>) + .def("insert", wrap_span_typed<&SbsWriter::Insert>) + .def("insert_sfp", wrap_span<&SbsWriter::InsertSfp>) .def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>) .def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>) .def("insert_float", wrap_span<&SbsWriter::InsertFloat>) diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index 7b7ff12e..e25f06b7 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -4,6 +4,7 @@ import unittest from compression.python import compression +from gemma.python import configs class CompressionTest(unittest.TestCase): @@ -13,9 +14,11 @@ def test_sbs_writer(self): writer = compression.SbsWriter() writer.insert( - "foo", np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32) + "foo", + np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32), + configs.Type.kSFP, ) - writer.insert( + writer.insert_sfp( "bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32) ) writer.insert_nuq( @@ -27,7 +30,7 @@ def test_sbs_writer(self): writer.insert_float( "quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32) ) - writer.write(temp_file.full_path) + self.assertEqual(writer.write(temp_file.full_path), 0) if __name__ == "__main__": diff --git a/gemma/configs.cc b/gemma/configs.cc index 7724c590..7a792cf0 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -49,6 +49,7 @@ static ModelConfig ConfigGemma2_27B() { .heads = 32, .kv_heads = 16, .qkv_dim = 128, + .optimized_gating = false, .post_norm = PostNormType::Scale}; config.layer_configs = {46, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); @@ -70,6 +71,7 @@ static ModelConfig ConfigGemma2_9B() { .heads = 16, .kv_heads = 8, .qkv_dim = 256, + .optimized_gating = false, .post_norm = PostNormType::Scale}; config.layer_configs = {42, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); @@ -91,6 +93,7 @@ static ModelConfig ConfigGemma2_2B() { .heads = 8, .kv_heads = 4, .qkv_dim = 256, + .optimized_gating = false, .post_norm = PostNormType::Scale}; config.layer_configs = {26, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index 8128baf6..fa8d8700 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -374,7 +374,6 @@ void AssertMatch(const ModelConfig& config) { } ASSERT_EQ(TConfig::kVocabSize, config.vocab_size); ASSERT_EQ(TConfig::kSeqLen, config.seq_len); - // ASSERT_EQ(TConfig::kTopK, config.top_k); - is now a runtime config value. ASSERT_EQ(TConfig::kAttCap, config.att_cap); ASSERT_EQ(TConfig::kFinalCap, config.final_cap); ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe); diff --git a/gemma/python/BUILD.bazel b/gemma/python/BUILD.bazel deleted file mode 100644 index d6b09b95..00000000 --- a/gemma/python/BUILD.bazel +++ /dev/null @@ -1,17 +0,0 @@ -load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") - -package( - default_applicable_licenses = [ - "//:license", # Placeholder comment, do not modify - ], - default_visibility = ["//visibility:public"], -) - -pybind_extension( - name = "configs", - srcs = ["configs.cc"], - deps = [ - "//:common", - "//compression:sfp", - ], -) diff --git a/gemma/python/configs.cc b/gemma/python/configs.cc deleted file mode 100644 index 8c378407..00000000 --- a/gemma/python/configs.cc +++ /dev/null @@ -1,156 +0,0 @@ -#include "gemma/configs.h" - -#include -#include - -#include "compression/shared.h" -#include "gemma/tensor_index.h" -#include "pybind11/cast.h" - -using gcpp::ActivationType; -using gcpp::LayerAttentionType; -using gcpp::LayerConfig; -using gcpp::Model; -using gcpp::ModelConfig; -using gcpp::ModelTraining; -using gcpp::PostNormType; -using gcpp::PostQKType; -using gcpp::QueryScaleType; -using gcpp::ResidualType; -using gcpp::TensorIndex; -using gcpp::TensorInfo; -using gcpp::Type; - -namespace pybind11 { - -PYBIND11_MODULE(configs, py_module) { - enum_(py_module, "ModelTraining") - .value("GEMMA_IT", ModelTraining::GEMMA_IT) - .value("GEMMA_PT", ModelTraining::GEMMA_PT) - .value("PALIGEMMA", ModelTraining::PALIGEMMA); - - enum_(py_module, "Type") - .value("kUnknown", Type::kUnknown) - .value("kF32", Type::kF32) - .value("kBF16", Type::kBF16) - .value("kSFP", Type::kSFP) - .value("kNUQ", Type::kNUQ) - .value("kF64", Type::kF64) - .value("kC64", Type::kC64) - .value("kU128", Type::kU128); - - enum_(py_module, "LayerAttentionType") - .value("kGemma", LayerAttentionType::kGemma) - .value("kGriffinRecurrentBlock", - LayerAttentionType::kGriffinRecurrentBlock) - .value("kVit", LayerAttentionType::kVit); - - enum_(py_module, "PostNormType") - .value("NoPostNorm", PostNormType::None) - .value("Scale", PostNormType::Scale); - - enum_(py_module, "PostQKType") - .value("Rope", PostQKType::Rope) - .value("HalfRope", PostQKType::HalfRope); - - enum_(py_module, "ActivationType") - .value("Gelu", ActivationType::Gelu); - - enum_(py_module, "QueryScaleType") - .value("SqrtKeySize", QueryScaleType::SqrtKeySize) - .value("SqrtModelDimDivNumHeads", - QueryScaleType::SqrtModelDimDivNumHeads); - - enum_(py_module, "ResidualType") - .value("Add", ResidualType::Add); - - enum_(py_module, "Model") - .value("UNKNOWN", Model::UNKNOWN) - .value("GEMMA_2B", Model::GEMMA_2B) - .value("GEMMA_7B", Model::GEMMA_7B) - .value("GEMMA2_9B", Model::GEMMA2_9B) - .value("GEMMA2_27B", Model::GEMMA2_27B) - .value("GRIFFIN_2B", Model::GRIFFIN_2B) - .value("GEMMA_TINY", Model::GEMMA_TINY) - .value("GEMMA2_2B", Model::GEMMA2_2B) - .value("PALIGEMMA_224", Model::PALIGEMMA_224); - - class_(py_module, "TensorInfo") - .def(init()) - .def_readwrite("name", &TensorInfo::name) - .def_readwrite("source_names", &TensorInfo::source_names) - .def_readwrite("preshape", &TensorInfo::preshape) - .def_readwrite("axes", &TensorInfo::axes) - .def_readwrite("shape", &TensorInfo::shape) - .def_readwrite("concat_names", &TensorInfo::concat_names) - .def_readwrite("concat_axis", &TensorInfo::concat_axis) - .def_readwrite("min_size", &TensorInfo::min_size) - .def_readwrite("scaled_softplus", &TensorInfo::scaled_softplus) - .def_readwrite("cols_take_extra_dims", &TensorInfo::cols_take_extra_dims); - - class_(py_module, "TensorIndex") - .def(init()) - .def("get_tensor_info", &TensorIndex::GetTensorInfo, arg("path")); - - class_(py_module, "LayerConfig") - .def(init()) - .def_readwrite("model_dim", &LayerConfig::model_dim) - .def_readwrite("griffin_dim", &LayerConfig::griffin_dim) - .def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim) - .def_readwrite("heads", &LayerConfig::heads) - .def_readwrite("kv_heads", &LayerConfig::kv_heads) - .def_readwrite("qkv_dim", &LayerConfig::qkv_dim) - .def_readwrite("conv1d_width", &LayerConfig::conv1d_width) - .def_readwrite("ff_biases", &LayerConfig::ff_biases) - .def_readwrite("softmax_attn_output_biases", - &LayerConfig::softmax_attn_output_biases) - .def_readwrite("optimized_gating", &LayerConfig::optimized_gating) - .def_readwrite("post_norm", &LayerConfig::post_norm) - .def_readwrite("type", &LayerConfig::type) - .def_readwrite("activation", &LayerConfig::activation) - .def_readwrite("post_qk", &LayerConfig::post_qk); - - class_(py_module, "ModelConfig") - .def(init()) - .def_readwrite("model_name", &ModelConfig::model_name) - .def_readwrite("model", &ModelConfig::model) - .def_readwrite("training", &ModelConfig::training) - .def_readwrite("weight", &ModelConfig::weight) - .def_readwrite("num_layers", &ModelConfig::num_layers) - .def_readwrite("model_dim", &ModelConfig::model_dim) - .def_readwrite("vit_model_dim", &ModelConfig::vit_model_dim) - .def_readwrite("vocab_size", &ModelConfig::vocab_size) - .def_readwrite("seq_len", &ModelConfig::seq_len) - .def_readwrite("vit_seq_len", &ModelConfig::vit_seq_len) - .def_readwrite("num_tensor_scales", &ModelConfig::num_tensor_scales) - .def_readwrite("num_vit_scales", &ModelConfig::num_vit_scales) - .def_readwrite("att_cap", &ModelConfig::att_cap) - .def_readwrite("final_cap", &ModelConfig::final_cap) - .def_readwrite("absolute_pe", &ModelConfig::absolute_pe) - .def_readwrite("use_local_attention", &ModelConfig::use_local_attention) - .def_readwrite("query_scale", &ModelConfig::query_scale) - .def_readwrite("layer_configs", &ModelConfig::layer_configs) - .def_readwrite("attention_window_sizes", - &ModelConfig::attention_window_sizes) - .def_readwrite("vit_layer_configs", &ModelConfig::vit_layer_configs) - .def_readwrite("scale_names", &ModelConfig::scale_names) - .def_readwrite("norm_num_groups", &ModelConfig::norm_num_groups) - .def_readwrite("model_family_version", &ModelConfig::model_family_version) - .def_readwrite("patch_width", &ModelConfig::patch_width) - .def_readwrite("image_size", &ModelConfig::image_size) - .def("add_layer_config", &ModelConfig::AddLayerConfig, - arg("layer_config")) - .def("test_equal", &ModelConfig::TestEqual, arg("other"), arg("partial"), - arg("debug")); - - // Returns the config for the given model. - py_module.def("config_from_model", &gcpp::ConfigFromModel, arg("model")); - - // Returns the model for the given config, if it matches any standard model. - py_module.def("model_from_config", &gcpp::ModelFromConfig, arg("config")); - - // Returns the sub-config for the ViT model of the PaliGemma model. - py_module.def("vit_config", &gcpp::VitConfig, arg("config")); -} - -} // namespace pybind11