forked from chaojie/ComfyUI-EasyAnimate
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nodes.py
207 lines (176 loc) · 8.76 KB
/
nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import os
import folder_paths
comfy_path = os.path.dirname(folder_paths.__file__)
import sys
easyanimate_path=f'{comfy_path}/custom_nodes/ComfyUI-EasyAnimate'
sys.path.insert(0,easyanimate_path)
import torch
from diffusers import (AutoencoderKL, DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
PNDMScheduler)
from omegaconf import OmegaConf
from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
from easyanimate.models.transformer3d import Transformer3DModel
from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
from easyanimate.utils.utils import save_videos_grid
from einops import rearrange
checkpoints=['None']
checkpoints.extend(folder_paths.get_filename_list("checkpoints"))
vaes=['None']
vaes.extend(folder_paths.get_filename_list("vae"))
class EasyAnimateLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"pixart_path": (os.listdir(folder_paths.get_folder_paths("diffusers")[0]), {"default": "PixArt-XL-2-512x512"}),
"motion_ckpt": (folder_paths.get_filename_list("checkpoints"), {"default": "easyanimate_v1_mm.safetensors"}),
"sampler_name": (["Euler","Euler A","DPM++","PNDM","DDIM"],{"default":"DPM++"}),
"device":(["cuda","cpu"],{"default":"cuda"}),
},
"optional": {
"transformer_ckpt": (checkpoints, {"default": 'None'}),
"lora_ckpt": (checkpoints, {"default": 'None'}),
"vae_ckpt": (vaes, {"default": 'None'}),
"lora_weight": ("FLOAT", {"default": 0.55, "min": 0, "max": 1, "step": 0.01}),
}
}
RETURN_TYPES = ("EasyAnimateModel",)
FUNCTION = "run"
CATEGORY = "EasyAnimate"
def run(self,pixart_path,motion_ckpt,sampler_name,device,transformer_ckpt='None',lora_ckpt='None',vae_ckpt='None',lora_weight=0.55):
pixart_path=os.path.join(folder_paths.get_folder_paths("diffusers")[0],pixart_path)
# Config and model path
config_path = f"{easyanimate_path}/config/easyanimate_video_motion_module_v1.yaml"
model_name = pixart_path
#model_name = "models/Diffusion_Transformer/PixArt-XL-2-512x512"
# Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM"
sampler_name = "DPM++"
# Load pretrained model if need
transformer_path = None
if transformer_ckpt!='None':
transformer_path = folder_paths.get_full_path("checkpoints", transformer_ckpt)
motion_module_path = folder_paths.get_full_path("checkpoints", motion_ckpt)
#motion_module_path = "models/Motion_Module/easyanimate_v1_mm.safetensors"
vae_path = None
if vae_ckpt!='None':
vae_path = folder_paths.get_full_path("vae", vae_ckpt)
lora_path = None
if lora_ckpt!='None':
lora_path = folder_paths.get_full_path("checkpoints", lora_ckpt)
weight_dtype = torch.float16
guidance_scale = 6.0
seed = 43
num_inference_steps = 30
#lora_weight = 0.55
config = OmegaConf.load(config_path)
# Get Transformer
transformer = Transformer3DModel.from_pretrained_2d(
model_name,
subfolder="transformer",
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs'])
).to(weight_dtype)
if transformer_path is not None:
print(f"From checkpoint: {transformer_path}")
if transformer_path.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(transformer_path)
else:
state_dict = torch.load(transformer_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
if motion_module_path is not None:
print(f"From Motion Module: {motion_module_path}")
if motion_module_path.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(motion_module_path)
else:
state_dict = torch.load(motion_module_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}, {u}")
# Get Vae
if OmegaConf.to_container(config['vae_kwargs'])['enable_magvit']:
Choosen_AutoencoderKL = AutoencoderKLMagvit
else:
Choosen_AutoencoderKL = AutoencoderKL
vae = Choosen_AutoencoderKL.from_pretrained(
model_name,
subfolder="vae",
torch_dtype=weight_dtype
)
if vae_path is not None:
print(f"From checkpoint: {vae_path}")
if vae_path.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(vae_path)
else:
state_dict = torch.load(vae_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = vae.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
# Get Scheduler
Choosen_Scheduler = scheduler_dict = {
"Euler": EulerDiscreteScheduler,
"Euler A": EulerAncestralDiscreteScheduler,
"DPM++": DPMSolverMultistepScheduler,
"PNDM": PNDMScheduler,
"DDIM": DDIMScheduler,
}[sampler_name]
scheduler = Choosen_Scheduler(**OmegaConf.to_container(config['noise_scheduler_kwargs']))
pipeline = EasyAnimatePipeline.from_pretrained(
model_name,
vae=vae,
transformer=transformer,
scheduler=scheduler,
torch_dtype=weight_dtype
)
#pipeline.to(device)
pipeline.enable_model_cpu_offload()
pipeline.transformer.to(device)
pipeline.text_encoder.to('cpu')
pipeline.vae.to('cpu')
if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight)
return (pipeline,)
class EasyAnimateRun:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model":("EasyAnimateModel",),
"prompt":("STRING",{"multiline": True, "default":"A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road."}),
"negative_prompt":("STRING",{"multiline": True, "default":"Strange motion trajectory, a poor composition and deformed video, worst quality, normal quality, low quality, low resolution, duplicate and ugly"}),
"video_length":("INT",{"default":80}),
"num_inference_steps":("INT",{"default":30}),
"width":("INT",{"default":512}),
"height":("INT",{"default":512}),
"guidance_scale":("FLOAT",{"default":6.0}),
"seed":("INT",{"default":1234}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
CATEGORY = "EasyAnimate"
def run(self,model,prompt,negative_prompt,video_length,num_inference_steps,width,height,guidance_scale,seed):
generator = torch.Generator(device='cuda').manual_seed(seed)
with torch.no_grad():
videos = model(
prompt,
video_length = video_length,
negative_prompt = negative_prompt,
height = height,
width = width,
generator = generator,
guidance_scale = guidance_scale,
num_inference_steps = num_inference_steps,
).videos
videos = rearrange(videos, "b c t h w -> b t h w c")
return videos
NODE_CLASS_MAPPINGS = {
"EasyAnimateLoader":EasyAnimateLoader,
"EasyAnimateRun":EasyAnimateRun,
}