From 4493531d72045b6212db5b075ec37cc66438f153 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 20 Dec 2024 06:33:15 -0800 Subject: [PATCH] Infra improvements: allocator: support mmap, fixed Bind, add padding bench_matmul: Add PreventElision BUILD: add ops_test build target matmul.h: move ConstMat here; dynamic alloc of MatMulEnv matmul_test: remove benchmarking replace fprintf with HWY_WARN threading.cc: support splitting large clusters (disabled); package_idx->pkg_idx, smaller IndexRangePartition PiperOrigin-RevId: 708306224 --- BUILD.bazel | 20 +-- compression/compress.h | 9 +- compression/shared.h | 8 +- gemma/activations.h | 6 +- gemma/gemma-inl.h | 36 +++--- ops/bench_matmul.cc | 80 +++++++----- ops/matmul.h | 66 ++++++++++ ops/matmul_test.cc | 58 +++------ util/allocator.cc | 260 ++++++++++++++++++++++++++++----------- util/allocator.h | 273 +++++++++++++++++++++++++++++------------ util/basics.h | 143 +-------------------- util/threading.cc | 161 +++++++++++++++--------- util/threading.h | 139 +++++++++------------ util/threading_test.cc | 99 ++------------- 14 files changed, 722 insertions(+), 636 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 1dc6a1f..260b90c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -76,6 +76,12 @@ cc_test( ], ) +# For building all tests in one command, so we can test several. +test_suite( + name = "ops_tests", + tags = ["ops_tests"], +) + cc_library( name = "ops", hdrs = [ @@ -110,7 +116,7 @@ cc_test( srcs = ["ops/dot_test.cc"], local_defines = ["HWY_IS_TEST"], # for test_suite. - tags = ["hwy_ops_test"], + tags = ["ops_tests"], deps = [ ":allocator", ":ops", @@ -135,7 +141,7 @@ cc_test( srcs = ["ops/ops_test.cc"], local_defines = ["HWY_IS_TEST"], # for test_suite. - tags = ["hwy_ops_test"], + tags = ["ops_tests"], deps = [ ":allocator", ":common", @@ -157,7 +163,7 @@ cc_test( srcs = ["ops/gemma_matvec_test.cc"], local_defines = ["HWY_IS_TEST"], # for test_suite. - tags = ["hwy_ops_test"], + tags = ["ops_tests"], deps = [ ":ops", "@googletest//:gtest_main", # buildcleaner: keep @@ -175,7 +181,7 @@ cc_test( srcs = ["ops/matmul_unit_test.cc"], local_defines = ["HWY_IS_TEST"], # for test_suite. - tags = ["hwy_ops_test"], + tags = ["ops_tests"], deps = [ ":allocator", ":basics", @@ -195,7 +201,7 @@ cc_test( srcs = ["ops/matmul_test.cc"], local_defines = ["HWY_IS_TEST"], # for test_suite. - tags = ["hwy_ops_test"], + tags = ["ops_tests"], deps = [ ":allocator", ":basics", @@ -205,7 +211,6 @@ cc_test( "//compression:compress", "@highway//:hwy", "@highway//:hwy_test_util", - "@highway//:nanobenchmark", "@highway//:thread_pool", ], ) @@ -217,7 +222,7 @@ cc_test( srcs = ["ops/bench_matmul.cc"], local_defines = ["HWY_IS_TEST"], # for test_suite. - tags = ["hwy_ops_test"], + tags = ["ops_tests"], deps = [ ":allocator", ":basics", @@ -228,6 +233,7 @@ cc_test( "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:nanobenchmark", + "@highway//:profiler", "@highway//:thread_pool", ], ) diff --git a/compression/compress.h b/compression/compress.h index ddfd16c..233f8d7 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -309,13 +309,6 @@ decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) { } } -template -ConstMat ConstMatFromWeights(const MatPtrT& m, size_t ofs = 0) { - ConstMat mat = MakeConstMat(const_cast(m.data()), m.Extents(), ofs); - mat.scale = m.scale(); - return mat; -} - // MatStorageT adds the actual data storage to MatPtrT. // TODO: use Extents2D instead of rows and cols. template @@ -361,7 +354,7 @@ class MatStorageT : public MatPtrT { } private: - hwy::AlignedFreeUniquePtr data_; + AlignedPtr data_; }; // MatStorage allows heterogeneous tensors to be stored in a single vector. diff --git a/compression/shared.h b/compression/shared.h index c9ce1c0..f9814a1 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -273,11 +273,11 @@ struct PackedSpan { // Ensures callers can read or write `num_accessible` elements starting at // `packed_ofs`. void BoundsCheck(size_t packed_ofs, size_t num_accessible) const { - // For NUQ, there can be fewer Packed than the number of elements, hence - // check the compressed count and ensure we have that many. - const size_t required = - CompressedArrayElements(packed_ofs + num_accessible); if constexpr (HWY_IS_DEBUG_BUILD) { + // For NUQ, there can be fewer Packed than the number of elements, hence + // check the compressed count and ensure we have that many. + const size_t required = + CompressedArrayElements(packed_ofs + num_accessible); if (num < required) { HWY_ABORT("PackedSpan: ofs %zu, want %zu, req %zu > %zu packed", packed_ofs, num_accessible, required, num); diff --git a/gemma/activations.h b/gemma/activations.h index 3863325..7a53aa1 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -19,6 +19,7 @@ #include #include +#include // std::unique_ptr #include "compression/shared.h" // BF16 #include "gemma/configs.h" @@ -63,7 +64,8 @@ struct Activations { // Rope RowVectorBatch inv_timescale; - MatMulEnv env; + // Dynamic because no default ctor and only initialized in `Allocate`. + std::unique_ptr env; PostQKType post_qk = PostQKType::Rope; // And the config. @@ -122,7 +124,7 @@ struct Activations { inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk); - env = MatMulEnv(pools); + env = std::make_unique(pools); } }; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 81a9469..feec7a2 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -81,7 +81,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, const KVCaches& kv_caches) { PROFILER_ZONE("Gen.Griffin"); KVCache& kv_cache = kv_caches[0]; - hwy::ThreadPool& pool = activations.env.Pool(); + hwy::ThreadPool& pool = activations.env->Pool(); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const size_t model_dim = layer_weights->layer_config.model_dim; @@ -252,7 +252,7 @@ class GemmaAttention { const size_t w1_rows = heads * layer_config_.QStride(); w_q1.ShrinkRows(w1_rows); MatMul(pre_att_rms_out, w_q1, - /*add=*/nullptr, activations_.env, RowPtrFromBatch(activations_.q)); + /*add=*/nullptr, *activations_.env, RowPtrFromBatch(activations_.q)); if (is_mha_) { // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. @@ -275,7 +275,7 @@ class GemmaAttention { RowPtrF kv_rows(kv, w_rows_kv_cols); kv_rows.SetStride(cache_pos_size_); MatMul(pre_att_rms_out, w_q2, - /*add=*/nullptr, activations_.env, kv_rows); + /*add=*/nullptr, *activations_.env, kv_rows); } else { // Proceed row by row because there will be wraparound. for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; @@ -464,7 +464,7 @@ class GemmaAttention { : nullptr; MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out), ConstMatFromWeights(layer_weights_.att_weights), add, - activations_.env, RowPtrFromBatch(activations_.att_sums)); + *activations_.env, RowPtrFromBatch(activations_.att_sums)); } public: @@ -514,7 +514,7 @@ class GemmaAttention { layer_weights_(*layer_weights), div_seq_len_(div_seq_len), kv_caches_(kv_caches), - pool_(activations.env.Pool()) { + pool_(activations.env->Pool()) { HWY_DASSERT(num_queries_ <= kv_caches_.size()); HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0, "query heads must be a multiple of key-value heads"); @@ -587,7 +587,7 @@ class VitAttention { HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out), ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w), - layer_weights_.vit.qkv_einsum_b.data_scale1(), activations_.env, + layer_weights_.vit.qkv_einsum_b.data_scale1(), *activations_.env, RowPtrFromBatch(qkv)); } @@ -641,7 +641,7 @@ class VitAttention { auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out); auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w); auto att_sums = RowPtrFromBatch(activations_.att_sums); - MatMul(att_out, att_weights, bias, activations_.env, att_sums); + MatMul(att_out, att_weights, bias, *activations_.env, att_sums); } public: @@ -652,7 +652,7 @@ class VitAttention { activations_(activations), layer_weights_(*layer_weights), layer_config_(layer_weights->layer_config), - pool_(activations.env.Pool()) {} + pool_(activations.env->Pool()) {} HWY_INLINE void operator()() { ComputeQKV(); @@ -728,8 +728,8 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved, auto w_output = ConstMatFromWeights(layer_weights->linear_w); // Compute the hidden layer activations. - MatMul(x, w1, bias1, activations.env, hidden_activations); - MatMul(x, w2, bias2, activations.env, multiplier); + MatMul(x, w1, bias1, *activations.env, hidden_activations); + MatMul(x, w2, bias2, *activations.env, multiplier); // Activation (Gelu) and maybe multiply by gate. Store activations in act. Activation(layer_weights->layer_config.activation, hidden_activations.Row(0), @@ -739,7 +739,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved, auto activations_mat = MakeConstMat( hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim)); - MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out); + MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out); } // Same as FFWNoVit, but with different layer_weights members and no second @@ -769,7 +769,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w); // Compute the hidden layer activations. - MatMul(x, w1, bias1, activations.env, hidden_activations); + MatMul(x, w1, bias1, *activations.env, hidden_activations); // Activation (Gelu), store in act. RowPtrF multiplier = RowPtrF(nullptr, 0); @@ -780,7 +780,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, auto activations_mat = MakeConstMat( hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim)); - MatMul(activations_mat, w_output, output_bias, activations.env, ffw_out); + MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out); } // `batch_idx` indicates which row of `x` to write to. @@ -1063,7 +1063,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, // MatMul( // MatFromBatch(kVitSeqLen, image_patches), // MatFromWeights(weights.vit_img_embedding_kernel), - // weights.vit_img_embedding_bias.data_scale1(), activations.env, + // weights.vit_img_embedding_bias.data_scale1(), *activations.env, // RowPtrF(activations.x.All(), kVitModelDim)); // However, MatMul currently requires that // A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0 @@ -1073,7 +1073,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size, image_patches[i].get(), weights.vit_img_embedding_bias.data_scale1(), - activations.x.Batch(i), activations.env.Pool()); + activations.x.Batch(i), activations.env->Pool()); } // Add position embeddings. AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(), @@ -1107,7 +1107,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, // Apply head embedding into image_tokens of size of the LLM kModelDim. MatMul(ConstMatFromBatch(num_tokens, activations.x), ConstMatFromWeights(weights.vit_img_head_kernel), - weights.vit_img_head_bias.data_scale1(), activations.env, + weights.vit_img_head_bias.data_scale1(), *activations.env, RowPtrFromBatch(image_tokens)); } @@ -1280,7 +1280,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, Activations prefill_activations(weights.weights_config); if (use_prefill_activations) { prefill_activations.Allocate(runtime_config.prefill_tbatch_size, - activations.env.Pools()); + activations.env->Pools()); } Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, query_idx_start, weights, @@ -1325,7 +1325,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, // Compute logits from last layer activations. MatMul(ConstMatFromBatch(num_queries, activations.x), ConstMatFromWeights(weights.embedder_input_embedding), - /*add=*/nullptr, activations.env, + /*add=*/nullptr, *activations.env, RowPtrFromBatch(activations.logits)); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index be58d29..0fb3bcd 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -16,7 +16,8 @@ // Benchmark of large MatMul instances for which the MatMulSlow would be too // slow. This lacks a reference and is only useful for performance measurement. -#include "hwy/detect_compiler_arch.h" +#include "hwy/base.h" +#include "hwy/nanobenchmark.h" #ifndef HWY_DISABLED_TARGETS // Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require // double-precision support. @@ -30,7 +31,10 @@ #include #include +#include #include +#include +#include #include "compression/compress.h" #include "compression/shared.h" @@ -38,7 +42,6 @@ #include "util/allocator.h" #include "util/basics.h" #include "util/threading.h" -#include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/timer.h" @@ -51,6 +54,7 @@ // After highway.h #include "compression/compress-inl.h" #include "ops/matmul-inl.h" +#include "hwy/profiler.h" // also uses SIMD #include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); @@ -74,7 +78,8 @@ MatStoragePtr GenerateMat(const Extents2D extents, std::make_unique>("mat", extents.rows, extents.cols); FloatPtr content = hwy::AllocateAligned(mat->NumElements()); HWY_ASSERT(content); - const float scale = SfpStream::kMax / (mat->NumElements()); + const float scale = + SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1); pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { for (size_t c = 0; c < extents.cols; c++) { float f = static_cast(r * extents.cols + c) * scale; @@ -96,7 +101,8 @@ MatStoragePtr GenerateTransposedMat(const Extents2D extents, auto mat = std::make_unique>("trans", extents.rows, extents.cols); FloatPtr content = hwy::AllocateAligned(mat->NumElements()); - const float scale = SfpStream::kMax / (mat->NumElements()); + const float scale = + SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1); pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { for (size_t c = 0; c < extents.cols; c++) { float f = static_cast(c * extents.rows + r) * scale; @@ -111,52 +117,63 @@ MatStoragePtr GenerateTransposedMat(const Extents2D extents, return mat; } -void PrintSpeed(const char* algo, const Extents2D& A_extents, - const Extents2D& B_extents, double elapsed) { +void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, + std::vector& times) { + std::sort(times.begin(), times.end()); + // Many measurements are with suboptimal configs, so report the best like + // bench_dnn, but also the ratio to the 3rd best. + const double elapsed = times[0]; + const double ratio = times[2] / HWY_MAX(elapsed, 1E-6); + const size_t num_b = B_extents.Area(); // 2x because of FMA. - fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo, - elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed); + fprintf(stderr, "%.1f\t%.2f\n", 2 * 1E-9 * A_extents.rows * num_b / elapsed, + ratio); } // Generates inputs and prints observed throughput of MatMul. +// M = A rows, K = A cols, N = C cols. template -void BenchMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, - MatMulEnv& env) { +void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { hwy::ThreadPool& pool = env.Pool(); - fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", - rows_ac, cols_a_rows_b, cols_bc, add, TypeName(), - TypeName()); + fprintf(stderr, "BenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", M, + K, N, add, TypeName(), TypeName()); - const Extents2D A_extents(rows_ac, cols_a_rows_b); - const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed - const Extents2D C_extents(rows_ac, cols_bc); + const Extents2D A_extents(M, K); + const Extents2D B_extents(N, K); // already transposed + const Extents2D C_extents(M, N); - MatStoragePtr a = GenerateMat(A_extents, pool); - MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); RowVectorBatch c_slow_batch(C_extents); RowVectorBatch c_batch(C_extents); - HWY_ASSERT(a && b_trans); std::unique_ptr> add_storage; if (add) { - add_storage = GenerateMat(Extents2D(1, cols_bc), pool); + add_storage = GenerateMat(Extents2D(1, N), pool); HWY_ASSERT(add_storage); add_storage->set_scale(1.0f); } + MatStoragePtr a = GenerateMat(A_extents, pool); + MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); + HWY_ASSERT(a && b_trans); const auto A = ConstMatFromWeights(*a); const auto B = ConstMatFromWeights(*b_trans); + const float* add_row = add ? add_storage->data_scale1() : nullptr; const RowPtrF C = RowPtrFromBatch(c_batch); - double min_elapsed = hwy::HighestValue(); - for (int rep = 0; rep < 3; ++rep) { - const double start_tiled = hwy::platform::Now(); + std::vector times; + times.reserve(20); + double result = 0.0; + for (;;) { + const double t0 = hwy::platform::Now(); MatMul(A, B, add_row, env, C); - min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled); + times.push_back(hwy::platform::Now() - t0); + result += C.Row(0)[hwy::Unpredictable1()]; + if (times.size() >= 20) break; } - PrintSpeed("MatMul", A_extents, B_extents, min_elapsed); + hwy::PreventElision(result); + PrintSpeed(A_extents, B_extents, times); } using F32 = float; @@ -184,16 +201,15 @@ void BenchAllMatMul() { Allocator::Init(pools.Topology()); MatMulEnv env(pools); - for (size_t batch_size : {1, /*4, 64,*/ 128}) { - BenchMatMul(batch_size, 24576, 3072, /*add=*/false, env); - BenchMatMul(batch_size, 3072, 24576, /*add=*/false, env); - BenchMatMul(batch_size, 24576, 3072, /*add=*/false, env); - BenchMatMul(batch_size, 3072, 24576, /*add=*/false, env); - BenchMatMul(batch_size, 24576, 3072, /*add=*/false, env); - BenchMatMul(batch_size, 3072, 24576, /*add=*/false, env); + for (size_t batch_size : {1, /* 4, 128,*/ 512}) { + constexpr bool kAdd = false; + BenchMatMul(batch_size, 24576, 3072, kAdd, env); + BenchMatMul(batch_size, 3072, 24576, kAdd, env); } pools.MaybeStopSpinning(use_spinning); } + + PROFILER_PRINT_RESULTS(); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/ops/matmul.h b/ops/matmul.h index e77bd93..00392ce 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -19,6 +19,8 @@ #include // IWYU pragma: begin_exports +#include "compression/compress.h" +#include "util/allocator.h" #include "util/basics.h" #include "util/threading.h" #include "hwy/base.h" @@ -81,6 +83,70 @@ class MatMulEnv { NestedPools* pools_; }; +// Used for the A and B arguments of `MatMul`, which are always const. +// Create via MakeConstMat. This differs from `RowPtr` in that it supports the +// `ofs` required for compressed T. +template +struct ConstMat { + ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0) + : ptr(ptr), extents(extents), ofs(ofs) { + HWY_DASSERT(ptr != nullptr); + } + // TODO: support stride for page alignment. + size_t Row(size_t r) const { + if constexpr (HWY_IS_DEBUG_BUILD) { + if (r >= extents.rows) { + HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows); + } + } + return ofs + extents.cols * r; + } + + const Extents2D& Extents() const { return extents; } + size_t Stride() const { return extents.cols; } + + // Shrinks the row-extent of this matrix view, i.e. reduces the view to a + // subrange of the original rows starting at row 0. + void ShrinkRows(size_t rows) { + HWY_ASSERT(rows <= extents.rows); + extents.rows = rows; + } + + const T* HWY_RESTRICT ptr; + Extents2D extents; + + // `scale` allows expanding the smaller range of `SfpStream` to the original + // values. MatFromWeights sets this from `MatPtr`. + float scale = 1.0f; + + // Offset to add to `ptr`; separate because T=NuqStream does not support + // pointer arithmetic. + size_t ofs; +}; + +// For deducing T. +template +ConstMat MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, + size_t ofs = 0) { + return ConstMat(ptr, extents, ofs); +} + +// For A argument to MatMul (activations). +template +ConstMat ConstMatFromBatch(size_t batch_size, + const RowVectorBatch& row_vectors) { + HWY_DASSERT(batch_size <= row_vectors.BatchSize()); + return MakeConstMat(const_cast(row_vectors.Const()), + Extents2D(batch_size, row_vectors.Cols())); +} + +template +ConstMat ConstMatFromWeights(const MatPtrT& m, size_t ofs = 0) { + ConstMat mat = MakeConstMat(const_cast(m.data()), m.Extents(), ofs); + mat.scale = m.scale(); + return mat; +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index ac79e4b..7239912 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -39,7 +39,6 @@ #include "util/threading.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/timer.h" // clang-format off #undef HWY_TARGET_INCLUDE @@ -55,7 +54,7 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { -// For running TestBatchSizes only once. Defined within HWY_ONCE. +// For running TestTiny only once. Defined within HWY_ONCE. extern int64_t first_target; namespace HWY_NAMESPACE { @@ -144,10 +143,10 @@ void AssertClose(const ConstMat& A, const ConstMat& B, const hn::ScalableTag df; const size_t num_a = A.extents.Area(); const size_t num_b = B.extents.Area(); - HWY_ASSERT(num_a % hn::Lanes(df) == 0); // for DecompressAndZeroPad - HWY_ASSERT(num_b % hn::Lanes(df) == 0); // for DecompressAndZeroPad - FloatPtr a = hwy::AllocateAligned(num_a); - FloatPtr b_trans = hwy::AllocateAligned(num_b); + const size_t N = hn::Lanes(df); + // Round up for DecompressAndZeroPad. + FloatPtr a = hwy::AllocateAligned(hwy::RoundUpTo(num_a, N)); + FloatPtr b_trans = hwy::AllocateAligned(hwy::RoundUpTo(num_b, N)); HWY_ASSERT(a && b_trans); HWY_ASSERT(A.ofs == 0 && B.ofs == 0); DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a); @@ -164,13 +163,11 @@ void AssertClose(const ConstMat& A, const ConstMat& B, double tolerance = 8 * norm * eps_f32; // Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the // tolerance there. - if (IsF32() && IsF32()) { + if (IsF32() && !IsF32()) { tolerance += 4 * max_abs * eps_bf16; } - EXPECT_GE(tolerance, 1E-4); - if (tolerance > 4.0) { - fprintf(stderr, "WARN: high tolerance %f norm %f maxabs %f\n", tolerance, - norm, max_abs); + if (tolerance > 8.0) { + HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs); } for (size_t r = 0; r < A.extents.rows; r++) { @@ -182,11 +179,10 @@ void AssertClose(const ConstMat& A, const ConstMat& B, if (!(expected_value - tolerance <= actual_value && actual_value <= expected_value + tolerance)) { - fprintf(stderr, - "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " - "tolerance %f\n", - r, c, expected_value, actual_value, norm, max_abs, tolerance); - return; + HWY_ABORT( + "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " + "tolerance %f\n", + r, c, expected_value, actual_value, norm, max_abs, tolerance); } } } @@ -217,7 +213,7 @@ HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, get_row_c, all_packages, [&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR { hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx); - const size_t multiple = Allocator::Alignment() / sizeof(MatTB); + const size_t multiple = Allocator::QuantumBytes() / sizeof(MatTB); const IndexRangePartition get_col_c = StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); ParallelizeOneRange( @@ -248,7 +244,6 @@ template void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatMulEnv& env) { hwy::ThreadPool& pool = env.Pool(); - const bool want_bench = cols_bc > 2000; // avoid spam for small matrices fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", rows_ac, cols_a_rows_b, cols_bc, add, TypeName(), TypeName()); @@ -276,32 +271,17 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch); const RowPtrF C = RowPtrFromBatch(c_batch); - const double start_slow = hwy::platform::Now(); MatMulSlow(A, B, add_row, env, C_slow); - if (want_bench) { - PrintSpeed("MatMulSlow", A_extents, B_extents, - hwy::platform::Now() - start_slow); - } - - double min_elapsed = hwy::HighestValue(); - for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) { - const double start_tiled = hwy::platform::Now(); - MatMul(A, B, add_row, env, C); - min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled); - } - if (want_bench) { - PrintSpeed("MatMul", A_extents, B_extents, min_elapsed); - } - + MatMul(A, B, add_row, env, C); AssertClose(A, B, C_slow, C); } using F32 = float; using SFP = SfpStream; -// Sweep batch_size for a single input type and Highway target, to verify the -// row partitioning. -void TestBatchSizes() { +// Sweep all dimensions for a single input type and Highway target, to verify +// the remainder handling. +void TestTiny() { if (first_target == 0) first_target = HWY_TARGET; if (HWY_TARGET != first_target) return; @@ -315,7 +295,7 @@ void TestBatchSizes() { // If less than the limit, we have already tested all num_packages. if (pools.Topology().FullTopology().packages.size() < max_packages) break; #endif - fprintf(stderr, "TestBatchSizes %zu: %s %s\n", max_packages, + fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages, pools.TopologyString(), pools.PinString()); Tristate use_spinning = Tristate::kDefault; @@ -405,7 +385,7 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { int64_t first_target = 0; // none run yet HWY_BEFORE_TEST(MatMulTest); -HWY_EXPORT_AND_TEST_P(MatMulTest, TestBatchSizes); +HWY_EXPORT_AND_TEST_P(MatMulTest, TestTiny); HWY_EXPORT_AND_TEST_P(MatMulTest, TestAllMatMul); HWY_AFTER_TEST(); diff --git a/util/allocator.cc b/util/allocator.cc index 8f41f00..5274785 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -17,45 +17,154 @@ #include +#include +#include #include #include "util/basics.h" // MaybeCheckInitialized +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/per_target.h" // VectorBytes -#if GEMMA_NUMA -#if HWY_OS_WIN -#ifndef NOMINMAX -#define NOMINMAX +// To avoid a dependency on libnuma, use syscalls directly. We require six +// arguments, which has been supported by glibc since around 2010. +#if defined(__GLIBC__) && defined(__GLIBC_PREREQ) +#if HWY_OS_LINUX && __GLIBC_PREREQ(2, 11) +#define GEMMA_LINUX_SYSCALL6 #endif -#ifndef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN #endif -#include -#elif HWY_OS_LINUX + +#ifndef GEMMA_BIND // allow override +#if defined(GEMMA_LINUX_SYSCALL6) +#define GEMMA_BIND 1 +#else +#define GEMMA_BIND 0 +#endif +#endif // GEMMA_BIND + +#if GEMMA_BIND && HWY_OS_LINUX +// `move_pages` requires anonymous/private mappings, hence mmap. +#include #include #include -#endif // HWY_OS_* -#endif // GEMMA_NUMA +#endif // GEMMA_BIND && HWY_OS_LINUX namespace gcpp { +namespace { -/*static*/ size_t Allocator::bytes_per_page_; -/*static*/ bool Allocator::use_numa_; -/*static*/ size_t Allocator::alignment_; - -/*static*/ size_t Allocator::DetectPageSize() { -#if HWY_OS_WIN - SYSTEM_INFO sys_info; - GetSystemInfo(&sys_info); - return sys_info.dwPageSize; -#elif HWY_OS_LINUX - return sysconf(_SC_PAGESIZE); +size_t DetectLineBytes() { + if (const hwy::Cache* caches = hwy::DataCaches()) { + // Might not have an L3. + return HWY_MAX(caches[2].bytes_per_line, caches[3].bytes_per_line); + } else { + return HWY_ALIGNMENT; + } +} + +size_t DetectPageSize() { +#if HWY_OS_LINUX + size_t page_bytes = static_cast(sysconf(_SC_PAGESIZE)); + HWY_ASSERT(page_bytes <= (4 << 20)); + return page_bytes; #else return 0; #endif } -#if GEMMA_NUMA && HWY_OS_LINUX +} // namespace + +static size_t line_bytes_; +static size_t vector_bytes_; +static size_t quantum_bytes_; +static size_t l1_bytes_; +static size_t l2_bytes_; +static bool should_bind_ = false; + +size_t Allocator::LineBytes() { return line_bytes_; } +size_t Allocator::VectorBytes() { return vector_bytes_; } +size_t Allocator::QuantumBytes() { return quantum_bytes_; } +size_t Allocator::L1Bytes() { return l1_bytes_; } +size_t Allocator::L2Bytes() { return l2_bytes_; } +bool Allocator::ShouldBind() { return should_bind_; } + +void Allocator::Init(const BoundedTopology& topology) { + line_bytes_ = DetectLineBytes(); + vector_bytes_ = hwy::VectorBytes(); + quantum_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); // may overwrite below + + if (const hwy::Cache* caches = hwy::DataCaches()) { + l1_bytes_ = caches[1].size_kib << 10; + l2_bytes_ = caches[2].size_kib << 10; + } else { // Unknown, make reasonable assumptions. + const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0); + l1_bytes_ = 32 << 10; + l2_bytes_ = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) << 10; + } + + // Prerequisites for binding: + // - supported by the OS (currently Linux only), + // - the page size is known and 'reasonably small', preferably less than + // a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB. + // - we successfully detected topology and there are multiple nodes; + // - there are multiple packages, because we shard by package_idx. + if constexpr (GEMMA_BIND) { + const size_t page_bytes = DetectPageSize(); + if ((page_bytes != 0 && page_bytes <= 16 * 1024) && + topology.NumNodes() > 1 && topology.NumPackages() > 1) { + // Ensure pages meet the alignment requirements of `AllocBytes`. + HWY_ASSERT(page_bytes >= quantum_bytes_); + quantum_bytes_ = page_bytes; + should_bind_ = true; + } + } +} + +Allocator::PtrAndDeleter Allocator::AllocBytes(size_t bytes) { + // If we are not binding, the Highway allocator is cheaper than `mmap`, and + // defends against 2K aliasing. However, we can only use it if the alignment + // >= `QuantumBytes()`. + if (!should_bind_ && HWY_ALIGNMENT >= QuantumBytes()) { + auto p = hwy::AllocateAligned(bytes); + // The `hwy::AlignedFreeUniquePtr` deleter is unfortunately specific to the + // alignment scheme in aligned_allocator.cc and does not work for + // already-aligned pointers as returned by `mmap`, hence we wrap the Highway + // pointer in our own deleter. + auto call_free = [](void* ptr, size_t /*bytes*/) { + hwy::FreeAlignedBytes(ptr, nullptr, nullptr); + }; + return PtrAndDeleter{p.release(), Deleter(call_free, bytes)}; + } + + // Binding, or large vector/cache line size: use platform-specific allocator. + +#if HWY_OS_LINUX + // `move_pages` is documented to require an anonymous/private mapping or + // `MAP_SHARED`. A normal allocation might not suffice, so we use `mmap`. + // `Init` verified that the page size is a multiple of `QuantumBytes()`. + const int prot = PROT_READ | PROT_WRITE; + const int flags = MAP_ANONYMOUS | MAP_PRIVATE; + const int fd = -1; + // Encourage transparent hugepages by rounding up to a multiple of 2 MiB. + bytes = hwy::RoundUpTo(bytes, 2ull << 20); + void* p = mmap(0, bytes, prot, flags, fd, off_t{0}); + if (p == MAP_FAILED) p = nullptr; + const auto call_munmap = [](void* ptr, size_t bytes) { + const int ret = munmap(ptr, bytes); + HWY_ASSERT(ret == 0); + }; + return PtrAndDeleter{p, Deleter(call_munmap, bytes)}; +#elif HWY_OS_WIN + const auto call_free = [](void* ptr, void*) { _aligned_free(ptr); }; + const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_); + return PtrAndDeleter{_aligned_malloc(bytes, alignment), + Deleter(call_free, bytes)}; +#else + return PtrAndDeleter{nullptr, Deleter(nullptr, 0)}; +#endif +} + +#if GEMMA_BIND && HWY_OS_LINUX using Ret = long; // NOLINT(runtime/int) using UL = unsigned long; // NOLINT(runtime/int) @@ -76,90 +185,91 @@ struct SyscallWrappers { MaybeCheckInitialized(status, count * sizeof(int)); return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags); } + + static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr, + unsigned flags) { + return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags); + } }; +// Returns the number of pages that are currently busy (hence not yet moved), +// and warns if there are any other reasons for not moving a page. Note that +// `move_pages` can return 0 regardless of whether all pages were moved. size_t CountBusyPages(size_t num_pages, size_t node, void** pages, const int* status) { - // Return value 0 does not actually guarantee all pages were moved. size_t num_busy = 0; for (size_t i = 0; i < num_pages; ++i) { if (status[i] == -EBUSY) { ++num_busy; - // Touch - hwy::ZeroBytes(pages[i], 8); } else if (status[i] != static_cast(node)) { - fprintf(stderr, "Error %d moving pages[%zu]=%p to node %zu (errno %d)\n", - status[i], i, pages[i], node, errno); + static std::atomic_flag first = ATOMIC_FLAG_INIT; + if (!first.test_and_set()) { + HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).", + status[i], i, pages[i], node, errno); + } } } return num_busy; } -// Attempts to move(!) memory to the given NUMA node, typically obtained from -// `BoundedTopology::GetCluster(package_idx, cluster_idx).node`. Using `mbind` -// directly is easier than calling libnuma's `numa_move_pages`, which requires -// an array of pages. Note that `numa_tonode_memory` is insufficient because -// it does not specify the `MPOL_MF_MOVE` flag, so it only sets the policy, -// which means it would have to be called before pages are faulted in, but -// `aligned_allocator.h` modifies the first bytes for its bookkeeping. -// May overwrite some of the memory with zeros. -void BindMemory(void* ptr, size_t bytes, size_t node) { +bool Allocator::BindMemory(void* ptr, size_t bytes, size_t node) { + HWY_DASSERT(should_bind_); constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough" + + if constexpr (HWY_IS_DEBUG_BUILD) { + // Ensure the requested `node` is allowed. + UL nodes[kMaxNodes / 64] = {0}; + const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED + HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes, + nullptr, flags) == 0); + HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64))); + } + // Avoid mbind because it does not report why it failed, which is most likely // because pages are busy, in which case we want to know which. -#if 0 - // nodemask with only the given node set. - UL nodes[hwy::DivCeil(kMaxNodes, ULBits)] = {}; - nodes[node / ULBits] = 1ULL << (node % ULBits); - - const int mode = 2; // MPOL_BIND - const unsigned flags = 3; // MPOL_MF_MOVE | MPOL_MF_STRICT - const int ret = - SyscallWrappers::mbind(ptr, bytes, mode, nodes, kMaxNodes, flags); - if (ret != 0) { - fprintf(stderr, "Failed to bind %p %zu to node %zu (errno %d)\n", ptr, - bytes, node, errno); - } -#elif 1 + + // `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set. const unsigned flags = 2; // MPOL_MF_MOVE - const size_t bytes_per_page = static_cast(sysconf(_SC_PAGESIZE)); - HWY_ASSERT(bytes % bytes_per_page == 0); - const size_t num_pages = bytes / bytes_per_page; + HWY_ASSERT(bytes % quantum_bytes_ == 0); + const size_t num_pages = bytes / quantum_bytes_; std::vector pages; pages.reserve(num_pages); for (size_t i = 0; i < num_pages; ++i) { - pages.push_back(static_cast(ptr) + i * bytes_per_page); + pages.push_back(static_cast(ptr) + i * quantum_bytes_); + // Ensure the page is faulted in to prevent `move_pages` from failing, + // because freshly allocated pages may be mapped to a shared 'zero page'. + hwy::ZeroBytes(pages.back(), 8); } std::vector nodes(num_pages, node); std::vector status(num_pages, static_cast(kMaxNodes)); + Ret ret = SyscallWrappers::move_pages( /*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags); - size_t num_busy = - CountBusyPages(num_pages, node, pages.data(), status.data()); - if (num_busy != 0) { - // Try again - ret = SyscallWrappers::move_pages( - /*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags); - const size_t num_busy_before = num_busy; - num_busy = CountBusyPages(num_pages, node, pages.data(), status.data()); - fprintf( - stderr, - "second try still %zu busy, was %zu. 2nd ret %d status %d %d %d %d\n", - num_busy, num_busy_before, static_cast(ret), status[0], status[1], - status[2], status[3]); + if (ret < 0) { + HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr, + bytes, node, errno, status[0]); + return false; } - if (ret < 0) { - fprintf(stderr, - "Failed to bind %p %zu to node %zu (errno %d) status %d %d\n", ptr, - bytes, node, errno, status[0], status[1]); + const size_t num_busy = + CountBusyPages(num_pages, node, pages.data(), status.data()); + if (HWY_UNLIKELY(num_busy != 0)) { + // Trying again is usually enough to succeed. + usleep(5); // NOLINT(runtime/sleep) + (void)SyscallWrappers::move_pages( + /*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags); + const size_t still_busy = + CountBusyPages(num_pages, node, pages.data(), status.data()); + if (HWY_UNLIKELY(still_busy != 0)) { + HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.", + still_busy, num_busy); + } } -#endif + return true; } #else -// TODO: support other OSes. -void BindMemory(void*, size_t, size_t) {} -#endif // GEMMA_NUMA && HWY_OS_LINUX +void Allocator::BindMemory(void*, size_t, size_t) {} +#endif // GEMMA_BIND && HWY_OS_LINUX } // namespace gcpp diff --git a/util/allocator.h b/util/allocator.h index cf1e161..5c99476 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -19,114 +19,233 @@ #include #include -#include // std::aligned_alloc / _aligned_malloc - // IWYU pragma: begin_exports +#include + #include "util/basics.h" #include "util/threading.h" -#include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" // IWYU pragma: end_exports -#ifndef GEMMA_NUMA -// The check below requires two #if, hence start with 0 and redefine to 1. -#define GEMMA_NUMA 0 - -// To avoid a dependency on libnuma, use syscalls directly. We require six -// arguments, which has been supported by glibc since around 2010. -#if defined(__GLIBC__) && defined(__GLIBC_PREREQ) -#if HWY_OS_LINUX && __GLIBC_PREREQ(2, 11) -#undef GEMMA_NUMA -#define GEMMA_NUMA 1 -#endif -#endif - -#endif // GEMMA_NUMA +#include "hwy/aligned_allocator.h" namespace gcpp { -using ByteStorageT = hwy::AlignedFreeUniquePtr; +// Points to an adapter lambda that calls `FreeAlignedBytes` or `munmap`. The +// `bytes` argument is required for the latter. +using FreeFunc = void (*)(void* mem, size_t bytes); -template -ByteStorageT AllocateSizeof() { - return hwy::AllocateAligned(sizeof(T)); -} - -// Stateful in order to know whether to bind to NUMA nodes. `Monostate` for -// convenience - avoids passing around a reference. -class Allocator { +// Custom deleter for std::unique_ptr that calls `FreeFunc`. +class Deleter { public: - static void Init(const BoundedTopology& topology) { - bytes_per_page_ = DetectPageSize(); - HWY_ASSERT(bytes_per_page_ <= (4 << 20)); - - // NUMA only makes sense if: - // - the page size is known and 'reasonably small', preferably less than - // a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB. - // - we successfully detected topology and there are multiple nodes; - // - there are multiple packages, because we shard by package_idx. - use_numa_ = (bytes_per_page_ != 0 && bytes_per_page_ <= 16 * 1024) && - topology.NumNodes() > 1 && topology.NumPackages() > 1; - // TODO: remove once tensors are page-aligned. - use_numa_ = false; - fprintf(stderr, "Warning: disabling use_numa_\n"); - - alignment_ = use_numa_ ? bytes_per_page_ : HWY_ALIGNMENT; + // `MatStorageT` requires this to be default-constructible. + Deleter() : free_func_(nullptr), bytes_(0) {} + Deleter(FreeFunc free_func, size_t bytes) + : free_func_(free_func), bytes_(bytes) {} + + template + void operator()(T* p) const { + free_func_(p, bytes_); } - static bool UseNUMA() { return use_numa_; } + private: + FreeFunc free_func_; + size_t bytes_; +}; - // BindTensor requires row pointers and lengths be a multiple of this. - static size_t Alignment() { return alignment_; } +// Unique (move-only) pointer to an aligned array of POD T. +template +using AlignedPtr = std::unique_ptr; +// Both allocation, binding, and row accessors depend on the sizes of memory +// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we +// use `Monostate` (static members). +class Allocator { + public: + // Must be called at least once before any other function. Not thread-safe, + // hence only call this from the main thread. + static void Init(const BoundedTopology& topology); + + // Bytes per cache line, or a reasonable guess if unknown. Used to choose + // ranges such that there will be no false sharing. + static size_t LineBytes(); + // Bytes per full vector. Used to compute loop steps. + static size_t VectorBytes(); + // Granularity of regions processed by different threads. Their start and + // length of regions should be divisible by this, which is at least + // `HWY_MAX(LineBytes(), VectorBytes())`. + static size_t QuantumBytes(); + static size_t L1Bytes(); + static size_t L2Bytes(); + + // Returns pointer aligned to `QuantumBytes()`. template - static hwy::AlignedFreeUniquePtr Alloc(size_t num) { - // For non-NUMA, use the Highway allocator because it defends against 2k - // aliasing. - if (!use_numa_) return hwy::AllocateAligned(num); - + static AlignedPtr Alloc(size_t num) { constexpr size_t kSize = sizeof(T); - // Ensure the `bytes = num * kSize` computation did not overflow. constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0; constexpr size_t kBits = hwy::detail::ShiftCount(kSize); static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug"); const size_t bytes = kIsPow2 ? num << kBits : num * kSize; + // Fail if the `bytes = num * kSize` computation overflowed. const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize; - if (check != num) { - return hwy::AlignedFreeUniquePtr(); // overflowed + if (check != num) return AlignedPtr(); + + PtrAndDeleter pd = AllocBytes(bytes); + return AlignedPtr(static_cast(pd.p), pd.deleter); + } + + // Returns whether `BindMemory` can/should be called, i.e. we have page-level + // control over memory placement and multiple packages and NUMA nodes. + static bool ShouldBind(); + + // Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is + // typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`. + // Writes zeros to SOME of the memory. Only call if `ShouldBind()`. + // `p` and `bytes` must be multiples of `QuantumBytes()`. + static bool BindMemory(void* p, size_t bytes, size_t node); + + private: + // Type-erased so this can be implemented in allocator.cc. + struct PtrAndDeleter { + void* p; + Deleter deleter; + }; + static PtrAndDeleter AllocBytes(size_t bytes); +}; + +// Owns dynamically-allocated aligned memory for a batch of row vectors. +// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns +// the memory. +template +class RowVectorBatch { + public: + // Default ctor for Activations ctor. + RowVectorBatch() = default; + // Main ctor, called from Activations::Allocate. If `stride` = 0, the default, + // we default to tightly packed rows (`stride = cols`). + // WARNING: not all call sites support `stride` != cols. + // TODO: once they do, remove stride and behave like AllocateAlignedRows here. + RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) { + if (stride == 0) { + stride_ = extents_.cols; + } else { + HWY_ASSERT(stride >= extents_.cols); + stride_ = stride; } + mem_ = Allocator::Alloc(extents_.rows * stride_); + } - // AlignedFreeUniquePtr has a deleter that can call an arbitrary `free`, but - // with an extra opaque pointer, which we discard via `call_free`. -#if defined(__ANDROID_API__) && __ANDROID_API__ < 28 - const auto call_free = [](void* ptr, void*) { std::free(ptr); }; - void* mem = nullptr; - int err = posix_memalign(&mem, Alignment(), bytes); - HWY_ASSERT(err == 0); - T* p = static_cast(mem); -#elif HWY_OS_WIN - const auto call_free = [](void* ptr, void*) { _aligned_free(ptr); }; - T* p = static_cast(_aligned_malloc(bytes, Alignment())); -#else - const auto call_free = [](void* ptr, void*) { std::free(ptr); }; - T* p = static_cast(std::aligned_alloc(Alignment(), bytes)); -#endif - return hwy::AlignedFreeUniquePtr( - p, hwy::AlignedFreer(call_free, nullptr)); + // Move-only + RowVectorBatch(RowVectorBatch&) noexcept = delete; + RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; + RowVectorBatch(RowVectorBatch&&) noexcept = default; + RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; + + size_t BatchSize() const { return extents_.rows; } + size_t Cols() const { return extents_.cols; } + size_t Stride() const { return stride_; } + Extents2D Extents() const { return extents_; } + + // Returns the given row vector of length `Cols()`. + T* Batch(size_t batch_idx) { + HWY_DASSERT(batch_idx < BatchSize()); + return mem_.get() + batch_idx * stride_; + } + const T* Batch(size_t batch_idx) const { + HWY_DASSERT(batch_idx < BatchSize()); + return mem_.get() + batch_idx * stride_; } + // For MatMul or other operations that process the entire batch at once. + // TODO: remove once we only use Mat. + T* All() { return mem_.get(); } + const T* Const() const { return mem_.get(); } + size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); } + private: - static size_t DetectPageSize(); + AlignedPtr mem_; + Extents2D extents_; + size_t stride_; +}; + +// Returns `num` rounded up to an odd number of cache lines. This is used to +// compute strides. An odd number of cache lines prevents 2K aliasing and is +// coprime with the cache associativity, which reduces conflict misses. +template +static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) { + HWY_DASSERT(line_bytes >= 32); + HWY_DASSERT(line_bytes % sizeof(T) == 0); + const size_t lines = hwy::DivCeil(num * sizeof(T), line_bytes); + const size_t padded_num = (lines | 1) * line_bytes / sizeof(T); + HWY_DASSERT(padded_num >= num); + return padded_num; +} + +// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because +// it is always float and does not support compressed T, but does support an +// arbitrary stride >= cols. +#pragma pack(push, 1) // power of two size +template +class RowPtr { + public: + RowPtr() = default; // for `MMPtrs`. + RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride) + : row0_(row0), + stride_(stride), + step_(static_cast( + HWY_MAX(Allocator::LineBytes(), Allocator::VectorBytes()))), + cols_(static_cast(cols)), + row_mask_(Allocator::QuantumBytes() / step_ - 1) { + HWY_DASSERT(stride >= cols); + HWY_DASSERT(row_mask_ != ~size_t{0}); + row_mask_ = 0; // TODO: remove + } + RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {} + + T* HWY_RESTRICT Row(size_t r) const { + // How much of the previous row's padding to consume. + const size_t pad_bytes = (r & row_mask_) * step_; + HWY_DASSERT(pad_bytes < Allocator::QuantumBytes()); + return row0_ + stride_ * r - pad_bytes; + } + size_t Cols() const { return cols_; } + + size_t Stride() const { return stride_; } + void SetStride(size_t stride) { + HWY_DASSERT(stride >= Cols()); + stride_ = stride; + // The caller might not have padded enough, so disable the padding in Row(). + // Rows will now be exactly `stride` elements apart. This is used when + // writing to the KV cache via MatMul. + row_mask_ = 0; + } - // Required for BindMemory. Usually 4K, but can differ on Arm. - static size_t bytes_per_page_; - static bool use_numa_; - static size_t alignment_; + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + RowPtr View(size_t r, size_t c, size_t cols) const { + HWY_DASSERT(c < cols_); + HWY_DASSERT(cols <= cols_ - c); + return RowPtr(Row(r) + c, cols, stride_); + } + + private: + T* HWY_RESTRICT row0_; + size_t stride_; + uint32_t step_; // Copy from Allocator::LineBytes() to improve locality. + uint32_t cols_; + size_t row_mask_; }; +#pragma pack(pop) -// For future NUMA support. TODO: use. -void BindMemory(void* ptr, size_t bytes, size_t node); +using RowPtrBF = RowPtr; +using RowPtrF = RowPtr; +using RowPtrD = RowPtr; + +// For C argument to MatMul. +template +RowPtr RowPtrFromBatch(RowVectorBatch& row_vectors) { + return RowPtr(row_vectors.All(), row_vectors.Cols(), row_vectors.Stride()); +} } // namespace gcpp diff --git a/util/basics.h b/util/basics.h index 06f8d5d..b25a9ef 100644 --- a/util/basics.h +++ b/util/basics.h @@ -64,8 +64,8 @@ struct TokenAndProb { // Entire size of a 2D array. struct Extents2D { - Extents2D() : rows(0), cols(0) {} - Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) { + constexpr Extents2D() : rows(0), cols(0) {} + constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) { HWY_DASSERT(rows != 0); HWY_DASSERT(cols != 0); } @@ -77,6 +77,7 @@ struct Extents2D { }; struct IndexRange { + IndexRange() = default; IndexRange(size_t begin, size_t end) : begin_(begin), end_(end) { HWY_DASSERT(begin < end); } @@ -113,144 +114,6 @@ static inline IndexRange MakeIndexRange(size_t begin, size_t end, size_t max_size) { return IndexRange(begin, HWY_MIN(begin + max_size, end)); } - -// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because -// it is always float and does not support compressed T, but does support an -// arbitrary stride >= cols. -template -class RowPtr { - public: - RowPtr(T* HWY_RESTRICT row0, size_t cols) - : row0_(row0), cols_(cols), stride_(cols) {} - RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride) - : row0_(row0), cols_(cols), stride_(stride) { - HWY_DASSERT(stride >= cols); - } - - T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } - size_t Cols() const { return cols_; } - - size_t Stride() const { return stride_; } - void SetStride(size_t stride) { - HWY_DASSERT(stride >= Cols()); - stride_ = stride; - } - - private: - T* HWY_RESTRICT row0_; - size_t stride_; - size_t cols_; -}; - -using RowPtrF = RowPtr; - -// Owns dynamically-allocated aligned memory for a batch of row vectors. -// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns -// the memory. -template -class RowVectorBatch { - public: - // Default ctor for Activations ctor. - RowVectorBatch() = default; - // Main ctor, called from Activations::Allocate. - RowVectorBatch(Extents2D extents) : extents_(extents) { - mem_ = hwy::AllocateAligned(extents_.rows * extents_.cols); - } - - // Move-only - RowVectorBatch(RowVectorBatch&) noexcept = delete; - RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; - RowVectorBatch(RowVectorBatch&&) noexcept = default; - RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; - - size_t BatchSize() const { return extents_.rows; } - size_t Cols() const { return extents_.cols; } - Extents2D Extents() const { return extents_; } - - // Returns the given row vector of length `Cols()`. - T* Batch(size_t batch_idx) { - HWY_DASSERT(batch_idx < BatchSize()); - return mem_.get() + batch_idx * Cols(); - } - const T* Batch(size_t batch_idx) const { - HWY_DASSERT(batch_idx < BatchSize()); - return mem_.get() + batch_idx * Cols(); - } - - // For MatMul or other operations that process the entire batch at once. - // TODO: remove once we only use Mat. - T* All() { return mem_.get(); } - const T* Const() const { return mem_.get(); } - size_t NumBytes() const { return BatchSize() * Cols() * sizeof(T); } - - private: - hwy::AlignedFreeUniquePtr mem_; - Extents2D extents_; -}; - -// Used for the A and B arguments of `MatMul`, which are always const. -// Create via MakeConstMat. This differs from `RowPtr` in that it supports the -// `ofs` required for compressed T. -template -struct ConstMat { - ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0) - : ptr(ptr), extents(extents), ofs(ofs) { - HWY_DASSERT(ptr != nullptr); - } - // TODO: support stride for page alignment. - size_t Row(size_t r) const { - if constexpr (HWY_IS_DEBUG_BUILD) { - if (r >= extents.rows) { - HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows); - } - } - return ofs + extents.cols * r; - } - - const Extents2D& Extents() const { return extents; } - size_t Stride() const { return extents.cols; } - - // Shrinks the row-extent of this matrix view, i.e. reduces the view to a - // subrange of the original rows starting at row 0. - void ShrinkRows(size_t rows) { - HWY_ASSERT(rows <= extents.rows); - extents.rows = rows; - } - - const T* HWY_RESTRICT ptr; - Extents2D extents; - - // `scale` allows expanding the smaller range of `SfpStream` to the original - // values. MatFromWeights sets this from `MatPtr`. - float scale = 1.0f; - - // Offset to add to `ptr`; separate because T=NuqStream does not support - // pointer arithmetic. - size_t ofs; -}; - -// For deducing T. -template -ConstMat MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, - size_t ofs = 0) { - return ConstMat(ptr, extents, ofs); -} - -// For A argument to MatMul (activations). -template -ConstMat ConstMatFromBatch(size_t batch_size, - const RowVectorBatch& row_vectors) { - HWY_DASSERT(batch_size <= row_vectors.BatchSize()); - return MakeConstMat(const_cast(row_vectors.Const()), - Extents2D(batch_size, row_vectors.Cols())); -} - -// For C argument to MatMul. -template -RowPtr RowPtrFromBatch(RowVectorBatch& row_vectors) { - return RowPtr(row_vectors.All(), row_vectors.Cols()); -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_ diff --git a/util/threading.cc b/util/threading.cc index 3c0ff0d..a10862e 100644 --- a/util/threading.cc +++ b/util/threading.cc @@ -55,10 +55,8 @@ class Pinning { LPS enabled_lps; if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) { const size_t num_lps = hwy::TotalLogicalProcessors(); - fprintf( - stderr, - "Warning, unknown OS affinity, considering all %zu LPs enabled\n.", - num_lps); + HWY_WARN("unknown OS affinity, considering all %zu LPs enabled.", + num_lps); for (size_t lp = 0; lp < num_lps; ++lp) { enabled_lps.Set(lp); } @@ -71,8 +69,7 @@ class Pinning { const size_t lp = enabled_lps.First(); enabled_lps = LPS(); enabled_lps.Set(lp); - fprintf(stderr, - "Warning, threads not supported, using only the main thread\n."); + HWY_WARN("Warning, threads not supported, using only the main thread."); } original_affinity_ = enabled_lps; @@ -155,23 +152,10 @@ BoundedTopology::BoundedTopology(BoundedSlice package_slice, HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0); } -// Topology is unknown, rely on OS affinity and user-specified slice. -BoundedTopology::Cluster::Cluster(const LPS& enabled_lps, - BoundedSlice lp_slice) { - // Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so - // we honor both the OS affinity and the user-specified slice. Note that - // this can be used to exclude hyperthreads because Linux groups LPs by - // sibling index. For example, the first `num_cores` are not siblings. - const size_t detected = enabled_lps.Count(); - size_t enabled_idx = 0; - enabled_lps.Foreach([&](size_t lp) { - if (lp_slice.Contains(detected, enabled_idx++)) { - AddLP(lp); - } - }); - - // lp_slice can only reduce the number of `enabled_lps`, and not below 1. - HWY_ASSERT(num_workers_ != 0); +// Topology is unknown, take the given set of LPs. +BoundedTopology::Cluster::Cluster(const LPS& lps) { + lps_ = lps; + num_workers_ = lps.Count(); } BoundedTopology::Cluster::Cluster(const LPS& enabled_lps, @@ -183,7 +167,9 @@ BoundedTopology::Cluster::Cluster(const LPS& enabled_lps, // Skip if not first-hyperthread or disabled. if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return; - AddLP(lp); + HWY_ASSERT(!lps_.Get(lp)); // Foreach ensures uniqueness + lps_.Set(lp); + ++num_workers_; // Set fields once, and ensure subsequent LPs match - we assume there // is only one NUMA node per cluster, with the same L2/L3 size. @@ -198,30 +184,63 @@ BoundedTopology::Cluster::Cluster(const LPS& enabled_lps, if (HWY_LIKELY(!warned)) { if (HWY_UNLIKELY(lp_node != node_)) { warned = true; - fprintf(stderr, "WARNING: lp %zu on node %zu != cluster node %zu.\n", - lp, lp_node, node_); + HWY_WARN("lp %zu on node %zu != cluster node %zu.", lp, lp_node, + node_); } if (HWY_UNLIKELY(private_kib_ != tcluster.private_kib)) { warned = true; - fprintf(stderr, "WARNING: lp %zu private_kib %zu != cluster %zu.\n", - lp, private_kib_, tcluster.private_kib); + HWY_WARN("lp %zu private_kib %zu != cluster %zu.", lp, private_kib_, + tcluster.private_kib); } if (HWY_UNLIKELY(shared_kib_ != tcluster.shared_kib)) { warned = true; - fprintf(stderr, "WARNING: lp %zu shared_kib %zu != cluster %zu.\n", - lp, shared_kib_, tcluster.shared_kib); + HWY_WARN("lp %zu shared_kib %zu != cluster %zu.", lp, shared_kib_, + tcluster.shared_kib); } } // !warned } }); } +// CPUs without clusters are rarely more than dozens of cores, and 6 is a +// decent number of threads in a per-cluster pool. +constexpr bool kSplitLargeClusters = false; +constexpr size_t kMaxClusters = 8; +constexpr size_t kMaxLPsPerCluster = 6; + +// Topology is unknown, rely on OS affinity and user-specified slice. +BoundedTopology::Package::Package(const LPS& enabled_lps, + BoundedSlice lp_slice) { + LPS clusters_lps[kMaxClusters]; + const size_t num_clusters = + kSplitLargeClusters + ? HWY_MIN(kMaxClusters, + hwy::DivCeil(enabled_lps.Count(), kMaxLPsPerCluster)) + : 1; + + // Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so + // we honor both the OS affinity and the user-specified slice. Note that + // this can be used to exclude hyperthreads because Linux groups LPs by + // sibling index. For example, the first `num_cores` are not siblings. + const size_t detected = enabled_lps.Count(); + size_t enabled_idx = 0; + enabled_lps.Foreach([&](size_t lp) { + if (lp_slice.Contains(detected, enabled_idx)) { + clusters_lps[enabled_idx % num_clusters].Set(lp); + } + ++enabled_idx; + }); + + for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { + clusters.push_back(Cluster(clusters_lps[cluster_idx])); + } +} + // NOTE: caller is responsible for checking whether `clusters` is empty. BoundedTopology::Package::Package(const LPS& enabled_lps, - const hwy::Topology& topology, - size_t package_idx, + const hwy::Topology& topology, size_t pkg_idx, BoundedSlice cluster_slice) { - const hwy::Topology::Package& tpackage = topology.packages[package_idx]; + const hwy::Topology::Package& tpackage = topology.packages[pkg_idx]; // Populate `clusters` with the subset of clusters in `cluster_slice` that // have any enabled LPs. If `clusters` remains empty, the caller will // skip this `Package`. @@ -233,10 +252,34 @@ BoundedTopology::Package::Package(const LPS& enabled_lps, // Skip if empty, i.e. too few `enabled_lps`. if (HWY_LIKELY(cluster.Size() != 0)) { - clusters.push_back(std::move(cluster)); + clusters.push_back(cluster); } }); SortByDescendingSize(clusters); + + // If there is only one large cluster, split it into smaller ones. + if (kSplitLargeClusters && clusters.size() == 1 && + enabled_lps.Count() >= 16) { + const LPS lps = clusters[0].LPSet(); // copy so we can clear + clusters.clear(); + + // Split `lps` into several clusters. + LPS clusters_lps[kMaxClusters]; + const size_t num_clusters = + HWY_MIN(kMaxClusters, hwy::DivCeil(lps.Count(), kMaxLPsPerCluster)); + size_t num_lps = 0; + lps.Foreach( + [&](size_t lp) { clusters_lps[num_lps++ % num_clusters].Set(lp); }); + HWY_DASSERT(num_lps == lps.Count()); + + // Create new clusters, just inserting the new LPS. + hwy::Topology::Cluster tcluster = tpackage.clusters[0]; // modifiable copy + for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { + tcluster.lps = clusters_lps[cluster_idx]; + // Keep same `private_kib` and `shared_kib`. + clusters.push_back(Cluster(enabled_lps, topology.lps, tcluster)); + } + } } #if !GEMMA_DISABLE_TOPOLOGY @@ -256,10 +299,9 @@ static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters, max_tclusters = 0; max_tcluster_cores = 0; max_tcluster_lps = 0; - for (size_t package_idx = 0; package_idx < topology_.packages.size(); - ++package_idx) { + for (size_t pkg_idx = 0; pkg_idx < topology_.packages.size(); ++pkg_idx) { const std::vector& tclusters = - topology_.packages[package_idx].clusters; + topology_.packages[pkg_idx].clusters; max_tclusters = HWY_MAX(max_tclusters, tclusters.size()); size_t tcluster_cores = 0; size_t tcluster_lps = 0; @@ -272,10 +314,10 @@ static void ScanTClusters(hwy::Topology& topology_, size_t& max_tclusters, } if (tclusters.size() > 1 && tcluster_cores > 8) { - fprintf(stderr, - "Package %zu: multiple clusters with max size %zu, whereas CCX " - "only have 8, may indicate a bug in hwy::Topology.\n", - package_idx, tcluster_cores); + HWY_WARN( + "Package %zu: multiple clusters with max size %zu, whereas CCX " + "only have 8, may indicate a bug in hwy::Topology.", + pkg_idx, tcluster_cores); } max_tcluster_cores = HWY_MAX(max_tcluster_cores, tcluster_cores); max_tcluster_lps = HWY_MAX(max_tcluster_lps, tcluster_lps); @@ -294,8 +336,8 @@ void BoundedTopology::InitFromTopology(const LPS& enabled_lps, // (Possibly empty) subset of `Topology` packages that have `enabled_lps`. package_slice.Foreach( - "package", topology_.packages.size(), [&](size_t package_idx) { - Package package(enabled_lps, topology_, package_idx, cluster_slice); + "package", topology_.packages.size(), [&](size_t pkg_idx) { + Package package(enabled_lps, topology_, pkg_idx, cluster_slice); // Skip if empty, i.e. too few `enabled_lps`. if (HWY_LIKELY(!package.clusters.empty())) { packages_.push_back(std::move(package)); @@ -313,18 +355,18 @@ void BoundedTopology::InitFromTopology(const LPS& enabled_lps, // Scan for max BoundedTopology clusters and their size, for topology_string_. size_t all_max_cluster_size = 0; - for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) { + for (size_t pkg_idx = 0; pkg_idx < NumPackages(); ++pkg_idx) { size_t max_cluster_size = 0; - for (size_t cluster_idx = 0; cluster_idx < NumClusters(package_idx); + for (size_t cluster_idx = 0; cluster_idx < NumClusters(pkg_idx); ++cluster_idx) { - max_cluster_size = HWY_MAX(max_cluster_size, - GetCluster(package_idx, cluster_idx).Size()); + max_cluster_size = + HWY_MAX(max_cluster_size, GetCluster(pkg_idx, cluster_idx).Size()); } - if (NumClusters(package_idx) > 1 && max_cluster_size > 8) { - fprintf(stderr, - "Package %zu: multiple clusters with max size %zu, whereas CCX " - "only have 8, may indicate a bug in BoundedTopology.\n", - package_idx, max_cluster_size); + if (NumClusters(pkg_idx) > 1 && max_cluster_size > 8) { + HWY_WARN( + "Package %zu: multiple clusters with max size %zu, whereas CCX " + "only have 8, may indicate a bug in BoundedTopology.", + pkg_idx, max_cluster_size); } all_max_cluster_size = HWY_MAX(all_max_cluster_size, max_cluster_size); } @@ -382,10 +424,10 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin, // calling thread of an all_clusters->Run, and hence pinned to one of the // `cluster.lps` if `pin`. all_packages_->Run( - 0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) { - HWY_ASSERT(package_idx == thread); // each thread has one task - packages_[package_idx] = - Package(topology_, package_idx, max_workers_per_package, lp_slice); + 0, all_packages_->NumWorkers(), [&](uint64_t pkg_idx, size_t thread) { + HWY_ASSERT(pkg_idx == thread); // each thread has one task + packages_[pkg_idx] = + Package(topology_, pkg_idx, max_workers_per_package, lp_slice); }); all_pinned_ = GetPinning().AllPinned(&pin_string_); @@ -405,12 +447,11 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin, HWY_ASSERT(max_workers_per_cluster_ <= 256); } -NestedPools::Package::Package(const BoundedTopology& topology, - size_t package_idx, +NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx, size_t max_workers_per_package, BoundedSlice lp_slice) { // Pre-allocate because elements are set concurrently. - clusters_.resize(topology.NumClusters(package_idx)); + clusters_.resize(topology.NumClusters(pkg_idx)); const size_t max_workers_per_cluster = DivideMaxAcross(max_workers_per_package, clusters_.size()); @@ -421,7 +462,7 @@ NestedPools::Package::Package(const BoundedTopology& topology, 0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) { HWY_ASSERT(cluster_idx == thread); // each thread has one task const BoundedTopology::Cluster& cluster = - topology.GetCluster(package_idx, cluster_idx); + topology.GetCluster(pkg_idx, cluster_idx); clusters_[cluster_idx] = MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster)); // Pin workers AND the calling thread from `all_clusters`. diff --git a/util/threading.h b/util/threading.h index 604882e..72d60ec 100644 --- a/util/threading.h +++ b/util/threading.h @@ -108,7 +108,7 @@ class BoundedTopology { class Cluster { public: - Cluster(const LPS& enabled_lps, BoundedSlice lp_slice); + Cluster(const LPS& lps); Cluster(const LPS& enabled_lps, const std::vector& all_lps, const hwy::Topology::Cluster& tcluster); @@ -124,17 +124,12 @@ class BoundedTopology { return lps; } + const LPS& LPSet() const { return lps_; } size_t Node() const { return node_; } size_t PrivateKiB() const { return private_kib_; } size_t SharedKiB() const { return shared_kib_; } private: - void AddLP(size_t lp) { - HWY_ASSERT(!lps_.Get(lp)); // Foreach ensures uniqueness - lps_.Set(lp); - ++num_workers_; - } - // Enabled LPs; if topology is known, only the ones in this cluster. LPS lps_; // How many workers in the per-cluster pool. If 0, this Cluster is removed. @@ -147,19 +142,19 @@ class BoundedTopology { size_t shared_kib_ = 0; }; // Cluster - size_t NumClusters(size_t package_idx) const { - HWY_ASSERT(package_idx < NumPackages()); - return packages_[package_idx].clusters.size(); + size_t NumClusters(size_t pkg_idx) const { + HWY_ASSERT(pkg_idx < NumPackages()); + return packages_[pkg_idx].clusters.size(); } - const Cluster& GetCluster(size_t package_idx, size_t cluster_idx) const { - HWY_ASSERT(package_idx < NumPackages()); - const Package& package = packages_[package_idx]; + const Cluster& GetCluster(size_t pkg_idx, size_t cluster_idx) const { + HWY_ASSERT(pkg_idx < NumPackages()); + const Package& package = packages_[pkg_idx]; HWY_ASSERT(cluster_idx < package.clusters.size()); return package.clusters[cluster_idx]; } - Cluster& GetCluster(size_t package_idx, size_t cluster_idx) { - HWY_ASSERT(package_idx < NumPackages()); - Package& package = packages_[package_idx]; + Cluster& GetCluster(size_t pkg_idx, size_t cluster_idx) { + HWY_ASSERT(pkg_idx < NumPackages()); + Package& package = packages_[pkg_idx]; HWY_ASSERT(cluster_idx < package.clusters.size()); return package.clusters[cluster_idx]; } @@ -170,13 +165,9 @@ class BoundedTopology { private: struct Package { - // Topology is unknown, rely on OS affinity and user-specified slice. - Package(const LPS& enabled_lps, BoundedSlice lp_slice) { - clusters.push_back(Cluster(enabled_lps, lp_slice)); - } - + Package(const LPS& enabled_lps, BoundedSlice lp_slice); Package(const LPS& enabled_lps, const hwy::Topology& topology, - size_t package_idx, BoundedSlice cluster_slice); + size_t pkg_idx, BoundedSlice cluster_slice); // For SortByDescendingSize. size_t Size() const { return clusters.size(); } @@ -257,33 +248,36 @@ class NestedPools { } } + size_t NumPackages() const { return packages_.size(); } hwy::ThreadPool& AllPackages() { return *all_packages_; } - hwy::ThreadPool& AllClusters(size_t package_idx) { - HWY_DASSERT(package_idx < packages_.size()); - return packages_[package_idx].AllClusters(); + hwy::ThreadPool& AllClusters(size_t pkg_idx) { + HWY_DASSERT(pkg_idx < NumPackages()); + return packages_[pkg_idx].AllClusters(); } - hwy::ThreadPool& Cluster(size_t package_idx, size_t cluster_idx) { - HWY_DASSERT(package_idx < packages_.size()); - return packages_[package_idx].Cluster(cluster_idx); + hwy::ThreadPool& Cluster(size_t pkg_idx, size_t cluster_idx) { + HWY_DASSERT(pkg_idx < NumPackages()); + return packages_[pkg_idx].Cluster(cluster_idx); } // For binding to NUMA nodes. - size_t Node(size_t package_idx, size_t cluster_idx) const { - return topology_.GetCluster(package_idx, cluster_idx).Node(); + size_t Node(size_t pkg_idx, size_t cluster_idx) const { + return topology_.GetCluster(pkg_idx, cluster_idx).Node(); } - // Reasonably tight upper bound for allocating thread-local storage (TLS). - size_t MaxWorkers() const { - return packages_.size() * max_clusters_per_package_ * - max_workers_per_cluster_; + // Reasonably tight upper bounds for allocating thread-local storage (TLS). + size_t MaxWorkersPerCluster() const { return max_workers_per_cluster_; } + size_t MaxWorkersPerPackage() const { + return max_clusters_per_package_ * MaxWorkersPerCluster(); } - // Returns the first of `cluster.NumWorkers()` TLS indices, to which callers - // add the worker index given by `cluster.Run`. - size_t WorkerOffset(size_t package_idx, size_t cluster_idx) const { - HWY_DASSERT(package_idx < packages_.size()); - HWY_DASSERT(cluster_idx < packages_[package_idx].NumClusters()); - return (package_idx * max_clusters_per_package_ + cluster_idx) * - max_workers_per_cluster_; + size_t MaxWorkers() const { return NumPackages() * MaxWorkersPerPackage(); } + + // Actual number of workers. + size_t TotalWorkers() const { + size_t total_workers = 0; + for (size_t pkg_idx = 0; pkg_idx < NumPackages(); ++pkg_idx) { + total_workers += packages_[pkg_idx].TotalWorkers(); + } + return total_workers; } // For Allocator @@ -296,20 +290,20 @@ class NestedPools { // if there is more than one, which maximizes available memory bandwidth, or // the first cluster, which is typically the whole package. For use by callers // that only have a single parallel-for. - hwy::ThreadPool& Pool(size_t package_idx = 0) { + hwy::ThreadPool& Pool(size_t pkg_idx = 0) { // Only one cluster: use its pool, typically a whole socket. - if (AllClusters(package_idx).NumWorkers() == 1) { - return Cluster(package_idx, 0); + if (AllClusters(pkg_idx).NumWorkers() == 1) { + return Cluster(pkg_idx, 0); } // One worker per cluster to maximize bandwidth availability. - return AllClusters(package_idx); + return AllClusters(pkg_idx); } private: class Package { public: Package() = default; // for vector - Package(const BoundedTopology& topology, size_t package_idx, + Package(const BoundedTopology& topology, size_t pkg_idx, size_t max_workers_per_package, BoundedSlice lp_slice); size_t NumClusters() const { return clusters_.size(); } @@ -321,6 +315,13 @@ class NestedPools { } return max_workers_per_cluster; } + size_t TotalWorkers() const { + size_t total_workers = 0; + for (const PoolPtr& cluster : clusters_) { + total_workers += cluster->NumWorkers(); + } + return total_workers; + } hwy::ThreadPool& AllClusters() { return *all_clusters_; } hwy::ThreadPool& Cluster(size_t cluster_idx) { @@ -365,32 +366,33 @@ class NestedPools { // functions below. class IndexRangePartition { public: + IndexRangePartition() = default; // for MMPartitions IndexRangePartition(const IndexRange& range, const size_t task_size) - : range_(range), task_size_(task_size) { - const size_t num = range.Num(); + : range_(range), task_size_(static_cast(task_size)) { + const uint32_t num = static_cast(range.Num()); HWY_DASSERT(task_size_ != 0); num_tasks_ = hwy::DivCeil(num, task_size_); HWY_DASSERT(num_tasks_ != 0); if constexpr (HWY_IS_DEBUG_BUILD) { - const size_t handled = num_tasks_ * task_size_; + const uint32_t handled = num_tasks_ * task_size_; // The last task may extend beyond items, but at most by (task_size_ - 1). HWY_DASSERT(num <= handled && handled < num + task_size_); } } - size_t TaskSize() const { return task_size_; } - size_t NumTasks() const { return num_tasks_; } + size_t TaskSize() const { return static_cast(task_size_); } + size_t NumTasks() const { return static_cast(num_tasks_); } IndexRange Range(size_t task_idx) const { HWY_DASSERT(task_idx < NumTasks()); - return MakeIndexRange(range_.begin() + task_idx * task_size_, range_.end(), - task_size_); + return MakeIndexRange(range_.begin() + task_idx * TaskSize(), range_.end(), + TaskSize()); } private: IndexRange range_; - size_t task_size_; - size_t num_tasks_; + uint32_t task_size_; + uint32_t num_tasks_; }; // Starts with `max_size` and rounds DOWN to a multiple of `size_multiple` @@ -455,33 +457,6 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1, }); } -// As above, for three ranges. -template -void ParallelizeThreeRanges(const IndexRangePartition& get1, - const IndexRangePartition& get2, - const IndexRangePartition& get3, - hwy::ThreadPool& pool, const Func& func) { - const hwy::Divisor div1(static_cast(get1.NumTasks())); - const size_t num12 = get1.NumTasks() * get2.NumTasks(); - const hwy::Divisor div12(static_cast(num12)); - - const size_t num_tasks = num12 * get3.NumTasks(); - pool.Run(0, num_tasks, [&](uint64_t task, size_t thread) { - HWY_DASSERT(task < (uint64_t{1} << 32)); - const size_t idx3 = div12.Divide(static_cast(task)); - const size_t task12 = div12.Remainder(static_cast(task)); - const size_t idx2 = div1.Divide(static_cast(task12)); - const size_t idx1 = div1.Remainder(static_cast(task12)); - HWY_DASSERT(idx1 < get1.NumTasks()); - HWY_DASSERT(idx2 < get2.NumTasks()); - HWY_DASSERT(idx3 < get3.NumTasks()); - const IndexRange range1 = get1.Range(idx1); - const IndexRange range2 = get2.Range(idx2); - const IndexRange range3 = get3.Range(idx3); - func(range1, range2, range3, thread); - }); -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ diff --git a/util/threading_test.cc b/util/threading_test.cc index 2190e7e..8cdf02e 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -138,6 +138,13 @@ TEST(ThreadingTest, TestMaxSizePartition) { HWY_ASSERT(partition.TaskSize() == 55); HWY_ASSERT(partition.NumTasks() == 2); } + // `size_multiple` almost as large as range: imbalanced + { + const IndexRangePartition partition = + MaxSizePartition(IndexRange(0, 6), 6, 4); + HWY_ASSERT(partition.TaskSize() == 4); + HWY_ASSERT(partition.NumTasks() == 2); + } // Small `max_size`: small tasks { const IndexRangePartition partition = MaxSizePartition(range, 2, 1); @@ -244,97 +251,5 @@ TEST(ThreadingTest, TestParallelizeTwoRanges) { } } -TEST(ThreadingTest, TestParallelizeThreeRanges) { - // Named according to number of tasks. - const IndexRangePartition partition3 = - StaticPartition(IndexRange(0, 8), 3, 1); // [0, 3) [3, 6) [6, 8) - HWY_ASSERT(partition3.NumTasks() == 3); - const IndexRangePartition partition2 = - MaxSizePartition(IndexRange(10, 30), 10, 10); // [10, 20), [20, 30) - HWY_ASSERT(partition2.NumTasks() == 2); - const IndexRangePartition partition4 = - MaxSizePartition(IndexRange(100, 500), 100, 100); // 100, 200, 300, 400 - HWY_ASSERT(partition4.NumTasks() == 4); - - const auto check_ranges = [&](const IndexRange& range3, - const IndexRange& range2, - const IndexRange& range4) { - HWY_ASSERT(range3.begin() == 0 || range3.begin() == 3 || - range3.begin() == 6); - HWY_ASSERT(range2.begin() == 10 || range2.begin() == 20); - HWY_ASSERT(range4.begin() % 100 == 0); - }; - - hwy::ThreadPool null_pool(0); - // All 6 permutations of the three ranges to test the Remainder() logic: - // 3, 2, 4 - { - size_t calls = 0; - ParallelizeThreeRanges( - partition3, partition2, partition4, null_pool, - [&](IndexRange range3, IndexRange range2, IndexRange range4, size_t) { - ++calls; - check_ranges(range3, range2, range4); - }); - HWY_ASSERT(calls == 3 * 2 * 4); - } - // 3, 4, 2 - { - size_t calls = 0; - ParallelizeThreeRanges( - partition3, partition4, partition2, null_pool, - [&](IndexRange range3, IndexRange range4, IndexRange range2, size_t) { - ++calls; - check_ranges(range3, range2, range4); - }); - HWY_ASSERT(calls == 3 * 2 * 4); - } - - // 4, 2, 3 - { - size_t calls = 0; - ParallelizeThreeRanges( - partition4, partition2, partition3, null_pool, - [&](IndexRange range4, IndexRange range2, IndexRange range3, size_t) { - ++calls; - check_ranges(range3, range2, range4); - }); - HWY_ASSERT(calls == 3 * 2 * 4); - } - // 4, 3, 2 - { - size_t calls = 0; - ParallelizeThreeRanges( - partition4, partition3, partition2, null_pool, - [&](IndexRange range4, IndexRange range3, IndexRange range2, size_t) { - ++calls; - check_ranges(range3, range2, range4); - }); - HWY_ASSERT(calls == 3 * 2 * 4); - } - - // 2, 3, 4 - { - size_t calls = 0; - ParallelizeThreeRanges( - partition2, partition3, partition4, null_pool, - [&](IndexRange range2, IndexRange range3, IndexRange range4, size_t) { - ++calls; - check_ranges(range3, range2, range4); - }); - HWY_ASSERT(calls == 3 * 2 * 4); - } - // 2, 4, 3 - { - size_t calls = 0; - ParallelizeThreeRanges( - partition2, partition4, partition3, null_pool, - [&](IndexRange range2, IndexRange range4, IndexRange range3, size_t) { - ++calls; - check_ranges(range3, range2, range4); - }); - HWY_ASSERT(calls == 3 * 2 * 4); - } -} } // namespace } // namespace gcpp