Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
91a911d by mkielo3 <[email protected]>:

Set device for local PyTorchGemmaLanguageModel

updated docstring

COPYBARA_INTEGRATE_REVIEW=#84 from mkielo3:main 89ed110
PiperOrigin-RevId: 681414971
Change-Id: I1bd9ef18404a2f4d9ea19868c91d3153d86a993d
  • Loading branch information
mkielo3 authored and copybara-github committed Oct 2, 2024
1 parent 01e4ad4 commit 8a0e489
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
13 changes: 8 additions & 5 deletions concordia/language_model/pytorch_gemma_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
*,
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
device: str = 'cpu'
) -> None:
"""Initializes the instance.
Expand All @@ -43,14 +44,16 @@ def __init__(
see transformers.AutoModelForCausalLM at huggingface.
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
device: Specifies whether to use cpu or cuda for model processing.
"""
self._model_name = model_name
self._tokenizer_name = model_name
self._device = device

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

self._model = transformers.GemmaForCausalLM.from_pretrained(
self._model_name)
self._model_name).to(self._device)
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
self._tokenizer_name)

Expand Down Expand Up @@ -80,14 +83,14 @@ def sample_text(
inputs = self._tokenizer(prompt_with_system_message, return_tensors='pt')

generated_tokens = self._model.generate(
inputs.input_ids,
inputs.input_ids.to(self._device),
max_new_tokens=max_tokens,
return_dict_in_generate=True,
output_scores=True,
)

response = self._tokenizer.decode(
np.int64(generated_tokens.sequences[0]),
np.int64(generated_tokens.sequences[0].cpu()),
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
Expand Down Expand Up @@ -116,13 +119,13 @@ def sample_choice(

inputs = self._tokenizer(prompt, return_tensors='pt')
generated_tokens = self._model.generate(
inputs.input_ids,
inputs.input_ids.to(self._device),
max_new_tokens=1,
return_dict_in_generate=True,
output_scores=True,
)
sample = self._tokenizer.batch_decode(
[np.argmax(generated_tokens.scores[0][0])],
[np.argmax(generated_tokens.scores[0][0].cpu())],
skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
answer = sampling.extract_choice_response(sample)
Expand Down
4 changes: 3 additions & 1 deletion concordia/language_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def language_model_setup(
model_name: str,
api_key: str | None = None,
disable_language_model: bool = False,
device: str = 'cpu'
) -> language_model.LanguageModel:
"""Get the wrapped language model.
Expand All @@ -43,6 +44,7 @@ def language_model_setup(
disable_language_model: Optional, if True then disable the language model.
This uses a model that returns an empty string whenever asked for a free
text response and a randome option when asked for a choice.
device: Specifies whether to use cpu or cuda for model processing.
Returns:
The wrapped language model.
Expand All @@ -69,7 +71,7 @@ def language_model_setup(
elif api_type == 'openai':
return gpt_model.GptLanguageModel(model_name, api_key=api_key)
elif api_type == 'pytorch_gemma':
return pytorch_gemma_model.PyTorchGemmaLanguageModel(model_name)
return pytorch_gemma_model.PyTorchGemmaLanguageModel(model_name, device=device)
elif api_type == 'together_ai':
return together_ai.Gemma2(model_name, api_key=api_key)
else:
Expand Down

0 comments on commit 8a0e489

Please sign in to comment.