Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[training] add ds support to lora sd3. #10378

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
Expand Down Expand Up @@ -1292,11 +1292,13 @@ def save_model_hook(models, weights, output_dir):
text_encoder_two_lora_layers_to_save = None

for model in models:
if isinstance(model, type(unwrap_model(transformer))):
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): # or text_encoder_two
# both text encoders are of the same class, so we check hidden size to distinguish between the two
hidden_size = unwrap_model(model).config.hidden_size
model = unwrap_model(model)
hidden_size = model.config.hidden_size
if hidden_size == 768:
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif hidden_size == 1280:
Expand All @@ -1305,7 +1307,8 @@ def save_model_hook(models, weights, output_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

StableDiffusion3Pipeline.save_lora_weights(
output_dir,
Expand All @@ -1319,17 +1322,31 @@ def load_model_hook(models, input_dir):
text_encoder_one_ = None
text_encoder_two_ = None

while len(models) > 0:
model = models.pop()
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()

if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
text_encoder_one_ = unwrap_model(model)
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
text_encoder_two_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

else:
transformer_ = SD3Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
)
transformer_.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_encoder_one_ = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder"
)
text_encoder_two_ = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2"
)

lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)

Expand Down Expand Up @@ -1829,7 +1846,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1

if accelerator.is_main_process:
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
Expand Down
Loading