diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 114d5a3..6fadb0f 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -157,13 +157,14 @@ TEST_F(GemmaTest, Multiturn) { Gemma* model = s_env->GetModel(); ASSERT_NE(model, nullptr); size_t abs_pos = 0; - std::string dialog; + std::string response; auto stream_token = [&](int token, float) { + if (token == EOS_ID) return true; ++abs_pos; std::string token_text; EXPECT_TRUE( model->Tokenizer().Decode(std::vector{token}, &token_text)); - dialog += token_text; + response += token_text; return true; }; RuntimeConfig runtime_config{ @@ -180,18 +181,21 @@ TEST_F(GemmaTest, Multiturn) { abs_pos, mutable_prompt); model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), timing_info); + // Note: we do not rewind any tokens here. If the model + // produced one and WrapAndTokenize() inserts another one, it will just be + // duplicated. mutable_prompt = "Please repeat all prior statements."; tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos, mutable_prompt); - // Reset the `dialog` string here, then check that the model actually has + // Reset the `response` string here, then check that the model actually has // access to the previous turn by asking to reproduce. - dialog.clear(); + response.clear(); model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), timing_info); - fprintf(stderr, "decoded: %s\n", dialog.c_str()); + fprintf(stderr, "decoded: %s\n", response.c_str()); bool remembered_turquoise = - dialog.find("turquoise") != std::string::npos; // NOLINT - bool remembered_car = dialog.find("car") != std::string::npos; // NOLINT + response.find("turquoise") != std::string::npos; // NOLINT + bool remembered_car = response.find("car") != std::string::npos; // NOLINT EXPECT_TRUE(remembered_turquoise || remembered_car); } diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 51f2999..81a9469 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1249,16 +1249,9 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, // Copy so we can increment without requiring users to pass in a mutable span. std::vector queries_pos_copy(queries_pos_in.cbegin(), queries_pos_in.cend()); - QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(), - queries_pos_copy.size()); - // For the first turn, qpos remains 0. Otherwise, rewind the previous EOS. - // Background: for multiturn, Gemma 2 expects only , not EOS. The - // previous `Generate` called `StreamToken` for the last token (EOS), hence - // our caller's qpos is 1 too high. This must be corrected because we didn't - // write to the KV cache at that position, so MSAN would complain. - for (size_t& qpos : queries_mutable_pos) { - qpos = qpos == 0 ? 0 : qpos - 1; - } + const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(), + queries_pos_copy.size()); + // Sanity check: prompts should not be empty, nor start with EOS. for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { const PromptTokens& prompt = queries_prompt[query_idx]; diff --git a/gemma/run.cc b/gemma/run.cc index d21c5d6..659be81 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -85,6 +85,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, size_t abs_pos = 0; // across turns size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t prompt_size = 0; + bool end_of_turn_seen = false; std::mt19937 gen; InitGenerator(args, gen); @@ -114,37 +115,44 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, // callback function invoked for each generated token. auto stream_token = [&](int token, float) { ++abs_pos; - ++tokens_generated_this_turn; - // <= since position is incremented before - if (tokens_generated_this_turn <= prompt_size) { - std::cerr << "." << std::flush; - } else if (token == EOS_ID) { - if (!args.multiturn) { - abs_pos = 0; - InitGenerator(args, gen); - } + if (token == EOS_ID) { if (app.verbosity >= 2) { std::cout << "\n[ End ]\n"; } - } else { - std::string token_text; - HWY_ASSERT( - model.Tokenizer().Decode(std::vector{token}, &token_text)); - // +1 since position is incremented above - if (tokens_generated_this_turn == prompt_size + 1) { - // first token of response - token_text.erase(0, token_text.find_first_not_of(" \t\n")); - if (app.verbosity >= 1) { - std::cout << "\n\n"; - } + return true; + } + const bool in_prompt = tokens_generated_this_turn < prompt_size; + const bool first_response_token = tokens_generated_this_turn == prompt_size; + ++tokens_generated_this_turn; + if (in_prompt) { + if (app.verbosity >= 1) { + std::cerr << "." << std::flush; + } + return true; + } + std::string token_text; + HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); + if (first_response_token) { + token_text.erase(0, token_text.find_first_not_of(" \t\n")); + if (app.verbosity >= 1) { + std::cout << "\n\n"; } - std::cout << token_text << std::flush; } + if (token_text == "") { + // We don't want to show the token to the user. + // We also need to remember that we've seen it, so that we can rewind + // abs_pos appropriately. We expect EOS as the next token. + end_of_turn_seen = true; + return true; + } + std::cout << token_text << std::flush; return true; }; while (true) { // Loop until user quits. tokens_generated_this_turn = 0; + + // Read prompt and handle special commands. std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line); if (!std::cin) return; // If !eot_line.empty(), we append \n, so only look at the first 2 chars. @@ -155,23 +163,22 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, continue; } } - - if (have_image && abs_pos != 0) { - // This occurs when we have hit max_generated. - abs_pos = 0; + if (prompt_string.empty()) { + std::cout << "Use '%q' to quit.\n"; + continue; } + // Wrap, tokenize and maybe log prompt tokens. std::vector prompt = WrapAndTokenize( model.Tokenizer(), model.Info(), abs_pos, prompt_string); prompt_size = prompt.size(); - std::cerr << "\n" - << "[ Reading prompt ] " << std::flush; if constexpr (kVerboseLogTokens) { for (int i = 0; i < prompt_size; ++i) { fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); } } + // Set up runtime config. TimingInfo timing_info = {.verbosity = app.verbosity}; RuntimeConfig runtime_config = {.gen = &gen, .verbosity = app.verbosity, @@ -190,9 +197,38 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, // We need to look at all the tokens for the prefix. runtime_config.prefill_tbatch_size = prompt_size; } + + // Generate until EOS or max_generated_tokens. + if (app.verbosity >= 1) { + std::cerr << "\n[ Reading prompt ] " << std::flush; + } model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, timing_info); std::cout << "\n\n"; + + // Prepare for the next turn. + if (!args.multiturn || model.Info().training == ModelTraining::PALIGEMMA) { + abs_pos = 0; // Start a new turn at position 0. + InitGenerator(args, gen); + } else { + // The last token was either EOS, then it should be ignored because it is + // never part of the dialog, see Table 5 in the Gemma-2 paper: + // https://arxiv.org/pdf/2408.00118 + // Or we have hit max_generated_tokens, then the last token will be lost. + // (We could store it in stream_token, and then prepend to the next turn, + // but it's not worth the complexity, as multi-turn with max_generated is + // not a common use case.) + // In either case, we need to rewind abs_pos by one. + HWY_ASSERT(abs_pos > 0); + abs_pos--; + } + if (end_of_turn_seen && abs_pos > 0) { + // If we have seen an end_of_turn token, we need to rewind abs_pos by one + // more, because we will pre-pend it again to the prompt in + // WrapAndTokenize. + abs_pos--; + } + end_of_turn_seen = false; } }