Skip to content

Commit

Permalink
Make prompt wrapping more consistent and fix duplicated tokens for mu…
Browse files Browse the repository at this point in the history
…lti-turn.

Do not echo <end_of_turn> tokens to the user.
Have verbosity=0 only show the dialog.

PiperOrigin-RevId: 704643323
  • Loading branch information
danielkeysers authored and copybara-github committed Dec 10, 2024
1 parent 642fc97 commit b414aaf
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 44 deletions.
18 changes: 11 additions & 7 deletions evals/gemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,14 @@ TEST_F(GemmaTest, Multiturn) {
Gemma* model = s_env->GetModel();
ASSERT_NE(model, nullptr);
size_t abs_pos = 0;
std::string dialog;
std::string response;
auto stream_token = [&](int token, float) {
if (token == EOS_ID) return true;
++abs_pos;
std::string token_text;
EXPECT_TRUE(
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
dialog += token_text;
response += token_text;
return true;
};
RuntimeConfig runtime_config{
Expand All @@ -180,18 +181,21 @@ TEST_F(GemmaTest, Multiturn) {
abs_pos, mutable_prompt);
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
// Note: we do not rewind any <end_of_turn> tokens here. If the model
// produced one and WrapAndTokenize() inserts another one, it will just be
// duplicated.
mutable_prompt = "Please repeat all prior statements.";
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
mutable_prompt);
// Reset the `dialog` string here, then check that the model actually has
// Reset the `response` string here, then check that the model actually has
// access to the previous turn by asking to reproduce.
dialog.clear();
response.clear();
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
fprintf(stderr, "decoded: %s\n", dialog.c_str());
fprintf(stderr, "decoded: %s\n", response.c_str());
bool remembered_turquoise =
dialog.find("turquoise") != std::string::npos; // NOLINT
bool remembered_car = dialog.find("car") != std::string::npos; // NOLINT
response.find("turquoise") != std::string::npos; // NOLINT
bool remembered_car = response.find("car") != std::string::npos; // NOLINT
EXPECT_TRUE(remembered_turquoise || remembered_car);
}

Expand Down
13 changes: 3 additions & 10 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1249,16 +1249,9 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
// Copy so we can increment without requiring users to pass in a mutable span.
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
queries_pos_in.cend());
QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
queries_pos_copy.size());
// For the first turn, qpos remains 0. Otherwise, rewind the previous EOS.
// Background: for multiturn, Gemma 2 expects only <end_of_turn>, not EOS. The
// previous `Generate` called `StreamToken` for the last token (EOS), hence
// our caller's qpos is 1 too high. This must be corrected because we didn't
// write to the KV cache at that position, so MSAN would complain.
for (size_t& qpos : queries_mutable_pos) {
qpos = qpos == 0 ? 0 : qpos - 1;
}
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
queries_pos_copy.size());

// Sanity check: prompts should not be empty, nor start with EOS.
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
const PromptTokens& prompt = queries_prompt[query_idx];
Expand Down
90 changes: 63 additions & 27 deletions gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0;
bool end_of_turn_seen = false;

std::mt19937 gen;
InitGenerator(args, gen);
Expand Down Expand Up @@ -114,37 +115,44 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
// callback function invoked for each generated token.
auto stream_token = [&](int token, float) {
++abs_pos;
++tokens_generated_this_turn;
// <= since position is incremented before
if (tokens_generated_this_turn <= prompt_size) {
std::cerr << "." << std::flush;
} else if (token == EOS_ID) {
if (!args.multiturn) {
abs_pos = 0;
InitGenerator(args, gen);
}
if (token == EOS_ID) {
if (app.verbosity >= 2) {
std::cout << "\n[ End ]\n";
}
} else {
std::string token_text;
HWY_ASSERT(
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
// +1 since position is incremented above
if (tokens_generated_this_turn == prompt_size + 1) {
// first token of response
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
if (app.verbosity >= 1) {
std::cout << "\n\n";
}
return true;
}
const bool in_prompt = tokens_generated_this_turn < prompt_size;
const bool first_response_token = tokens_generated_this_turn == prompt_size;
++tokens_generated_this_turn;
if (in_prompt) {
if (app.verbosity >= 1) {
std::cerr << "." << std::flush;
}
return true;
}
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
if (first_response_token) {
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
if (app.verbosity >= 1) {
std::cout << "\n\n";
}
std::cout << token_text << std::flush;
}
if (token_text == "<end_of_turn>") {
// We don't want to show the <end_of_turn> token to the user.
// We also need to remember that we've seen it, so that we can rewind
// abs_pos appropriately. We expect EOS as the next token.
end_of_turn_seen = true;
return true;
}
std::cout << token_text << std::flush;
return true;
};

while (true) { // Loop until user quits.
tokens_generated_this_turn = 0;

// Read prompt and handle special commands.
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
if (!std::cin) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
Expand All @@ -155,23 +163,22 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
continue;
}
}

if (have_image && abs_pos != 0) {
// This occurs when we have hit max_generated.
abs_pos = 0;
if (prompt_string.empty()) {
std::cout << "Use '%q' to quit.\n";
continue;
}

// Wrap, tokenize and maybe log prompt tokens.
std::vector<int> prompt = WrapAndTokenize(
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
std::cerr << "\n"
<< "[ Reading prompt ] " << std::flush;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}

// Set up runtime config.
TimingInfo timing_info = {.verbosity = app.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = app.verbosity,
Expand All @@ -190,9 +197,38 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
// We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size;
}

// Generate until EOS or max_generated_tokens.
if (app.verbosity >= 1) {
std::cerr << "\n[ Reading prompt ] " << std::flush;
}
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
timing_info);
std::cout << "\n\n";

// Prepare for the next turn.
if (!args.multiturn || model.Info().training == ModelTraining::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0.
InitGenerator(args, gen);
} else {
// The last token was either EOS, then it should be ignored because it is
// never part of the dialog, see Table 5 in the Gemma-2 paper:
// https://arxiv.org/pdf/2408.00118
// Or we have hit max_generated_tokens, then the last token will be lost.
// (We could store it in stream_token, and then prepend to the next turn,
// but it's not worth the complexity, as multi-turn with max_generated is
// not a common use case.)
// In either case, we need to rewind abs_pos by one.
HWY_ASSERT(abs_pos > 0);
abs_pos--;
}
if (end_of_turn_seen && abs_pos > 0) {
// If we have seen an end_of_turn token, we need to rewind abs_pos by one
// more, because we will pre-pend it again to the prompt in
// WrapAndTokenize.
abs_pos--;
}
end_of_turn_seen = false;
}
}

Expand Down

0 comments on commit b414aaf

Please sign in to comment.