forked from kiri-art/docker-diffusers-api
-
Notifications
You must be signed in to change notification settings - Fork 0
/
getPipeline.py
90 lines (76 loc) · 2.75 KB
/
getPipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import time
import os, fnmatch
from diffusers import (
DiffusionPipeline,
pipelines as diffusers_pipelines,
)
from precision import revision, torch_dtype
_pipelines = {}
_availableCommunityPipelines = None
def listAvailablePipelines():
return (
list(
filter(
lambda key: key.endswith("Pipeline"),
list(diffusers_pipelines.__dict__.keys()),
)
)
+ availableCommunityPipelines()
)
def availableCommunityPipelines():
global _availableCommunityPipelines
if not _availableCommunityPipelines:
_availableCommunityPipelines = list(
map(
lambda s: s[0:-3],
fnmatch.filter(os.listdir("diffusers/examples/community"), "*.py"),
)
)
return _availableCommunityPipelines
def clearPipelines():
"""
Clears the pipeline cache. Important to call this when changing the
loaded model, as pipelines include references to the model and would
therefore prevent memory being reclaimed after unloading the previous
model.
"""
pipelines = {}
def getPipelineForModel(pipeline_name: str, model, model_id):
"""
Inits a new pipeline, re-using components from a previously loaded
model. The pipeline is cached and future calls with the same
arguments will return the previously initted instance. Be sure
to call `clearPipelines()` if loading a new model, to allow the
previous model to be garbage collected.
"""
pipeline = _pipelines.get(pipeline_name)
if pipeline:
return pipeline
start = time.time()
if hasattr(diffusers_pipelines, pipeline_name):
if hasattr(model, "components"):
pipeline = getattr(diffusers_pipelines, pipeline_name)(**model.components)
else:
pipeline = getattr(diffusers_pipelines, pipeline_name)(
vae=model.vae,
text_encoder=model.text_encoder,
tokenizer=model.tokenizer,
unet=model.unet,
scheduler=model.scheduler,
safety_checker=model.safety_checker,
feature_extractor=model.feature_extractor,
)
elif pipeline_name in availableCommunityPipelines():
pipeline = DiffusionPipeline.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
custom_pipeline="./diffusers/examples/community/" + pipeline_name + ".py",
local_files_only=True,
**model.components,
)
if pipeline:
_pipelines.update({pipeline_name: pipeline})
diff = round((time.time() - start) * 1000)
print(f"Initialized {pipeline_name} for {model_id} in {diff}ms")
return pipeline