Skip to content

Commit

Permalink
Removed duplicated tensor sizes from weights.h by changing the constr…
Browse files Browse the repository at this point in the history
…uctor used for MatPtrT

PiperOrigin-RevId: 705045279
  • Loading branch information
theraysmith authored and copybara-github committed Dec 11, 2024
1 parent aed1739 commit db0a749
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 98 deletions.
8 changes: 5 additions & 3 deletions backprop/backward_scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,14 @@ TEST(BackPropTest, LayerVJP) {
using T = double;
using TC = std::complex<T>;
ModelConfig config = TestConfig();
TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1,
/*reshape_att=*/false);
const size_t kOutputSize = config.seq_len * config.model_dim;
LayerWeightsPtrs<T> weights(config.layer_configs[0]);
LayerWeightsPtrs<T> grad(config.layer_configs[0]);
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
ForwardLayer<T> forward(config.layer_configs[0], config.seq_len);
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0]);
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
MatStorageT<T> y("y", kOutputSize, 1);
MatStorageT<T> dy("dy", kOutputSize, 1);
Expand Down
30 changes: 18 additions & 12 deletions compression/compress.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,21 +219,27 @@ class MatPtrT : public MatPtr {
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
MatPtrT(const std::string& name, const TensorInfo* tensor)
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
HWY_ASSERT(tensor != nullptr);
cols_ = tensor->shape.back();
rows_ = 1;
if (tensor->cols_take_extra_dims) {
// The columns eat the extra dimensions.
rows_ = tensor->shape[0];
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
cols_ *= tensor->shape[i];
}
if (tensor == nullptr) {
cols_ = 0;
rows_ = 0;
} else {
// The rows eat the extra dimensions.
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
rows_ *= tensor->shape[i];
cols_ = tensor->shape.back();
rows_ = 1;
if (tensor->cols_take_extra_dims) {
// The columns eat the extra dimensions.
rows_ = tensor->shape[0];
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
cols_ *= tensor->shape[i];
}
} else {
// The rows eat the extra dimensions.
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
rows_ *= tensor->shape[i];
}
}
}
stride_ = cols_;
num_elements_ = rows_ * cols_;
}

// Copying allowed as the metadata is small.
Expand Down
152 changes: 69 additions & 83 deletions gemma/weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/tensor_index.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
Expand All @@ -56,73 +57,48 @@ enum class ForEachType {
template <class Weight>
struct LayerWeightsPtrs {
// Large data is constructed separately.
explicit LayerWeightsPtrs(const LayerConfig& config)
: attn_vec_einsum_w("att_ein", config.heads * config.model_dim,
config.qkv_dim),
qkv_einsum_w("qkv_ein",
(config.heads + 2 * config.kv_heads) * config.qkv_dim,
config.model_dim),
qkv_einsum_w1("qkv1_w", config.heads * config.qkv_dim,
config.model_dim),
qkv_einsum_w2("qkv2_w", 2 * config.kv_heads * config.qkv_dim,
config.model_dim),
attention_output_biases(
"attn_ob", 1,
config.softmax_attn_output_biases ? config.model_dim : 0),
griffin(
{.linear_x_w = {"gr_lin_x_w", config.griffin_dim,
config.griffin_dim},
.linear_x_biases = {"gr_lin_x_b", 1, config.griffin_dim},
.linear_y_w = {"gr_lin_y_w", config.griffin_dim,
config.griffin_dim},
.linear_y_biases = {"gr_lin_y_b", 1, config.griffin_dim},
.linear_out_w = {"gr_lin_out_w", config.griffin_dim,
config.griffin_dim},
.linear_out_biases = {"gr_lin_out_b", 1, config.griffin_dim},
.conv_w = {"gr_conv_w", config.conv1d_width, config.griffin_dim},
.conv_biases = {"gr_conv_b", 1, config.griffin_dim},
.gate_w = {"gr_gate_w", 2 * config.griffin_dim,
config.griffin_dim / config.heads},
.gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2},
.a = {"gr_a", 1, config.griffin_dim}}),
explicit LayerWeightsPtrs(const LayerConfig& config,
const TensorIndex& tensor_index)
: attn_vec_einsum_w("att_ein", tensor_index),
qkv_einsum_w("qkv_ein", tensor_index),
qkv_einsum_w1("qkv1_w", tensor_index),
qkv_einsum_w2("qkv2_w", tensor_index),
attention_output_biases("attn_ob", tensor_index),
griffin({.linear_x_w = {"gr_lin_x_w", tensor_index},
.linear_x_biases = {"gr_lin_x_b", tensor_index},
.linear_y_w = {"gr_lin_y_w", tensor_index},
.linear_y_biases = {"gr_lin_y_b", tensor_index},
.linear_out_w = {"gr_lin_out_w", tensor_index},
.linear_out_biases = {"gr_lin_out_b", tensor_index},
.conv_w = {"gr_conv_w", tensor_index},
.conv_biases = {"gr_conv_b", tensor_index},
.gate_w = {"gr_gate_w", tensor_index},
.gate_biases = {"gr_gate_b", tensor_index},
.a = {"gr_a", tensor_index}}),
// MultiHeadDotProductAttention.
vit({.attn_out_w = {"attn_out_w", config.model_dim,
config.heads * config.qkv_dim},
.attn_out_b = {"attn_out_b", 1, config.model_dim},
.qkv_einsum_w = {"qkv_ein_w",
(config.heads + 2 * config.kv_heads) *
config.qkv_dim,
config.model_dim},
.qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads),
config.qkv_dim},
.linear_0_w = {"linear_0_w", config.ff_hidden_dim,
config.model_dim},
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
.linear_1_w = {"linear_1_w", config.model_dim,
config.ff_hidden_dim},
.linear_1_b = {"linear_1_b", 1, config.model_dim},
.layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim},
.layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim},
.layer_norm_1_bias = {"ln_1_bias", 1, config.model_dim},
.layer_norm_1_scale = {"ln_1_scale", 1, config.model_dim}}),
gating_einsum_w("gating_ein", 2 * config.ff_hidden_dim,
config.model_dim),
gating_einsum_w1("gating1_w", config.ff_hidden_dim, config.model_dim),
gating_einsum_w2("gating2_w", config.ff_hidden_dim, config.model_dim),
linear_w("linear_w", config.model_dim, config.ff_hidden_dim),
pre_attention_norm_scale("pre_att_ns", 1, config.model_dim),
pre_ffw_norm_scale("pre_ff_ns", 1, config.model_dim),
post_attention_norm_scale(
"post_att_ns", 1,
config.post_norm == PostNormType::Scale ? config.model_dim : 0),
post_ffw_norm_scale(
"post_ff_ns", 1,
config.post_norm == PostNormType::Scale ? config.model_dim : 0),
ffw_gating_biases("ffw_gat_b", 1,
config.ff_biases ? 2 * config.ff_hidden_dim : 0),
ffw_output_biases("ffw_out_b", 1,
config.ff_biases ? config.model_dim : 0),
att_weights("att_w", config.model_dim, config.heads * config.qkv_dim),
vit({.attn_out_w = {"attn_out_w", tensor_index},
.attn_out_b = {"attn_out_b", tensor_index},
.qkv_einsum_w = {"qkv_ein_w", tensor_index},
.qkv_einsum_b = {"qkv_ein_b", tensor_index},
.linear_0_w = {"linear_0_w", tensor_index},
.linear_0_b = {"linear_0_b", tensor_index},
.linear_1_w = {"linear_1_w", tensor_index},
.linear_1_b = {"linear_1_b", tensor_index},
.layer_norm_0_bias = {"ln_0_bias", tensor_index},
.layer_norm_0_scale = {"ln_0_scale", tensor_index},
.layer_norm_1_bias = {"ln_1_bias", tensor_index},
.layer_norm_1_scale = {"ln_1_scale", tensor_index}}),
gating_einsum_w("gating_ein", tensor_index),
gating_einsum_w1("gating1_w", tensor_index),
gating_einsum_w2("gating2_w", tensor_index),
linear_w("linear_w", tensor_index),
pre_attention_norm_scale("pre_att_ns", tensor_index),
pre_ffw_norm_scale("pre_ff_ns", tensor_index),
post_attention_norm_scale("post_att_ns", tensor_index),
post_ffw_norm_scale("post_ff_ns", tensor_index),
ffw_gating_biases("ffw_gat_b", tensor_index),
ffw_output_biases("ffw_out_b", tensor_index),
att_weights("att_w", tensor_index),
layer_config(config) {}
~LayerWeightsPtrs() = default;

