Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brandon/flux model loading #6739

Merged
merged 113 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 112 commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
5256f26
Bump diffusers version to include FLUX support.
RyanJDick Aug 6, 2024
562c2cc
Update imports for compatibility with bumped diffusers version.
RyanJDick Aug 6, 2024
b617631
Update HF download logic to work for black-forest-labs/FLUX.1-schnell.
RyanJDick Aug 6, 2024
6a068cc
First draft of FluxTextToImageInvocation.
RyanJDick Aug 6, 2024
5149a3e
Add sentencepiece dependency for the T5 tokenizer.
RyanJDick Aug 7, 2024
761e49f
Use the FluxPipeline.encode_prompt() api rather than trying to run th…
RyanJDick Aug 7, 2024
5e1b3e9
Got FLUX schnell working with 8-bit quantization. Still lots of rough…
RyanJDick Aug 7, 2024
e71c7d0
Minor improvements to FLUX workflow.
RyanJDick Aug 7, 2024
f7753be
Make 8-bit quantization save/reload work for the FLUX transformer. Re…
RyanJDick Aug 8, 2024
28654ec
Add support for 8-bit quantizatino of the FLUX T5XXL text encoder.
RyanJDick Aug 8, 2024
40e9a4e
Make float16 inference work with FLUX on 24GB GPU.
RyanJDick Aug 8, 2024
df1ac07
WIP - experimentation
RyanJDick Aug 9, 2024
11279ab
Make quantized loading fast.
RyanJDick Aug 9, 2024
bcb7f8e
Make quantized loading fast for both T5XXL and FLUX transformer.
RyanJDick Aug 9, 2024
3f1fbc6
Split a FluxTextEncoderInvocation out from the FluxTextToImageInvocat…
RyanJDick Aug 12, 2024
06d35c3
wip
RyanJDick Aug 14, 2024
b1cf2f5
NF4 loading working... I think.
RyanJDick Aug 14, 2024
31c8d76
NF4 inference working
RyanJDick Aug 14, 2024
17f5952
Clean up NF4 implementation.
RyanJDick Aug 15, 2024
52ff3c7
LLM.int8() quantization is working, but still some rough edges to solve.
RyanJDick Aug 15, 2024
307a130
More improvements for LLM.int8() - not fully tested.
RyanJDick Aug 15, 2024
27f33bb
WIP on moving from diffusers to FLUX
RyanJDick Aug 16, 2024
e157ff3
Setup flux model loading in the UI
brandonrising Aug 12, 2024
6779e03
Remove changes to v1 workflow
brandonrising Aug 12, 2024
4bb3438
Manage quantization of models within the loader
brandonrising Aug 12, 2024
6b0f5f4
Run Ruff
brandonrising Aug 14, 2024
3814cd7
Run Ruff
brandonrising Aug 15, 2024
a6ad70e
Some UI cleanup, regenerate schema
brandonrising Aug 15, 2024
f3096a8
Add backend functions and classes for Flux implementation, Update the…
brandonrising Aug 16, 2024
5fc6c28
Run ruff, setup initial text to image node
brandonrising Aug 19, 2024
68d28db
Add nf4 bnb quantized format
brandonrising Aug 19, 2024
f3ebbe1
Remove unused param on _run_vae_decoding in flux text to image
brandonrising Aug 19, 2024
efab4a3
Working inference node with quantized bnb nf4 checkpoint
brandonrising Aug 19, 2024
95a2d97
Install sub directories with folders correctly, ensure consistent dty…
brandonrising Aug 19, 2024
98151ce
Select dev/schnell based on state dict, use correct max seq len based…
brandonrising Aug 19, 2024
5d7e154
Fix FLUX output image clamping. And a few other minor fixes to make i…
RyanJDick Aug 20, 2024
870ecd3
Add tqdm progress bar to FLUX denoising.
RyanJDick Aug 20, 2024
24829b9
Fix support for 8b quantized t5 encoders, update exception messages i…
brandonrising Aug 20, 2024
f36c6d0
Fix styling/lint
brandonrising Aug 20, 2024
bebc6d3
Add t5 encoders and clip embeds to the model manager
brandonrising Aug 20, 2024
3f845d9
Some cleanup of the tags and description of flux nodes
brandonrising Aug 20, 2024
8b3e386
exclude flux models from main model dropdown
Aug 21, 2024
35c263a
add default workflow for flux t2i
Aug 21, 2024
ec360ee
Rename t5Encoder -> t5_encoder.
RyanJDick Aug 21, 2024
c822c3d
Address minor review comments.
RyanJDick Aug 21, 2024
19238ed
Update doc string for import_local_model and remove access_token sinc…
brandonrising Aug 21, 2024
ede26a7
Switch inheritance class of flux model loaders
brandonrising Aug 21, 2024
f0408bb
Various styling and exception type updates
brandonrising Aug 21, 2024
9e888b1
More flux loader cleanup
brandonrising Aug 21, 2024
6afb113
Remove duplicate log_time(...) function.
RyanJDick Aug 21, 2024
519bf71
Add docs to the requantize(...) function explaining why it was copied…
RyanJDick Aug 21, 2024
c549a49
Move requantize.py to the quatnization/ dir.
RyanJDick Aug 21, 2024
41fb09b
update flux_model_loader node to take a T5 encoder from node field in…
maryhipp Aug 21, 2024
c66ccad
add case for clip embed models in probe
maryhipp Aug 21, 2024
9020a8a
add FLUX schnell starter models and submodels as dependenices or adho…
maryhipp Aug 21, 2024
7264920
fix(ui): only exclude flux main models from linear UI dropdown, not m…
Aug 21, 2024
3fe9582
fix(ui): pass base/type when installing models, add flux formats to M…
Aug 21, 2024
24831d4
feat(ui): create new field for t5 encoder models in nodes
Aug 21, 2024
6899762
tsc and lint fix
Aug 21, 2024
f67c4da
fix schema
Aug 21, 2024
a04d479
update default workflow
maryhipp Aug 21, 2024
192eda7
fix(worker) fix T5 type
maryhipp Aug 21, 2024
3c861fd
add better workflow description
maryhipp Aug 21, 2024
dcfdc00
add better workflow name
maryhipp Aug 21, 2024
f51dd36
Fix bug in InvokeInt8Params that was causing it to use double the nec…
RyanJDick Aug 21, 2024
3c9811f
Update load_flux_model_bnb_llm_int8.py to work with a single-file FLU…
RyanJDick Aug 21, 2024
9982bc2
Add docs to the quantization scripts.
RyanJDick Aug 21, 2024
c5c60f5
Fix max_seq_len field description.
RyanJDick Aug 21, 2024
5f3e325
Remove automatic install of models during flux model loader, remove n…
brandonrising Aug 21, 2024
bc6e1ba
Run ruff
brandonrising Aug 21, 2024
407796c
Undo changes to the v2 dir of frontend types
brandonrising Aug 21, 2024
d4ec434
added FLUX dev to starter models
maryhipp Aug 21, 2024
374dc82
Don't install bitsandbytes on macOS
brandonrising Aug 21, 2024
80a46d2
Attribute black-forest-labs/flux for much of the flux code
brandonrising Aug 21, 2024
5406a2f
Mark FLUX nodes as prototypes.
RyanJDick Aug 22, 2024
c9c4e47
Make FLUX get_noise(...) consistent across devices/dtypes.
RyanJDick Aug 22, 2024
22a3b3d
Tidy is_schnell detection logic.
RyanJDick Aug 22, 2024
5307a6f
Add comment about incorrect T5 Tokenizer size calculation.
RyanJDick Aug 22, 2024
5e9ef4b
Rename field positive_prompt -> prompt.
RyanJDick Aug 22, 2024
0e9f6f7
Move prepare_latent_image_patches(...) to sampling.py with all of the…
RyanJDick Aug 22, 2024
b5c937e
Run FLUX VAE decoding in the user's preferred dtype rather than float…
RyanJDick Aug 22, 2024
f34a923
Update macos test vm to macOS-14
brandonrising Aug 22, 2024
b8d4630
Load and unload clip/t5 encoders and run inference separately in text…
brandonrising Aug 23, 2024
a31c02b
Only import bnb quantize file if bitsandbytes is installed
brandonrising Aug 23, 2024
9899e42
Switch flux to using its own conditioning field
brandonrising Aug 23, 2024
cfcd860
Add script for quantizing a T5 model.
RyanJDick Aug 23, 2024
4089ff2
Fixes to the T5XXL quantization script.
RyanJDick Aug 23, 2024
54c48c3
Update the T5 8-bit quantized starter model to use the BnB LLM.int8()…
RyanJDick Aug 23, 2024
5af214b
Remove all references to optimum-quanto and downgrade diffusers.
RyanJDick Aug 23, 2024
f4612a9
Update docs for T5 quantization script.
RyanJDick Aug 23, 2024
18c0ec3
Move quantization scripts to a scripts/ subdir.
RyanJDick Aug 23, 2024
098db5c
Downgrade revert torch version after removing optimum-qanto, and othe…
RyanJDick Aug 23, 2024
dbdd851
Update t5 encoder formats to accurately reflect the quantization stra…
brandonrising Aug 23, 2024
1d6c83b
Switch the CLIP-L start model to use our hosted version - which is mu…
RyanJDick Aug 23, 2024
1413ff9
Replace swish() with torch.nn.functional.silu(h). They are functional…
RyanJDick Aug 23, 2024
bd1b37d
Setup scaffolding for in progress images and add ability to cancel th…
brandonrising Aug 24, 2024
d159fe6
Remove dependency on flux config files
brandonrising Aug 25, 2024
877b88e
ruff
RyanJDick Aug 26, 2024
ae94e48
Remove flux repo dependency
RyanJDick Aug 26, 2024
f046a38
Downgrade accelerate and huggingface-hub deps to original versions.
RyanJDick Aug 26, 2024
9f6f404
ruff format
RyanJDick Aug 26, 2024
9a530c7
Remove outdated TODO.
RyanJDick Aug 26, 2024
bb80697
Only install starter models if not already installed
brandonrising Aug 26, 2024
642a953
Remove in progress images until we're able to make the valuable
brandonrising Aug 26, 2024
a90d098
Remove no longer used code in the flux denoise function
brandonrising Aug 26, 2024
40a3fa5
Fix type error in tsc
brandonrising Aug 26, 2024
b9238b6
Run ruff
brandonrising Aug 26, 2024
5a5ca10
Rename params for flux and flux vae, add comments explaining use of t…
brandonrising Aug 26, 2024
bf59ab3
update default workflow for flux
Aug 26, 2024
bd2692b
remove prompt
Aug 26, 2024
5d42e67
Run ruff
brandonrising Aug 26, 2024
c510234
default workflow: add steps to exposed fields, add more notes
Aug 26, 2024
3b29bad
Update starter model size estimates.
RyanJDick Aug 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
os: macOS-12
os: macOS-14
github-env: $GITHUB_ENV
- platform: windows-cpu
os: windows-2022
Expand Down
11 changes: 11 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum):

# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
Expand All @@ -48,6 +49,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion

Expand Down Expand Up @@ -125,13 +127,16 @@ class FieldDescriptions:
negative_cond = "Negative conditioning tensor"
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
vae_model = "VAE model to load"
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
Expand Down Expand Up @@ -231,6 +236,12 @@ def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)


class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

conditioning_name: str = Field(description="The name of conditioning tensor")


class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

Expand Down
86 changes: 86 additions & 0 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Literal

import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo


@invocation(
"flux_text_encoder",
title="FLUX Text Encoding",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextEncoderInvocation(BaseInvocation):
RyanJDick marked this conversation as resolved.
Show resolved Hide resolved
"""Encodes and preps a prompt for a flux image."""

clip: CLIPField = InputField(
title="CLIP",
description=FieldDescriptions.clip,
input=Input.Connection,
)
t5_encoder: T5EncoderField = InputField(
title="T5Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
prompt: str = InputField(description="Text prompt to encode.")

@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
)

conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)

def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# Load CLIP.
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)

# Load T5.
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)

prompt = [self.prompt]

with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)

t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)

prompt_embeds = t5_encoder(prompt)

with (
clip_text_encoder_info as clip_text_encoder,
clip_tokenizer_info as clip_tokenizer,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)

clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)

pooled_prompt_embeds = clip_encoder(prompt)

assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
return prompt_embeds, pooled_prompt_embeds
172 changes: 172 additions & 0 deletions invokeai/app/invocations/flux_text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import torch
from einops import rearrange
from PIL import Image

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice


@invocation(
"flux_text_to_image",
title="FLUX Text to Image",
tags=["image", "flux"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Text-to-image generation using a FLUX model."""

transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)

latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
image = self._run_vae_decoding(context, latents)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

def _run_diffusion(
self,
context: InvocationContext,
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = torch.bfloat16
RyanJDick marked this conversation as resolved.
Show resolved Hide resolved

# Prepare input noise.
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)

img, img_ids = prepare_latent_img_patches(x)

is_schnell = "schnell" in transformer_info.config.config_path

timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
shift=not is_schnell,
)

bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())

# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)

with transformer_info as transformer:
assert isinstance(transformer, Flux)

def step_callback() -> None:
if context.util.is_canceled():
raise CanceledException

# TODO: Make this look like the image before re-enabling
# latent_image = unpack(img.float(), self.height, self.width)
# latent_image = latent_image.squeeze() # Remove unnecessary dimensions
# flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128]

# # Create a new tensor of the required shape [255, 255, 3]
# latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format

# # Convert to a NumPy array and then to a PIL Image
# image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8))

# (width, height) = image.size
# width *= 8
# height *= 8

# dataURL = image_to_dataURL(image, image_format="JPEG")

# # TODO: move this whole function to invocation context to properly reference these variables
# context._services.events.emit_invocation_denoise_progress(
# context._data.queue_item,
# context._data.invocation,
# state,
# ProgressImage(dataURL=dataURL, width=width, height=height),
# )

x = denoise(
model=transformer,
img=img,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
step_callback=step_callback,
guidance=self.guidance,
)

x = unpack(x.float(), self.height, self.width)

return x

def _run_vae_decoding(
self,
context: InvocationContext,
latents: torch.Tensor,
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(dtype=TorchDevice.choose_torch_dtype())
img = vae.decode(latents)

img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())

return img_pil
Loading
Loading