diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index b5e39db..d99a067 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -411,12 +411,14 @@ TEST(BackPropTest, LayerVJP) { using T = double; using TC = std::complex; ModelConfig config = TestConfig(); + TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1, + /*reshape_att=*/false); const size_t kOutputSize = config.seq_len * config.model_dim; - LayerWeightsPtrs weights(config.layer_configs[0]); - LayerWeightsPtrs grad(config.layer_configs[0]); + LayerWeightsPtrs weights(config.layer_configs[0], tensor_index); + LayerWeightsPtrs grad(config.layer_configs[0], tensor_index); ForwardLayer forward(config.layer_configs[0], config.seq_len); ForwardLayer backward(config.layer_configs[0], config.seq_len); - LayerWeightsPtrs c_weights(config.layer_configs[0]); + LayerWeightsPtrs c_weights(config.layer_configs[0], tensor_index); ForwardLayer c_forward(config.layer_configs[0], config.seq_len); MatStorageT y("y", kOutputSize, 1); MatStorageT dy("dy", kOutputSize, 1); diff --git a/compression/compress.h b/compression/compress.h index b717ac8..8d4635b 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -219,21 +219,27 @@ class MatPtrT : public MatPtr { : MatPtrT(name, tensor_index.FindName(name)) {} MatPtrT(const std::string& name, const TensorInfo* tensor) : MatPtr(name, TypeEnum(), sizeof(MatT), 0, 0) { - HWY_ASSERT(tensor != nullptr); - cols_ = tensor->shape.back(); - rows_ = 1; - if (tensor->cols_take_extra_dims) { - // The columns eat the extra dimensions. - rows_ = tensor->shape[0]; - for (size_t i = 1; i < tensor->shape.size() - 1; ++i) { - cols_ *= tensor->shape[i]; - } + if (tensor == nullptr) { + cols_ = 0; + rows_ = 0; } else { - // The rows eat the extra dimensions. - for (size_t i = 0; i < tensor->shape.size() - 1; ++i) { - rows_ *= tensor->shape[i]; + cols_ = tensor->shape.back(); + rows_ = 1; + if (tensor->cols_take_extra_dims) { + // The columns eat the extra dimensions. + rows_ = tensor->shape[0]; + for (size_t i = 1; i < tensor->shape.size() - 1; ++i) { + cols_ *= tensor->shape[i]; + } + } else { + // The rows eat the extra dimensions. + for (size_t i = 0; i < tensor->shape.size() - 1; ++i) { + rows_ *= tensor->shape[i]; + } } } + stride_ = cols_; + num_elements_ = rows_ * cols_; } // Copying allowed as the metadata is small. diff --git a/gemma/weights.h b/gemma/weights.h index ecd917b..8410b81 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -30,6 +30,7 @@ #include "compression/shared.h" #include "gemma/common.h" #include "gemma/configs.h" +#include "gemma/tensor_index.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -56,73 +57,48 @@ enum class ForEachType { template struct LayerWeightsPtrs { // Large data is constructed separately. - explicit LayerWeightsPtrs(const LayerConfig& config) - : attn_vec_einsum_w("att_ein", config.heads * config.model_dim, - config.qkv_dim), - qkv_einsum_w("qkv_ein", - (config.heads + 2 * config.kv_heads) * config.qkv_dim, - config.model_dim), - qkv_einsum_w1("qkv1_w", config.heads * config.qkv_dim, - config.model_dim), - qkv_einsum_w2("qkv2_w", 2 * config.kv_heads * config.qkv_dim, - config.model_dim), - attention_output_biases( - "attn_ob", 1, - config.softmax_attn_output_biases ? config.model_dim : 0), - griffin( - {.linear_x_w = {"gr_lin_x_w", config.griffin_dim, - config.griffin_dim}, - .linear_x_biases = {"gr_lin_x_b", 1, config.griffin_dim}, - .linear_y_w = {"gr_lin_y_w", config.griffin_dim, - config.griffin_dim}, - .linear_y_biases = {"gr_lin_y_b", 1, config.griffin_dim}, - .linear_out_w = {"gr_lin_out_w", config.griffin_dim, - config.griffin_dim}, - .linear_out_biases = {"gr_lin_out_b", 1, config.griffin_dim}, - .conv_w = {"gr_conv_w", config.conv1d_width, config.griffin_dim}, - .conv_biases = {"gr_conv_b", 1, config.griffin_dim}, - .gate_w = {"gr_gate_w", 2 * config.griffin_dim, - config.griffin_dim / config.heads}, - .gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2}, - .a = {"gr_a", 1, config.griffin_dim}}), + explicit LayerWeightsPtrs(const LayerConfig& config, + const TensorIndex& tensor_index) + : attn_vec_einsum_w("att_ein", tensor_index), + qkv_einsum_w("qkv_ein", tensor_index), + qkv_einsum_w1("qkv1_w", tensor_index), + qkv_einsum_w2("qkv2_w", tensor_index), + attention_output_biases("attn_ob", tensor_index), + griffin({.linear_x_w = {"gr_lin_x_w", tensor_index}, + .linear_x_biases = {"gr_lin_x_b", tensor_index}, + .linear_y_w = {"gr_lin_y_w", tensor_index}, + .linear_y_biases = {"gr_lin_y_b", tensor_index}, + .linear_out_w = {"gr_lin_out_w", tensor_index}, + .linear_out_biases = {"gr_lin_out_b", tensor_index}, + .conv_w = {"gr_conv_w", tensor_index}, + .conv_biases = {"gr_conv_b", tensor_index}, + .gate_w = {"gr_gate_w", tensor_index}, + .gate_biases = {"gr_gate_b", tensor_index}, + .a = {"gr_a", tensor_index}}), // MultiHeadDotProductAttention. - vit({.attn_out_w = {"attn_out_w", config.model_dim, - config.heads * config.qkv_dim}, - .attn_out_b = {"attn_out_b", 1, config.model_dim}, - .qkv_einsum_w = {"qkv_ein_w", - (config.heads + 2 * config.kv_heads) * - config.qkv_dim, - config.model_dim}, - .qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads), - config.qkv_dim}, - .linear_0_w = {"linear_0_w", config.ff_hidden_dim, - config.model_dim}, - .linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim}, - .linear_1_w = {"linear_1_w", config.model_dim, - config.ff_hidden_dim}, - .linear_1_b = {"linear_1_b", 1, config.model_dim}, - .layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim}, - .layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim}, - .layer_norm_1_bias = {"ln_1_bias", 1, config.model_dim}, - .layer_norm_1_scale = {"ln_1_scale", 1, config.model_dim}}), - gating_einsum_w("gating_ein", 2 * config.ff_hidden_dim, - config.model_dim), - gating_einsum_w1("gating1_w", config.ff_hidden_dim, config.model_dim), - gating_einsum_w2("gating2_w", config.ff_hidden_dim, config.model_dim), - linear_w("linear_w", config.model_dim, config.ff_hidden_dim), - pre_attention_norm_scale("pre_att_ns", 1, config.model_dim), - pre_ffw_norm_scale("pre_ff_ns", 1, config.model_dim), - post_attention_norm_scale( - "post_att_ns", 1, - config.post_norm == PostNormType::Scale ? config.model_dim : 0), - post_ffw_norm_scale( - "post_ff_ns", 1, - config.post_norm == PostNormType::Scale ? config.model_dim : 0), - ffw_gating_biases("ffw_gat_b", 1, - config.ff_biases ? 2 * config.ff_hidden_dim : 0), - ffw_output_biases("ffw_out_b", 1, - config.ff_biases ? config.model_dim : 0), - att_weights("att_w", config.model_dim, config.heads * config.qkv_dim), + vit({.attn_out_w = {"attn_out_w", tensor_index}, + .attn_out_b = {"attn_out_b", tensor_index}, + .qkv_einsum_w = {"qkv_ein_w", tensor_index}, + .qkv_einsum_b = {"qkv_ein_b", tensor_index}, + .linear_0_w = {"linear_0_w", tensor_index}, + .linear_0_b = {"linear_0_b", tensor_index}, + .linear_1_w = {"linear_1_w", tensor_index}, + .linear_1_b = {"linear_1_b", tensor_index}, + .layer_norm_0_bias = {"ln_0_bias", tensor_index}, + .layer_norm_0_scale = {"ln_0_scale", tensor_index}, + .layer_norm_1_bias = {"ln_1_bias", tensor_index}, + .layer_norm_1_scale = {"ln_1_scale", tensor_index}}), + gating_einsum_w("gating_ein", tensor_index), + gating_einsum_w1("gating1_w", tensor_index), + gating_einsum_w2("gating2_w", tensor_index), + linear_w("linear_w", tensor_index), + pre_attention_norm_scale("pre_att_ns", tensor_index), + pre_ffw_norm_scale("pre_ff_ns", tensor_index), + post_attention_norm_scale("post_att_ns", tensor_index), + post_ffw_norm_scale("post_ff_ns", tensor_index), + ffw_gating_biases("ffw_gat_b", tensor_index), + ffw_output_biases("ffw_out_b", tensor_index), + att_weights("att_w", tensor_index), layer_config(config) {} ~LayerWeightsPtrs() = default; @@ -343,27 +319,37 @@ struct LayerWeightsPtrs { template struct ModelWeightsPtrs { ModelWeightsPtrs(const ModelConfig& config, hwy::ThreadPool& pool) - : embedder_input_embedding("c_embedding", config.vocab_size, - config.model_dim), - final_norm_scale("c_final_norm", 1, config.model_dim), - vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim), - vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim), - vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim), - vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim, - config.patch_width * config.patch_width * 3), - vit_img_pos_embedding("img_pos_emb", config.vit_seq_len, - config.vit_model_dim), - vit_img_head_bias("img_head_bias", 1, config.model_dim), - vit_img_head_kernel("img_head_kernel", config.model_dim, - config.vit_model_dim), + : ModelWeightsPtrs( + config, + TensorIndex(config, /*llm_layer_idx=*/-1, /*vit_layer_idx=*/-1, + /*reshape_att=*/false), + pool) {} + ModelWeightsPtrs(const ModelConfig& config, const TensorIndex& tensor_index, + hwy::ThreadPool& pool) + : embedder_input_embedding("c_embedding", tensor_index), + final_norm_scale("c_final_norm", tensor_index), + vit_encoder_norm_bias("enc_norm_bias", tensor_index), + vit_encoder_norm_scale("enc_norm_scale", tensor_index), + vit_img_embedding_bias("img_emb_bias", tensor_index), + vit_img_embedding_kernel("img_emb_kernel", tensor_index), + vit_img_pos_embedding("img_pos_emb", tensor_index), + vit_img_head_bias("img_head_bias", tensor_index), + vit_img_head_kernel("img_head_kernel", tensor_index), scale_names(config.scale_names), weights_config(config) { c_layers.reserve(config.layer_configs.size()); - for (const auto& layer_config : config.layer_configs) { - c_layers.push_back(LayerWeightsPtrs(layer_config)); + for (int index = 0; index < config.layer_configs.size(); ++index) { + const auto& layer_config = config.layer_configs[index]; + TensorIndex tensor_index(config, index, /*vit_layer_idx=*/-1, + /*reshape_att=*/false); + c_layers.push_back(LayerWeightsPtrs(layer_config, tensor_index)); } - for (const auto& layer_config : config.vit_layer_configs) { - vit_layers.push_back(LayerWeightsPtrs(layer_config)); + for (int index = 0; index < config.vit_layer_configs.size(); ++index) { + const auto& layer_config = config.vit_layer_configs[index]; + TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index, + /*reshape_att=*/false); + vit_layers.push_back( + LayerWeightsPtrs(layer_config, tensor_index)); } }