Expand Down Expand Up @@ -343,27 +319,37 @@ struct LayerWeightsPtrs {
template <class Weight>
struct ModelWeightsPtrs {
ModelWeightsPtrs(const ModelConfig& config, hwy::ThreadPool& pool)
: embedder_input_embedding("c_embedding", config.vocab_size,
config.model_dim),
final_norm_scale("c_final_norm", 1, config.model_dim),
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim,
config.patch_width * config.patch_width * 3),
vit_img_pos_embedding("img_pos_emb", config.vit_seq_len,
config.vit_model_dim),
vit_img_head_bias("img_head_bias", 1, config.model_dim),
vit_img_head_kernel("img_head_kernel", config.model_dim,
config.vit_model_dim),
: ModelWeightsPtrs(
config,
TensorIndex(config, /*llm_layer_idx=*/-1, /*vit_layer_idx=*/-1,
/*reshape_att=*/false),
pool) {}
ModelWeightsPtrs(const ModelConfig& config, const TensorIndex& tensor_index,
hwy::ThreadPool& pool)
: embedder_input_embedding("c_embedding", tensor_index),
final_norm_scale("c_final_norm", tensor_index),
vit_encoder_norm_bias("enc_norm_bias", tensor_index),
vit_encoder_norm_scale("enc_norm_scale", tensor_index),
vit_img_embedding_bias("img_emb_bias", tensor_index),
vit_img_embedding_kernel("img_emb_kernel", tensor_index),
vit_img_pos_embedding("img_pos_emb", tensor_index),
vit_img_head_bias("img_head_bias", tensor_index),
vit_img_head_kernel("img_head_kernel", tensor_index),
scale_names(config.scale_names),
weights_config(config) {
c_layers.reserve(config.layer_configs.size());
for (const auto& layer_config : config.layer_configs) {
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
for (int index = 0; index < config.layer_configs.size(); ++index) {
const auto& layer_config = config.layer_configs[index];
TensorIndex tensor_index(config, index, /*vit_layer_idx=*/-1,
/*reshape_att=*/false);
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config, tensor_index));
}
for (const auto& layer_config : config.vit_layer_configs) {
vit_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
for (int index = 0; index < config.vit_layer_configs.size(); ++index) {
const auto& layer_config = config.vit_layer_configs[index];
TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index,
/*reshape_att=*/false);
vit_layers.push_back(
LayerWeightsPtrs<Weight>(layer_config, tensor_index));
}
}

Expand Down

0 comments on commit db0a749

Please sign in to comment.