Skip to content

Commit

Permalink
Improved consistency of compressor API, and added a universal method …
Browse files Browse the repository at this point in the history
…with a target type arg.

Moved configs pybind up to root level.

PiperOrigin-RevId: 698743417
  • Loading branch information
theraysmith authored and copybara-github committed Nov 21, 2024
1 parent 73640d2 commit 3d1625d
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 190 deletions.
3 changes: 2 additions & 1 deletion compression/compress-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -690,12 +690,13 @@ class Compressor {
}
}

void WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
const BlobError err = writer_.WriteAll(pool, blob_filename);
if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
blob_filename.path.c_str(), err);
}
return err;
}

private:
Expand Down
2 changes: 2 additions & 0 deletions compression/python/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pybind_extension(
deps = [
":compression_clif_aux",
"@abseil-cpp//absl/types:span",
"//compression:sfp",
],
)

Expand All @@ -38,6 +39,7 @@ py_test(
deps = [
":compression",
"//testing/pybase",
"//python:configs",
"//third_party/py/numpy",
],
)
42 changes: 34 additions & 8 deletions compression/python/compression_clif_aux.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ class WriterInterface {
public:
virtual ~WriterInterface() = default;

virtual void Insert(std::string name, absl::Span<const float> weights) = 0;
virtual void Insert(std::string name, absl::Span<const float> weights,
Type type) = 0;
virtual void InsertSfp(std::string name, absl::Span<const float> weights) = 0;
virtual void InsertNUQ(std::string name, absl::Span<const float> weights) = 0;
virtual void InsertBfloat16(std::string name,
absl::Span<const float> weights) = 0;
virtual void InsertFloat(std::string name,
absl::Span<const float> weights) = 0;
virtual void AddScales(const std::vector<float>& scales) = 0;

virtual void Write(std::string path) = 0;
virtual int Write(std::string path) = 0;
};

} // namespace gcpp
Expand Down Expand Up @@ -67,7 +69,27 @@ class SbsWriterImpl : public WriterInterface {
public:
SbsWriterImpl() : pool_(0), compressor_(pool_) {}

void Insert(std::string name, absl::Span<const float> weights) override {
void Insert(std::string name, absl::Span<const float> weights,
Type type) override {
switch (type) {
case Type::kSFP:
AllocateAndCompress<SfpStream>(name, weights);
break;
case Type::kNUQ:
AllocateAndCompress<NuqStream>(name, weights);
break;
case Type::kBF16:
AllocateAndCompress<BF16>(name, weights);
break;
case Type::kF32:
AllocateAndCompress<float>(name, weights);
break;
default:
HWY_ABORT("Unsupported type");
}
}

void InsertSfp(std::string name, absl::Span<const float> weights) override {
AllocateAndCompress<SfpStream>(name, weights);
}

Expand All @@ -90,8 +112,8 @@ class SbsWriterImpl : public WriterInterface {
compressor_.AddScales(scales_.data(), scales_.size());
}

void Write(std::string path) override {
compressor_.WriteAll(pool_, gcpp::Path(path));
int Write(std::string path) override {
return compressor_.WriteAll(pool_, gcpp::Path(path));
}

hwy::ThreadPool pool_;
Expand All @@ -115,8 +137,12 @@ HWY_EXPORT(NewSbsWriter);
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
SbsWriter::~SbsWriter() = default;

void SbsWriter::Insert(std::string name, absl::Span<const float> weights) {
impl_->Insert(name, weights);
void SbsWriter::Insert(std::string name, absl::Span<const float> weights,
Type type) {
impl_->Insert(name, weights, type);
}
void SbsWriter::InsertSfp(std::string name, absl::Span<const float> weights) {
impl_->InsertSfp(name, weights);
}
void SbsWriter::InsertNUQ(std::string name, absl::Span<const float> weights) {
impl_->InsertNUQ(name, weights);
Expand All @@ -132,7 +158,7 @@ void SbsWriter::InsertFloat(std::string name, absl::Span<const float> weights) {
void SbsWriter::AddScales(const std::vector<float>& scales) {
impl_->AddScales(scales);
}
void SbsWriter::Write(std::string path) { impl_->Write(path); }
int SbsWriter::Write(std::string path) { return impl_->Write(path); }

} // namespace gcpp
#endif // HWY_ONCE
6 changes: 4 additions & 2 deletions compression/python/compression_clif_aux.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <vector>

#include "absl/types/span.h"
#include "compression/shared.h"

namespace gcpp {

Expand All @@ -16,13 +17,14 @@ class SbsWriter {
SbsWriter();
~SbsWriter();

void Insert(std::string name, absl::Span<const float> weights);
void Insert(std::string name, absl::Span<const float> weights, Type type);
void InsertSfp(std::string name, absl::Span<const float> weights);
void InsertNUQ(std::string name, absl::Span<const float> weights);
void InsertBfloat16(std::string name, absl::Span<const float> weights);
void InsertFloat(std::string name, absl::Span<const float> weights);
void AddScales(const std::vector<float>& scales);

void Write(std::string path);
int Write(std::string path);

private:
// Isolates Highway-dispatched types and other internals from CLIF.
Expand Down
14 changes: 12 additions & 2 deletions compression/python/compression_extension.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include <pybind11/pybind11.h>

#include <exception>
#include <stdexcept>
#include <string>

#include "absl/types/span.h"
#include "compression/python/compression_clif_aux.h"
#include "compression/shared.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
Expand All @@ -22,14 +22,24 @@ void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
}
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
}
template <auto Func>
void wrap_span_typed(SbsWriter& writer, std::string name,
py::array_t<float> data, gcpp::Type type) {
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
throw std::domain_error("Input array must be 1D and densely packed.");
}
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()),
type);
}
} // namespace

PYBIND11_MODULE(compression, m) {
py::class_<SbsWriter>(m, "SbsWriter")
.def(py::init<>())
// NOTE: Individual compression backends may impose constraints on the
// array length, such as a minimum of (say) 32 elements.
.def("insert", wrap_span<&SbsWriter::Insert>)
.def("insert", wrap_span_typed<&SbsWriter::Insert>)
.def("insert_sfp", wrap_span<&SbsWriter::InsertSfp>)
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
Expand Down
9 changes: 6 additions & 3 deletions compression/python/compression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import unittest
from compression.python import compression
from gemma.python import configs


class CompressionTest(unittest.TestCase):
Expand All @@ -13,9 +14,11 @@ def test_sbs_writer(self):

writer = compression.SbsWriter()
writer.insert(
"foo", np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32)
"foo",
np.array([0.0012] * 128 + [0.001] * 64, dtype=np.float32),
configs.Type.kSFP,
)
writer.insert(
writer.insert_sfp(
"bar", np.array([0.000375] * 128 + [0.00009] * 128, dtype=np.float32)
)
writer.insert_nuq(
Expand All @@ -27,7 +30,7 @@ def test_sbs_writer(self):
writer.insert_float(
"quux", np.array([0.000375] * 128 + [0.00006] * 128, dtype=np.float32)
)
writer.write(temp_file.full_path)
self.assertEqual(writer.write(temp_file.full_path), 0)


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ static ModelConfig ConfigGemma2_27B() {
.heads = 32,
.kv_heads = 16,
.qkv_dim = 128,
.optimized_gating = false,
.post_norm = PostNormType::Scale};
config.layer_configs = {46, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
Expand All @@ -70,6 +71,7 @@ static ModelConfig ConfigGemma2_9B() {
.heads = 16,
.kv_heads = 8,
.qkv_dim = 256,
.optimized_gating = false,
.post_norm = PostNormType::Scale};
config.layer_configs = {42, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
Expand All @@ -91,6 +93,7 @@ static ModelConfig ConfigGemma2_2B() {
.heads = 8,
.kv_heads = 4,
.qkv_dim = 256,
.optimized_gating = false,
.post_norm = PostNormType::Scale};
config.layer_configs = {26, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
Expand Down
1 change: 0 additions & 1 deletion gemma/configs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ void AssertMatch(const ModelConfig& config) {
}
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
// ASSERT_EQ(TConfig::kTopK, config.top_k); - is now a runtime config value.
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);
Expand Down
17 changes: 0 additions & 17 deletions gemma/python/BUILD.bazel

This file was deleted.

Loading

0 comments on commit 3d1625d

Please sign in to comment.