-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Concurrency support using model clone #564
Open
dtrawins
wants to merge
23
commits into
huggingface:main
Choose a base branch
from
dtrawins:concurrency_support_cloneall
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
9b55100
support for concurrency in llm models
dtrawins 9e4ab17
style fixes
dtrawins dcb2a8f
concurrency in seq2seq and stable diffusion classes
dtrawins 03b797f
merge from main
dtrawins eb5db08
merge from upstream
dtrawins e189e40
concurrency via model cloning in encoders and decoders
dtrawins 3395049
merge from upstream
dtrawins f75021b
fix clone performance
dtrawins e9bb941
init
mzegla 5f85cbb
fix next_beam_idx initialization
dtrawins 912cf3a
init version
dtrawins 03b548d
more tests
mzegla d034d97
Merge pull request #1 from dtrawins/multithreading_tests
dtrawins e5d9c75
running conncurrent execution of stable diffusion pipe with cloning
dtrawins ae50484
Merge remote-tracking branch 'dtrawins/stable-diff-test' into concurr…
dtrawins ed3e4a3
merge from main with fixes
dtrawins 34e7e28
add concurrency examples
dtrawins f4d21d8
preserve request attribure as deprecated
dtrawins 22b529e
merge from main
dtrawins c971473
drop not needed tests
dtrawins cbdb304
style fix
dtrawins bd84ca9
Merge remote-tracking branch 'origin/main' into concurrency_support_c…
dtrawins 30437ae
fix tests without gpu
dtrawins File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Execution in multi-threaded environment | ||
|
||
## Overview | ||
|
||
This example demonstrates how to execute the pipelines from Hugging Face transformers with multi concurency. | ||
A typical scenrio is with multi threaded application without duplicating the model allocation in the host memeory. | ||
|
||
By default, the execution of the transformers with OpenVINO Runtime backend is single threaded. Runing the generation process parallel can cause an error | ||
`RuntimeError: Infer Request is busy`. | ||
|
||
A simple technic can overcome this limitation using `clone` method on the model or a pipeline. It duplicates the execution object while sharing the OpenVINO compiled model in the host memory. The clone object should not change the model by reshaping, changing precision and recompiling. | ||
The snippet below applies this concept: | ||
|
||
```python | ||
pipe = OVStableDiffusionPipeline.from_pretrained( | ||
MODEL_PATH, ov_config=OV_CONFIG, compile=True | ||
) | ||
def thread(prompt, results): | ||
pipe_exec = pipe.clone() | ||
images = pipe_exec(prompt).images | ||
# Do something with images | ||
|
||
T1 = threading.Thread(target=thread, args=("my prompt")) | ||
T1.start() | ||
``` | ||
Note that the `clone` operation is quick and is not duplicating the memory usage. It just creates new context for the generating algorithm. | ||
|
||
Check the simple examples how it can be applied in practice. | ||
|
||
## Preparing python environment | ||
```bash | ||
pip install -r examples/openvino/multithreading/requirement.txt | ||
``` | ||
|
||
## Text generation | ||
|
||
```bash | ||
python examples/openvino/multithreading/gen_text.py | ||
``` | ||
## Image generation | ||
```bash | ||
python examples/openvino/multithreading/gen_text.py | ||
``` | ||
|
||
## Text translation with seq2seq | ||
|
||
```bash | ||
python examples/openvino/multithreading/gen_seq2seq.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import datetime | ||
import threading | ||
|
||
from optimum.intel.openvino import OVStableDiffusionPipeline | ||
|
||
|
||
MODEL_PATH = "runwayml/stable-diffusion-v1-5" | ||
OV_CONFIG = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1"} | ||
|
||
|
||
pipe = OVStableDiffusionPipeline.from_pretrained( | ||
MODEL_PATH, ov_config=OV_CONFIG, compile=True, dynamic_shapes=True, export=True | ||
) | ||
|
||
vae_decoder_clon = pipe.vae_decoder.clone() | ||
unet_clon = pipe.unet.clone() | ||
|
||
prompt1 = [" Zebras in space "] | ||
prompt2 = [" The statue of liberty in New York", " Big Ben in London "] | ||
prompt3 = [" pigs on the grass field", "beach in the storm", "sail yacht on the ocean"] | ||
|
||
prompts = [prompt1, prompt2, prompt3] | ||
|
||
NUM_THREADS = 3 | ||
|
||
threads = [None] * NUM_THREADS | ||
results = [None] * NUM_THREADS | ||
|
||
|
||
def save_response(t, p, r): | ||
print("THREAD", t) | ||
print("PROMPT:", p) | ||
for i in range(len(r)): | ||
print("IMG:", i) | ||
r[i].save("img_" + str(t) + "_" + str(i) + ".png", format="PNG") | ||
|
||
|
||
def gen_thread(prompt, results, i): | ||
start = datetime.datetime.now() | ||
pipe_exec = pipe.clone() | ||
end = datetime.datetime.now() | ||
print("Clonning time [s]", ((end - start).total_seconds())) | ||
text = prompt | ||
images = pipe_exec(text).images | ||
results[i] = images | ||
|
||
|
||
start = datetime.datetime.now() | ||
for i in range(len(threads)): | ||
threads[i] = threading.Thread(target=gen_thread, args=(prompts[i], results, i)) | ||
threads[i].start() | ||
nu_img = 0 | ||
for i in range(len(threads)): | ||
threads[i].join() | ||
nu_img += len(results[i]) | ||
end = datetime.datetime.now() | ||
|
||
for i in range(len(threads)): | ||
save_response(i, prompts[i], results[i]) | ||
|
||
print("Generation time [s]", ((end - start).total_seconds()), "images:", nu_img) | ||
print("Throughput:", nu_img * 60 / ((end - start).total_seconds()), "images/min") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import datetime | ||
import threading | ||
|
||
from transformers import AutoTokenizer, pipeline | ||
|
||
from optimum.intel import OVModelForSeq2SeqLM | ||
|
||
|
||
model_id = "echarlaix/t5-small-openvino" | ||
model = OVModelForSeq2SeqLM.from_pretrained(model_id) | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer) | ||
|
||
prompt1 = ["I live in Europe"] | ||
prompt2 = ["What is your name?", "The dog is very happy"] | ||
prompt3 = ["It's a beautiful weather today", "Yes", "Good morning"] | ||
prompts = [prompt1, prompt2, prompt3] | ||
|
||
NUM_THREADS = 3 | ||
|
||
threads = [None] * NUM_THREADS | ||
results = [None] * NUM_THREADS | ||
|
||
|
||
def print_response(t, p, r): | ||
print("THREAD", t) | ||
print("PROMPT:", p) | ||
for i in range(len(r)): | ||
print("TRANSLATION", i, ":", r[i]["translation_text"]) | ||
|
||
|
||
def gen_thread(prompt, results, i): | ||
translations = pipe(prompt) | ||
results[i] = translations | ||
|
||
|
||
start = datetime.datetime.now() | ||
for i in range(len(threads)): | ||
threads[i] = threading.Thread(target=gen_thread, args=(prompts[i], results, i)) | ||
threads[i].start() | ||
nu_trans = 0 | ||
for i in range(len(threads)): | ||
threads[i].join() | ||
nu_trans += len(results[i]) | ||
end = datetime.datetime.now() | ||
|
||
for i in range(len(threads)): | ||
print_response(i, prompts[i], results[i]) | ||
|
||
print("Generation time [s]", ((end - start).total_seconds()), "translations:", nu_trans) | ||
print("Throughput:", nu_trans / ((end - start).total_seconds()), "translations/s") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import threading | ||
from datetime import datetime | ||
|
||
from transformers import AutoConfig, AutoTokenizer, set_seed | ||
|
||
from optimum.intel import OVModelForCausalLM | ||
|
||
|
||
set_seed(10) | ||
model_id = "togethercomputer/RedPajama-INCITE-Chat-3B-v1" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
tokenizer.pad_token = "[PAD]" | ||
tokenizer.padding_side = "left" | ||
NUM_THREADS = 3 | ||
prompt1 = ["<human>: Question: What is the weather like now? Answer: <bot>"] | ||
prompt2 = ["<human>: Question: What is Openvino?", "<human>: Question: What the the relativity theory? Answer: <bot>"] | ||
prompt3 = [ | ||
"<human>: Question: Are cats smarter that dogs? Answer: <bot>", | ||
"<human>: Question: How big is an elephant? Answer: <bot>", | ||
"<human>: Question: The water in the ocean is much hotter than before? Answer: <bot>", | ||
] | ||
prompts = [prompt1, prompt2, prompt3] | ||
|
||
OV_CONFIG = {"PERFORMANCE_HINT": "LATENCY", "CACHE_DIR": "", "NUM_STREAMS": "2"} | ||
model = OVModelForCausalLM.from_pretrained( | ||
model_id, | ||
config=AutoConfig.from_pretrained(model_id, trust_remote_code=True), | ||
ov_config=OV_CONFIG, | ||
compile=True, | ||
export=True, | ||
) | ||
|
||
threads = [None] * NUM_THREADS | ||
results = [None] * NUM_THREADS | ||
|
||
|
||
def print_response(t, p, r): | ||
print("THREAD", t) | ||
print("PROMPT:", p) | ||
for answer in r: | ||
print("Answer:") | ||
print(tokenizer.decode(answer, skip_special_tokens=True)) | ||
|
||
|
||
def gen_thread(prompt, results, i): | ||
inputs = tokenizer(prompt, return_tensors="pt", padding=True) | ||
generate_kwargs = { | ||
"input_ids": inputs.input_ids, | ||
"max_new_tokens": 200, | ||
"temperature": 1.0, | ||
"do_sample": True, | ||
"top_p": 1.0, | ||
"top_k": 50, | ||
"num_beams": 5, | ||
"repetition_penalty": 1.1, | ||
} | ||
start = datetime.now() | ||
model_exec = model.clone() | ||
end = datetime.now() | ||
print("cloning model duration", (end - start).total_seconds() * 1000000, "us") | ||
outputs = model_exec.generate(**generate_kwargs) | ||
num_tok = 0 | ||
for x in range(len(prompt)): | ||
num_tok += outputs[x].numel() - inputs.get("input_ids")[x].numel() | ||
results[i] = outputs, num_tok | ||
|
||
|
||
start = datetime.now() | ||
for i in range(len(threads)): | ||
threads[i] = threading.Thread(target=gen_thread, args=(prompts[i], results, i)) | ||
threads[i].start() | ||
|
||
total_tok = 0 | ||
for i in range(len(threads)): | ||
threads[i].join() | ||
total_tok += results[i][1] | ||
end = datetime.now() | ||
|
||
for i in range(len(threads)): | ||
print_response(i, prompts[i], results[i][0]) | ||
|
||
print("Generation time [s]", ((end - start).total_seconds()), "tokens:", total_tok) | ||
print("Throughput:", total_tok * 60 / ((end - start).total_seconds()), "tokens/min") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
optimum-intel[openvino, nncf]"@git+https://github.com/huggingface/optimum-intel.git | ||
transformers | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transformers is a dependency of optimum-intel