Skip to content

Commit

Permalink
Infra improvements:
Browse files Browse the repository at this point in the history
allocator: support mmap, fixed Bind, add padding
bench_matmul: Add PreventElision
BUILD: add ops_test build target
matmul.h: move ConstMat here; dynamic alloc of MatMulEnv
matmul_test: remove benchmarking
replace fprintf with HWY_WARN
threading.cc: support splitting large clusters (disabled); package_idx->pkg_idx, smaller IndexRangePartition
PiperOrigin-RevId: 708306224
  • Loading branch information
jan-wassenberg authored and copybara-github committed Dec 20, 2024
1 parent 9d40f01 commit 4493531
Show file tree
Hide file tree
Showing 14 changed files with 722 additions and 636 deletions.
20 changes: 13 additions & 7 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ cc_test(
],
)

# For building all tests in one command, so we can test several.
test_suite(
name = "ops_tests",
tags = ["ops_tests"],
)

cc_library(
name = "ops",
hdrs = [
Expand Down Expand Up @@ -110,7 +116,7 @@ cc_test(
srcs = ["ops/dot_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":allocator",
":ops",
Expand All @@ -135,7 +141,7 @@ cc_test(
srcs = ["ops/ops_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":allocator",
":common",
Expand All @@ -157,7 +163,7 @@ cc_test(
srcs = ["ops/gemma_matvec_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":ops",
"@googletest//:gtest_main", # buildcleaner: keep
Expand All @@ -175,7 +181,7 @@ cc_test(
srcs = ["ops/matmul_unit_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":allocator",
":basics",
Expand All @@ -195,7 +201,7 @@ cc_test(
srcs = ["ops/matmul_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":allocator",
":basics",
Expand All @@ -205,7 +211,6 @@ cc_test(
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
],
)
Expand All @@ -217,7 +222,7 @@ cc_test(
srcs = ["ops/bench_matmul.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["hwy_ops_test"],
tags = ["ops_tests"],
deps = [
":allocator",
":basics",
Expand All @@ -228,6 +233,7 @@ cc_test(
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:thread_pool",
],
)
Expand Down
9 changes: 1 addition & 8 deletions compression/compress.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,6 @@ decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
}
}

template <typename T>
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
mat.scale = m.scale();
return mat;
}

// MatStorageT adds the actual data storage to MatPtrT.
// TODO: use Extents2D instead of rows and cols.
template <typename MatT>
Expand Down Expand Up @@ -361,7 +354,7 @@ class MatStorageT : public MatPtrT<MatT> {
}

private:
hwy::AlignedFreeUniquePtr<MatT[]> data_;
AlignedPtr<MatT> data_;
};

// MatStorage allows heterogeneous tensors to be stored in a single vector.
Expand Down
8 changes: 4 additions & 4 deletions compression/shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,11 @@ struct PackedSpan {
// Ensures callers can read or write `num_accessible` elements starting at
// `packed_ofs`.
void BoundsCheck(size_t packed_ofs, size_t num_accessible) const {
// For NUQ, there can be fewer Packed than the number of elements, hence
// check the compressed count and ensure we have that many.
const size_t required =
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
if constexpr (HWY_IS_DEBUG_BUILD) {
// For NUQ, there can be fewer Packed than the number of elements, hence
// check the compressed count and ensure we have that many.
const size_t required =
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
if (num < required) {
HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed",
packed_ofs, num_accessible, required, num);
Expand Down
6 changes: 4 additions & 2 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <stddef.h>

#include <cmath>
#include <memory> // std::unique_ptr

#include "compression/shared.h" // BF16
#include "gemma/configs.h"
Expand Down Expand Up @@ -63,7 +64,8 @@ struct Activations {
// Rope
RowVectorBatch<float> inv_timescale;

MatMulEnv env;
// Dynamic because no default ctor and only initialized in `Allocate`.
std::unique_ptr<MatMulEnv> env;

PostQKType post_qk = PostQKType::Rope;
// And the config.
Expand Down Expand Up @@ -122,7 +124,7 @@ struct Activations {

inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);

env = MatMulEnv(pools);
env = std::make_unique<MatMulEnv>(pools);
}
};

Expand Down
36 changes: 18 additions & 18 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Griffin");
KVCache& kv_cache = kv_caches[0];
hwy::ThreadPool& pool = activations.env.Pool();
hwy::ThreadPool& pool = activations.env->Pool();
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const size_t model_dim = layer_weights->layer_config.model_dim;
Expand Down Expand Up @@ -252,7 +252,7 @@ class GemmaAttention {
const size_t w1_rows = heads * layer_config_.QStride();
w_q1.ShrinkRows(w1_rows);
MatMul(pre_att_rms_out, w_q1,
/*add=*/nullptr, activations_.env, RowPtrFromBatch(activations_.q));
/*add=*/nullptr, *activations_.env, RowPtrFromBatch(activations_.q));

if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
Expand All @@ -275,7 +275,7 @@ class GemmaAttention {
RowPtrF kv_rows(kv, w_rows_kv_cols);
kv_rows.SetStride(cache_pos_size_);
MatMul(pre_att_rms_out, w_q2,
/*add=*/nullptr, activations_.env, kv_rows);
/*add=*/nullptr, *activations_.env, kv_rows);
} else {
// Proceed row by row because there will be wraparound.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
Expand Down Expand Up @@ -464,7 +464,7 @@ class GemmaAttention {
: nullptr;
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
ConstMatFromWeights(layer_weights_.att_weights), add,
activations_.env, RowPtrFromBatch(activations_.att_sums));
*activations_.env, RowPtrFromBatch(activations_.att_sums));
}

public:
Expand Down Expand Up @@ -514,7 +514,7 @@ class GemmaAttention {
layer_weights_(*layer_weights),
div_seq_len_(div_seq_len),
kv_caches_(kv_caches),
pool_(activations.env.Pool()) {
pool_(activations.env->Pool()) {
HWY_DASSERT(num_queries_ <= kv_caches_.size());
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
"query heads must be a multiple of key-value heads");
Expand Down Expand Up @@ -587,7 +587,7 @@ class VitAttention {
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
layer_weights_.vit.qkv_einsum_b.data_scale1(), activations_.env,
layer_weights_.vit.qkv_einsum_b.data_scale1(), *activations_.env,
RowPtrFromBatch(qkv));
}

Expand Down Expand Up @@ -641,7 +641,7 @@ class VitAttention {
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
auto att_sums = RowPtrFromBatch(activations_.att_sums);
MatMul(att_out, att_weights, bias, activations_.env, att_sums);
MatMul(att_out, att_weights, bias, *activations_.env, att_sums);
}

public:
Expand All @@ -652,7 +652,7 @@ class VitAttention {
activations_(activations),
layer_weights_(*layer_weights),
layer_config_(layer_weights->layer_config),
pool_(activations.env.Pool()) {}
pool_(activations.env->Pool()) {}

HWY_INLINE void operator()() {
ComputeQKV();
Expand Down Expand Up @@ -728,8 +728,8 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
auto w_output = ConstMatFromWeights(layer_weights->linear_w);

// Compute the hidden layer activations.
MatMul(x, w1, bias1, activations.env, hidden_activations);
MatMul(x, w2, bias2, activations.env, multiplier);
MatMul(x, w1, bias1, *activations.env, hidden_activations);
MatMul(x, w2, bias2, *activations.env, multiplier);

// Activation (Gelu) and maybe multiply by gate. Store activations in act.
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
Expand All @@ -739,7 +739,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
auto activations_mat = MakeConstMat(
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim));

MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
}

// Same as FFWNoVit, but with different layer_weights members and no second
Expand Down Expand Up @@ -769,7 +769,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);

// Compute the hidden layer activations.
MatMul(x, w1, bias1, activations.env, hidden_activations);
MatMul(x, w1, bias1, *activations.env, hidden_activations);

// Activation (Gelu), store in act.
RowPtrF multiplier = RowPtrF(nullptr, 0);
Expand All @@ -780,7 +780,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
auto activations_mat = MakeConstMat(
hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim));

MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out);
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
}

// `batch_idx` indicates which row of `x` to write to.
Expand Down Expand Up @@ -1063,7 +1063,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
// MatMul(
// MatFromBatch(kVitSeqLen, image_patches),
// MatFromWeights(weights.vit_img_embedding_kernel),
// weights.vit_img_embedding_bias.data_scale1(), activations.env,
// weights.vit_img_embedding_bias.data_scale1(), *activations.env,
// RowPtrF(activations.x.All(), kVitModelDim));
// However, MatMul currently requires that
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
Expand All @@ -1073,7 +1073,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
image_patches[i].get(),
weights.vit_img_embedding_bias.data_scale1(),
activations.x.Batch(i), activations.env.Pool());
activations.x.Batch(i), activations.env->Pool());
}
// Add position embeddings.
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
Expand Down Expand Up @@ -1107,7 +1107,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
// Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul(ConstMatFromBatch(num_tokens, activations.x),
ConstMatFromWeights(weights.vit_img_head_kernel),
weights.vit_img_head_bias.data_scale1(), activations.env,
weights.vit_img_head_bias.data_scale1(), *activations.env,
RowPtrFromBatch(image_tokens));
}

Expand Down Expand Up @@ -1280,7 +1280,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
Activations prefill_activations(weights.weights_config);
if (use_prefill_activations) {
prefill_activations.Allocate(runtime_config.prefill_tbatch_size,
activations.env.Pools());
activations.env->Pools());
}
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
query_idx_start, weights,
Expand Down Expand Up @@ -1325,7 +1325,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
// Compute logits from last layer activations.
MatMul(ConstMatFromBatch(num_queries, activations.x),
ConstMatFromWeights(weights.embedder_input_embedding),
/*add=*/nullptr, activations.env,
/*add=*/nullptr, *activations.env,
RowPtrFromBatch(activations.logits));
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
Expand Down
Loading

0 comments on commit 4493531

Please sign in to comment.