diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 8805989a69b..f22e29e3e9b 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -60,7 +60,7 @@ jobs: extra-index-url: 'https://download.pytorch.org/whl/cpu' github-env: $GITHUB_ENV - platform: macos-default - os: macOS-12 + os: macOS-14 github-env: $GITHUB_ENV - platform: windows-cpu os: windows-2022 diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 9efcf2148f7..3a4e2cbddb1 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum): # region Model Field Types MainModel = "MainModelField" + FluxMainModel = "FluxMainModelField" SDXLMainModel = "SDXLMainModelField" SDXLRefinerModel = "SDXLRefinerModelField" ONNXModel = "ONNXModelField" @@ -48,6 +49,7 @@ class UIType(str, Enum, metaclass=MetaEnum): ControlNetModel = "ControlNetModelField" IPAdapterModel = "IPAdapterModelField" T2IAdapterModel = "T2IAdapterModelField" + T5EncoderModel = "T5EncoderModelField" SpandrelImageToImageModel = "SpandrelImageToImageModelField" # endregion @@ -125,13 +127,16 @@ class FieldDescriptions: negative_cond = "Negative conditioning tensor" noise = "Noise tensor" clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" + t5_encoder = "T5 tokenizer and text encoder" unet = "UNet (scheduler, LoRAs)" + transformer = "Transformer" vae = "VAE" cond = "Conditioning tensor" controlnet_model = "ControlNet model to load" vae_model = "VAE model to load" lora_model = "LoRA model to load" main_model = "Main model (UNet, VAE, CLIP) to load" + flux_model = "Flux model (Transformer) to load" sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" @@ -231,6 +236,12 @@ def tuple(self) -> Tuple[int, int, int, int]: return (self.r, self.g, self.b, self.a) +class FluxConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") + + class ConditioningField(BaseModel): """A conditioning tensor primitive value""" diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py new file mode 100644 index 00000000000..0e7ebd6d69b --- /dev/null +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -0,0 +1,86 @@ +from typing import Literal + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField +from invokeai.app.invocations.model import CLIPField, T5EncoderField +from invokeai.app.invocations.primitives import FluxConditioningOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.modules.conditioner import HFEncoder +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo + + +@invocation( + "flux_text_encoder", + title="FLUX Text Encoding", + tags=["prompt", "conditioning", "flux"], + category="conditioning", + version="1.0.0", + classification=Classification.Prototype, +) +class FluxTextEncoderInvocation(BaseInvocation): + """Encodes and preps a prompt for a flux image.""" + + clip: CLIPField = InputField( + title="CLIP", + description=FieldDescriptions.clip, + input=Input.Connection, + ) + t5_encoder: T5EncoderField = InputField( + title="T5Encoder", + description=FieldDescriptions.t5_encoder, + input=Input.Connection, + ) + t5_max_seq_len: Literal[256, 512] = InputField( + description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models." + ) + prompt: str = InputField(description="Text prompt to encode.") + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> FluxConditioningOutput: + t5_embeddings, clip_embeddings = self._encode_prompt(context) + conditioning_data = ConditioningFieldData( + conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] + ) + + conditioning_name = context.conditioning.save(conditioning_data) + return FluxConditioningOutput.build(conditioning_name) + + def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]: + # Load CLIP. + clip_tokenizer_info = context.models.load(self.clip.tokenizer) + clip_text_encoder_info = context.models.load(self.clip.text_encoder) + + # Load T5. + t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) + t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder) + + prompt = [self.prompt] + + with ( + t5_text_encoder_info as t5_text_encoder, + t5_tokenizer_info as t5_tokenizer, + ): + assert isinstance(t5_text_encoder, T5EncoderModel) + assert isinstance(t5_tokenizer, T5Tokenizer) + + t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len) + + prompt_embeds = t5_encoder(prompt) + + with ( + clip_text_encoder_info as clip_text_encoder, + clip_tokenizer_info as clip_tokenizer, + ): + assert isinstance(clip_text_encoder, CLIPTextModel) + assert isinstance(clip_tokenizer, CLIPTokenizer) + + clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77) + + pooled_prompt_embeds = clip_encoder(prompt) + + assert isinstance(prompt_embeds, torch.Tensor) + assert isinstance(pooled_prompt_embeds, torch.Tensor) + return prompt_embeds, pooled_prompt_embeds diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py new file mode 100644 index 00000000000..b6ff06c67bf --- /dev/null +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -0,0 +1,172 @@ +import torch +from einops import rearrange +from PIL import Image + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + FluxConditioningField, + Input, + InputField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import TransformerField, VAEField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.session_processor.session_processor_common import CanceledException +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.modules.autoencoder import AutoEncoder +from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, prepare_latent_img_patches, unpack +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "flux_text_to_image", + title="FLUX Text to Image", + tags=["image", "flux"], + category="image", + version="1.0.0", + classification=Classification.Prototype, +) +class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): + """Text-to-image generation using a FLUX model.""" + + transformer: TransformerField = InputField( + description=FieldDescriptions.flux_model, + input=Input.Connection, + title="Transformer", + ) + vae: VAEField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + positive_text_conditioning: FluxConditioningField = InputField( + description=FieldDescriptions.positive_cond, input=Input.Connection + ) + width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") + height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") + num_steps: int = InputField( + default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50." + ) + guidance: float = InputField( + default=4.0, + description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.", + ) + seed: int = InputField(default=0, description="Randomness seed for reproducibility.") + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + # Load the conditioning data. + cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name) + assert len(cond_data.conditionings) == 1 + flux_conditioning = cond_data.conditionings[0] + assert isinstance(flux_conditioning, FLUXConditioningInfo) + + latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds) + image = self._run_vae_decoding(context, latents) + image_dto = context.images.save(image=image) + return ImageOutput.build(image_dto) + + def _run_diffusion( + self, + context: InvocationContext, + clip_embeddings: torch.Tensor, + t5_embeddings: torch.Tensor, + ): + transformer_info = context.models.load(self.transformer.transformer) + inference_dtype = torch.bfloat16 + + # Prepare input noise. + x = get_noise( + num_samples=1, + height=self.height, + width=self.width, + device=TorchDevice.choose_torch_device(), + dtype=inference_dtype, + seed=self.seed, + ) + + img, img_ids = prepare_latent_img_patches(x) + + is_schnell = "schnell" in transformer_info.config.config_path + + timesteps = get_schedule( + num_steps=self.num_steps, + image_seq_len=img.shape[1], + shift=not is_schnell, + ) + + bs, t5_seq_len, _ = t5_embeddings.shape + txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) + + # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from + # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems + # if the cache is not empty. + context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) + + with transformer_info as transformer: + assert isinstance(transformer, Flux) + + def step_callback() -> None: + if context.util.is_canceled(): + raise CanceledException + + # TODO: Make this look like the image before re-enabling + # latent_image = unpack(img.float(), self.height, self.width) + # latent_image = latent_image.squeeze() # Remove unnecessary dimensions + # flattened_tensor = latent_image.reshape(-1) # Flatten to shape [48*128*128] + + # # Create a new tensor of the required shape [255, 255, 3] + # latent_image = flattened_tensor[: 255 * 255 * 3].reshape(255, 255, 3) # Reshape to RGB format + + # # Convert to a NumPy array and then to a PIL Image + # image = Image.fromarray(latent_image.cpu().numpy().astype(np.uint8)) + + # (width, height) = image.size + # width *= 8 + # height *= 8 + + # dataURL = image_to_dataURL(image, image_format="JPEG") + + # # TODO: move this whole function to invocation context to properly reference these variables + # context._services.events.emit_invocation_denoise_progress( + # context._data.queue_item, + # context._data.invocation, + # state, + # ProgressImage(dataURL=dataURL, width=width, height=height), + # ) + + x = denoise( + model=transformer, + img=img, + img_ids=img_ids, + txt=t5_embeddings, + txt_ids=txt_ids, + vec=clip_embeddings, + timesteps=timesteps, + step_callback=step_callback, + guidance=self.guidance, + ) + + x = unpack(x.float(), self.height, self.width) + + return x + + def _run_vae_decoding( + self, + context: InvocationContext, + latents: torch.Tensor, + ) -> Image.Image: + vae_info = context.models.load(self.vae.vae) + with vae_info as vae: + assert isinstance(vae, AutoEncoder) + latents = latents.to(dtype=TorchDevice.choose_torch_dtype()) + img = vae.decode(latents) + + img = img.clamp(-1, 1) + img = rearrange(img[0], "c h w -> h w c") + img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) + + return img_pil diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c0d067c0a7a..88874f302a7 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,5 +1,5 @@ import copy -from typing import List, Optional +from typing import List, Literal, Optional from pydantic import BaseModel, Field @@ -13,7 +13,14 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType +from invokeai.backend.flux.util import max_seq_lengths +from invokeai.backend.model_manager.config import ( + AnyModelConfig, + BaseModelType, + CheckpointConfigBase, + ModelType, + SubModelType, +) class ModelIdentifierField(BaseModel): @@ -60,6 +67,15 @@ class CLIPField(BaseModel): loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") +class TransformerField(BaseModel): + transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel") + + +class T5EncoderField(BaseModel): + tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") + text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") + + class VAEField(BaseModel): vae: ModelIdentifierField = Field(description="Info to load vae submodel") seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') @@ -122,6 +138,112 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput: return ModelIdentifierOutput(model=self.model) +@invocation_output("flux_model_loader_output") +class FluxModelLoaderOutput(BaseInvocationOutput): + """Flux base model loader output""" + + transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") + clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP") + t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder") + vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") + max_seq_len: Literal[256, 512] = OutputField( + description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)", + title="Max Seq Length", + ) + + +@invocation( + "flux_model_loader", + title="Flux Main Model", + tags=["model", "flux"], + category="model", + version="1.0.3", + classification=Classification.Prototype, +) +class FluxModelLoaderInvocation(BaseInvocation): + """Loads a flux base model, outputting its submodels.""" + + model: ModelIdentifierField = InputField( + description=FieldDescriptions.flux_model, + ui_type=UIType.FluxMainModel, + input=Input.Direct, + ) + + t5_encoder: ModelIdentifierField = InputField( + description=FieldDescriptions.t5_encoder, + ui_type=UIType.T5EncoderModel, + input=Input.Direct, + ) + + def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput: + model_key = self.model.key + + if not context.models.exists(model_key): + raise ValueError(f"Unknown model: {model_key}") + transformer = self._get_model(context, SubModelType.Transformer) + tokenizer = self._get_model(context, SubModelType.Tokenizer) + tokenizer2 = self._get_model(context, SubModelType.Tokenizer2) + clip_encoder = self._get_model(context, SubModelType.TextEncoder) + t5_encoder = self._get_model(context, SubModelType.TextEncoder2) + vae = self._get_model(context, SubModelType.VAE) + transformer_config = context.models.get_config(transformer) + assert isinstance(transformer_config, CheckpointConfigBase) + + return FluxModelLoaderOutput( + transformer=TransformerField(transformer=transformer), + clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), + t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder), + vae=VAEField(vae=vae), + max_seq_len=max_seq_lengths[transformer_config.config_path], + ) + + def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField: + match submodel: + case SubModelType.Transformer: + return self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) + case SubModelType.VAE: + return self._pull_model_from_mm( + context, + SubModelType.VAE, + "FLUX.1-schnell_ae", + ModelType.VAE, + BaseModelType.Flux, + ) + case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]: + return self._pull_model_from_mm( + context, + submodel, + "clip-vit-large-patch14", + ModelType.CLIPEmbed, + BaseModelType.Any, + ) + case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]: + return self._pull_model_from_mm( + context, + submodel, + self.t5_encoder.name, + ModelType.T5Encoder, + BaseModelType.Any, + ) + case _: + raise Exception(f"{submodel.value} is not a supported submodule for a flux model") + + def _pull_model_from_mm( + self, + context: InvocationContext, + submodel: SubModelType, + name: str, + type: ModelType, + base: BaseModelType, + ): + if models := context.models.search_by_attrs(name=name, base=base, type=type): + if len(models) != 1: + raise Exception(f"Multiple models detected for selected model with name {name}") + return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel}) + else: + raise ValueError(f"Please install the {base}:{type} model named {name} via starter models") + + @invocation( "main_model_loader", title="Main Model", diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 3655554f3bf..bb136d62fdd 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -12,6 +12,7 @@ ConditioningField, DenoiseMaskField, FieldDescriptions, + FluxConditioningField, ImageField, Input, InputField, @@ -414,6 +415,17 @@ class MaskOutput(BaseInvocationOutput): height: int = OutputField(description="The height of the mask in pixels.") +@invocation_output("flux_conditioning_output") +class FluxConditioningOutput(BaseInvocationOutput): + """Base class for nodes that output a single conditioning tensor""" + + conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond) + + @classmethod + def build(cls, conditioning_name: str) -> "FluxConditioningOutput": + return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name)) + + @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index e1d784f5bf4..4ff48034385 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -783,8 +783,9 @@ def _multifile_download( # So what we do is to synthesize a folder named "sdxl-turbo_vae" here. if subfolder: top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/" - path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/ - path_to_add = Path(f"{top}_{subfolder}") + path_to_remove = top / subfolder # sdxl-turbo/vae/ + subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_") + path_to_add = Path(f"{top}_{subfolder_rename}") else: path_to_remove = Path(".") path_to_add = Path(".") diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 46d11d4ddf2..9cc1486a019 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -77,6 +77,7 @@ class ModelRecordChanges(BaseModelExcludeNull): type: Optional[ModelType] = Field(description="Type of model", default=None) key: Optional[str] = Field(description="Database ID for this model", default=None) hash: Optional[str] = Field(description="hash of model file", default=None) + format: Optional[str] = Field(description="format of model file", default=None) trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field( description="Default settings for this model", default=None diff --git a/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json b/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json new file mode 100644 index 00000000000..783fdeed5e3 --- /dev/null +++ b/invokeai/app/services/workflow_records/default_workflows/Flux Text to Image.json @@ -0,0 +1,266 @@ +{ + "name": "FLUX Text to Image", + "author": "InvokeAI", + "description": "A simple text-to-image workflow using FLUX dev or schnell models. Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.", + "version": "1.0.0", + "contact": "", + "tags": "text2image, flux", + "notes": "Prerequisite model downloads: T5 Encoder, CLIP-L Encoder, and FLUX VAE. Quantized and un-quantized versions can be found in the starter models tab within your Model Manager. We recommend 4 steps for FLUX schnell models and 30 steps for FLUX dev models.", + "exposedFields": [ + { + "nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", + "fieldName": "model" + }, + { + "nodeId": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", + "fieldName": "prompt" + }, + { + "nodeId": "159bdf1b-79e7-4174-b86e-d40e646964c8", + "fieldName": "num_steps" + }, + { + "nodeId": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", + "fieldName": "t5_encoder" + } + ], + "meta": { + "version": "3.0.0", + "category": "default" + }, + "nodes": [ + { + "id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", + "type": "invocation", + "data": { + "id": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", + "type": "flux_model_loader", + "version": "1.0.3", + "label": "", + "notes": "", + "isOpen": true, + "isIntermediate": true, + "useCache": false, + "inputs": { + "model": { + "name": "model", + "label": "Model (Starter Models can be found in Model Manager)", + "value": { + "key": "f04a7a2f-c74d-4538-8d5e-879a53501662", + "hash": "random:4875da7a9508444ffa706f61961c260d0c6729f6181a86b31fad06df1277b850", + "name": "FLUX Dev (Quantized)", + "base": "flux", + "type": "main" + } + }, + "t5_encoder": { + "name": "t5_encoder", + "label": "T 5 Encoder (Starter Models can be found in Model Manager)", + "value": { + "key": "20dcd9ec-5fbb-4012-8401-049e707da5e5", + "hash": "random:f986be43ff3502169e4adbdcee158afb0e0a65a1edc4cab16ae59963630cfd8f", + "name": "t5_bnb_int8_quantized_encoder", + "base": "any", + "type": "t5_encoder" + } + } + } + }, + "position": { + "x": 337.09365228062825, + "y": 40.63469521079861 + } + }, + { + "id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", + "type": "invocation", + "data": { + "id": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", + "type": "flux_text_encoder", + "version": "1.0.0", + "label": "", + "notes": "", + "isOpen": true, + "isIntermediate": true, + "useCache": true, + "inputs": { + "clip": { + "name": "clip", + "label": "" + }, + "t5_encoder": { + "name": "t5_encoder", + "label": "" + }, + "t5_max_seq_len": { + "name": "t5_max_seq_len", + "label": "T5 Max Seq Len", + "value": 256 + }, + "prompt": { + "name": "prompt", + "label": "", + "value": "a cat" + } + } + }, + "position": { + "x": 824.1970602278849, + "y": 146.98251001061735 + } + }, + { + "id": "4754c534-a5f3-4ad0-9382-7887985e668c", + "type": "invocation", + "data": { + "id": "4754c534-a5f3-4ad0-9382-7887985e668c", + "type": "rand_int", + "version": "1.0.1", + "label": "", + "notes": "", + "isOpen": true, + "isIntermediate": true, + "useCache": false, + "inputs": { + "low": { + "name": "low", + "label": "", + "value": 0 + }, + "high": { + "name": "high", + "label": "", + "value": 2147483647 + } + } + }, + "position": { + "x": 822.9899179655476, + "y": 360.9657214885052 + } + }, + { + "id": "159bdf1b-79e7-4174-b86e-d40e646964c8", + "type": "invocation", + "data": { + "id": "159bdf1b-79e7-4174-b86e-d40e646964c8", + "type": "flux_text_to_image", + "version": "1.0.0", + "label": "", + "notes": "", + "isOpen": true, + "isIntermediate": false, + "useCache": true, + "inputs": { + "board": { + "name": "board", + "label": "" + }, + "metadata": { + "name": "metadata", + "label": "" + }, + "transformer": { + "name": "transformer", + "label": "" + }, + "vae": { + "name": "vae", + "label": "" + }, + "positive_text_conditioning": { + "name": "positive_text_conditioning", + "label": "" + }, + "width": { + "name": "width", + "label": "", + "value": 1024 + }, + "height": { + "name": "height", + "label": "", + "value": 1024 + }, + "num_steps": { + "name": "num_steps", + "label": "Steps (Recommend 30 for Dev, 4 for Schnell)", + "value": 30 + }, + "guidance": { + "name": "guidance", + "label": "", + "value": 4 + }, + "seed": { + "name": "seed", + "label": "", + "value": 0 + } + } + }, + "position": { + "x": 1216.3900791301849, + "y": 5.500841807102248 + } + } + ], + "edges": [ + { + "id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33amax_seq_len-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_max_seq_len", + "type": "default", + "source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", + "target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", + "sourceHandle": "max_seq_len", + "targetHandle": "t5_max_seq_len" + }, + { + "id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33avae-159bdf1b-79e7-4174-b86e-d40e646964c8vae", + "type": "default", + "source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", + "target": "159bdf1b-79e7-4174-b86e-d40e646964c8", + "sourceHandle": "vae", + "targetHandle": "vae" + }, + { + "id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33atransformer-159bdf1b-79e7-4174-b86e-d40e646964c8transformer", + "type": "default", + "source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", + "target": "159bdf1b-79e7-4174-b86e-d40e646964c8", + "sourceHandle": "transformer", + "targetHandle": "transformer" + }, + { + "id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33at5_encoder-01f674f8-b3d1-4df1-acac-6cb8e0bfb63ct5_encoder", + "type": "default", + "source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", + "target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", + "sourceHandle": "t5_encoder", + "targetHandle": "t5_encoder" + }, + { + "id": "reactflow__edge-4f0207c2-ff40-41fd-b047-ad33fbb1c33aclip-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cclip", + "type": "default", + "source": "4f0207c2-ff40-41fd-b047-ad33fbb1c33a", + "target": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", + "sourceHandle": "clip", + "targetHandle": "clip" + }, + { + "id": "reactflow__edge-01f674f8-b3d1-4df1-acac-6cb8e0bfb63cconditioning-159bdf1b-79e7-4174-b86e-d40e646964c8positive_text_conditioning", + "type": "default", + "source": "01f674f8-b3d1-4df1-acac-6cb8e0bfb63c", + "target": "159bdf1b-79e7-4174-b86e-d40e646964c8", + "sourceHandle": "conditioning", + "targetHandle": "positive_text_conditioning" + }, + { + "id": "reactflow__edge-4754c534-a5f3-4ad0-9382-7887985e668cvalue-159bdf1b-79e7-4174-b86e-d40e646964c8seed", + "type": "default", + "source": "4754c534-a5f3-4ad0-9382-7887985e668c", + "target": "159bdf1b-79e7-4174-b86e-d40e646964c8", + "sourceHandle": "value", + "targetHandle": "seed" + } + ] +} diff --git a/invokeai/backend/flux/math.py b/invokeai/backend/flux/math.py new file mode 100644 index 00000000000..aa719b7c072 --- /dev/null +++ b/invokeai/backend/flux/math.py @@ -0,0 +1,32 @@ +# Initially pulled from https://github.com/black-forest-labs/flux + +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/invokeai/backend/flux/model.py b/invokeai/backend/flux/model.py new file mode 100644 index 00000000000..5358ddf0bc0 --- /dev/null +++ b/invokeai/backend/flux/model.py @@ -0,0 +1,117 @@ +# Initially pulled from https://github.com/black-forest-labs/flux + +from dataclasses import dataclass + +import torch +from torch import Tensor, nn + +from invokeai.backend.flux.modules.layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/invokeai/backend/flux/modules/autoencoder.py b/invokeai/backend/flux/modules/autoencoder.py new file mode 100644 index 00000000000..237769aba71 --- /dev/null +++ b/invokeai/backend/flux/modules/autoencoder.py @@ -0,0 +1,310 @@ +# Initially pulled from https://github.com/black-forest-labs/flux + +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = torch.nn.functional.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = torch.nn.functional.silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = torch.nn.functional.silu(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = torch.nn.functional.silu(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/invokeai/backend/flux/modules/conditioner.py b/invokeai/backend/flux/modules/conditioner.py new file mode 100644 index 00000000000..de6d8256c4f --- /dev/null +++ b/invokeai/backend/flux/modules/conditioner.py @@ -0,0 +1,33 @@ +# Initially pulled from https://github.com/black-forest-labs/flux + +from torch import Tensor, nn +from transformers import PreTrainedModel, PreTrainedTokenizer + + +class HFEncoder(nn.Module): + def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int): + super().__init__() + self.max_length = max_length + self.is_clip = is_clip + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + self.tokenizer = tokenizer + self.hf_module = encoder + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, text: list[str]) -> Tensor: + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + outputs = self.hf_module( + input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/invokeai/backend/flux/modules/layers.py b/invokeai/backend/flux/modules/layers.py new file mode 100644 index 00000000000..23dc2448d3c --- /dev/null +++ b/invokeai/backend/flux/modules/layers.py @@ -0,0 +1,253 @@ +# Initially pulled from https://github.com/black-forest-labs/flux + +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from invokeai.backend.flux.math import attention, rope + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/invokeai/backend/flux/sampling.py b/invokeai/backend/flux/sampling.py new file mode 100644 index 00000000000..19de48ae81a --- /dev/null +++ b/invokeai/backend/flux/sampling.py @@ -0,0 +1,176 @@ +# Initially pulled from https://github.com/black-forest-labs/flux + +import math +from typing import Callable + +import torch +from einops import rearrange, repeat +from torch import Tensor +from tqdm import tqdm + +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.modules.conditioner import HFEncoder + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + # We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes. + rand_device = "cpu" + rand_dtype = torch.float16 + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=rand_device, + dtype=rand_dtype, + generator=torch.Generator(device=rand_device).manual_seed(seed), + ).to(device=device, dtype=dtype) + + +def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + # sampling parameters + timesteps: list[float], + step_callback: Callable[[], None], + guidance: float = 4.0, +): + dtype = model.txt_in.bias.dtype + + # TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller. + img = img.to(dtype=dtype) + img_ids = img_ids.to(dtype=dtype) + txt = txt.to(dtype=dtype) + txt_ids = txt_ids.to(dtype=dtype) + vec = vec.to(dtype=dtype) + + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + + img = img + (t_prev - t_curr) * pred + step_callback() + + return img + + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) + + +def prepare_latent_img_patches(latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert an input image in latent space to patches for diffusion. + + This implementation was extracted from: + https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32 + + Returns: + tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo. + """ + bs, c, h, w = latent_img.shape + + # Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. + img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + # Generate patch position ids. + img_ids = torch.zeros(h // 2, w // 2, 3, device=img.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=img.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=img.device)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + return img, img_ids diff --git a/invokeai/backend/flux/util.py b/invokeai/backend/flux/util.py new file mode 100644 index 00000000000..c81424f8ce4 --- /dev/null +++ b/invokeai/backend/flux/util.py @@ -0,0 +1,71 @@ +# Initially pulled from https://github.com/black-forest-labs/flux + +from dataclasses import dataclass +from typing import Dict, Literal + +from invokeai.backend.flux.model import FluxParams +from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + + +max_seq_lengths: Dict[str, Literal[256, 512]] = { + "flux-dev": 512, + "flux-schnell": 256, +} + + +ae_params = { + "flux": AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ) +} + + +params = { + "flux-dev": FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + "flux-schnell": FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), +} diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 332ac6c8faf..66e54d82f3a 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -52,6 +52,7 @@ class BaseModelType(str, Enum): StableDiffusion2 = "sd-2" StableDiffusionXL = "sdxl" StableDiffusionXLRefiner = "sdxl-refiner" + Flux = "flux" # Kandinsky2_1 = "kandinsky-2.1" @@ -66,7 +67,9 @@ class ModelType(str, Enum): TextualInversion = "embedding" IPAdapter = "ip_adapter" CLIPVision = "clip_vision" + CLIPEmbed = "clip_embed" T2IAdapter = "t2i_adapter" + T5Encoder = "t5_encoder" SpandrelImageToImage = "spandrel_image_to_image" @@ -74,6 +77,7 @@ class SubModelType(str, Enum): """Submodel type.""" UNet = "unet" + Transformer = "transformer" TextEncoder = "text_encoder" TextEncoder2 = "text_encoder_2" Tokenizer = "tokenizer" @@ -104,6 +108,9 @@ class ModelFormat(str, Enum): EmbeddingFile = "embedding_file" EmbeddingFolder = "embedding_folder" InvokeAI = "invokeai" + T5Encoder = "t5_encoder" + BnbQuantizedLlmInt8b = "bnb_quantized_int8b" + BnbQuantizednf4b = "bnb_quantized_nf4b" class SchedulerPredictionType(str, Enum): @@ -186,7 +193,9 @@ def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> N class CheckpointConfigBase(ModelConfigBase): """Model config for checkpoint-style models.""" - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field( + description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint + ) config_path: str = Field(description="path to the checkpoint model config file") converted_at: Optional[float] = Field( description="When this model was last converted to diffusers", default_factory=time.time @@ -205,6 +214,26 @@ class LoRAConfigBase(ModelConfigBase): trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) +class T5EncoderConfigBase(ModelConfigBase): + type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder + + +class T5EncoderConfig(T5EncoderConfigBase): + format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}") + + +class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase): + format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}") + + class LoRALyCORISConfig(LoRAConfigBase): """Model config for LoRA/Lycoris models.""" @@ -229,7 +258,6 @@ class VAECheckpointConfig(CheckpointConfigBase): """Model config for standalone VAE models.""" type: Literal[ModelType.VAE] = ModelType.VAE - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint @staticmethod def get_tag() -> Tag: @@ -268,7 +296,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase) """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = ModelType.ControlNet - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint @staticmethod def get_tag() -> Tag: @@ -317,6 +344,21 @@ def get_tag() -> Tag: return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}") +class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase): + """Model config for main checkpoint models.""" + + prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon + upcast_attention: bool = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.format = ModelFormat.BnbQuantizednf4b + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}") + + class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase): """Model config for main diffusers models.""" @@ -350,6 +392,17 @@ def get_tag() -> Tag: return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}") +class CLIPEmbedDiffusersConfig(DiffusersConfigBase): + """Model config for Clip Embeddings.""" + + type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed + format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}") + + class CLIPVisionDiffusersConfig(DiffusersConfigBase): """Model config for CLIPVision.""" @@ -408,12 +461,15 @@ def get_model_discriminator_value(v: Any) -> str: Union[ Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], + Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()], Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()], Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()], Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()], Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()], Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], + Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()], + Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()], Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()], @@ -421,6 +477,7 @@ def get_model_discriminator_value(v: Any) -> str: Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()], Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], + Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()], ], Discriminator(get_model_discriminator_value), ] diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py new file mode 100644 index 00000000000..0316de60440 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/flux.py @@ -0,0 +1,234 @@ +# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team +"""Class for Flux model loading in InvokeAI.""" + +from pathlib import Path +from typing import Optional + +import accelerate +import torch +from safetensors.torch import load_file +from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer + +from invokeai.app.services.config.config_default import get_config +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.modules.autoencoder import AutoEncoder +from invokeai.backend.flux.util import ae_params, params +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.config import ( + CheckpointConfigBase, + CLIPEmbedDiffusersConfig, + MainBnbQuantized4bCheckpointConfig, + MainCheckpointConfig, + T5EncoderBnbQuantizedLlmInt8bConfig, + T5EncoderConfig, + VAECheckpointConfig, +) +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry +from invokeai.backend.util.silence_warnings import SilenceWarnings + +try: + from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 + from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 + + bnb_available = True +except ImportError: + bnb_available = False + +app_config = get_config() + + +@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint) +class FluxVAELoader(ModelLoader): + """Class to load VAE models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, VAECheckpointConfig): + raise ValueError("Only VAECheckpointConfig models are currently supported here.") + model_path = Path(config.path) + + with SilenceWarnings(): + model = AutoEncoder(ae_params[config.config_path]) + sd = load_file(model_path) + model.load_state_dict(sd, assign=True) + model.to(dtype=self._torch_dtype) + + return model + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers) +class ClipCheckpointModel(ModelLoader): + """Class to load main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, CLIPEmbedDiffusersConfig): + raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.") + + match submodel_type: + case SubModelType.Tokenizer: + return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer") + case SubModelType.TextEncoder: + return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder") + + raise ValueError( + f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b) +class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader): + """Class to load main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig): + raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.") + if not bnb_available: + raise ImportError( + "The bnb modules are not available. Please install bitsandbytes if available on your platform." + ) + match submodel_type: + case SubModelType.Tokenizer2: + return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) + case SubModelType.TextEncoder2: + te2_model_path = Path(config.path) / "text_encoder_2" + model_config = AutoConfig.from_pretrained(te2_model_path) + with accelerate.init_empty_weights(): + model = AutoModelForTextEncoding.from_config(model_config) + model = quantize_model_llm_int8(model, modules_to_not_convert=set()) + + state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors" + state_dict = load_file(state_dict_path) + self._load_state_dict_into_t5(model, state_dict) + + return model + + raise ValueError( + f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + @classmethod + def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]): + # There is a shared reference to a single weight tensor in the model. + # Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should + # be present in the state_dict. + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True) + assert len(unexpected_keys) == 0 + assert set(missing_keys) == {"encoder.embed_tokens.weight"} + # Assert that the layers we expect to be shared are actually shared. + assert model.encoder.embed_tokens.weight is model.shared.weight + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) +class T5EncoderCheckpointModel(ModelLoader): + """Class to load main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, T5EncoderConfig): + raise ValueError("Only T5EncoderConfig models are currently supported here.") + + match submodel_type: + case SubModelType.Tokenizer2: + return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) + case SubModelType.TextEncoder2: + return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") + + raise ValueError( + f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + +@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint) +class FluxCheckpointModel(ModelLoader): + """Class to load main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, CheckpointConfigBase): + raise ValueError("Only CheckpointConfigBase models are currently supported here.") + + match submodel_type: + case SubModelType.Transformer: + return self._load_from_singlefile(config) + + raise ValueError( + f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + def _load_from_singlefile( + self, + config: AnyModelConfig, + ) -> AnyModel: + assert isinstance(config, MainCheckpointConfig) + model_path = Path(config.path) + + with SilenceWarnings(): + model = Flux(params[config.config_path]) + sd = load_file(model_path) + model.load_state_dict(sd, assign=True) + return model + + +@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b) +class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): + """Class to load main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, CheckpointConfigBase): + raise ValueError("Only CheckpointConfigBase models are currently supported here.") + + match submodel_type: + case SubModelType.Transformer: + return self._load_from_singlefile(config) + + raise ValueError( + f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + def _load_from_singlefile( + self, + config: AnyModelConfig, + ) -> AnyModel: + assert isinstance(config, MainBnbQuantized4bCheckpointConfig) + if not bnb_available: + raise ImportError( + "The bnb modules are not available. Please install bitsandbytes if available on your platform." + ) + model_path = Path(config.path) + + with SilenceWarnings(): + with accelerate.init_empty_weights(): + model = Flux(params[config.config_path]) + model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) + sd = load_file(model_path) + model.load_state_dict(sd, assign=True) + return model diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index dfe38aa79c2..f1691ec4d4b 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -78,7 +78,12 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy # TO DO: Add exception handling def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type - if module in ["diffusers", "transformers"]: + if module in [ + "diffusers", + "transformers", + "invokeai.backend.quantization.fast_quantized_transformers_model", + "invokeai.backend.quantization.fast_quantized_diffusion_model", + ]: res_type = sys.modules[module] else: res_type = sys.modules["diffusers"].pipelines diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index 33ce4abc4d4..572859dbaee 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -36,8 +36,18 @@ } -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers) +@ModelLoaderRegistry.register( + base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers +) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint) +@ModelLoaderRegistry.register( + base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint +) class StableDiffusionDiffusersModel(GenericDiffusersLoader): """Class to load main models.""" diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index bc612043e34..4b8b5a8dded 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -9,7 +9,7 @@ import torch from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SchedulerMixin -from transformers import CLIPTokenizer +from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline @@ -50,6 +50,17 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int: ), ): return model.calc_size() + elif isinstance( + model, + ( + T5TokenizerFast, + T5Tokenizer, + ), + ): + # HACK(ryand): len(model) just returns the vocabulary size, so this is blatantly wrong. It should be small + # relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some + # point. + return len(model) else: # TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the # supported model types. diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 1929b3f4fd8..029366e3573 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -95,6 +95,7 @@ class ModelProbe(object): } CLASS2TYPE = { + "FluxPipeline": ModelType.Main, "StableDiffusionPipeline": ModelType.Main, "StableDiffusionInpaintPipeline": ModelType.Main, "StableDiffusionXLPipeline": ModelType.Main, @@ -106,6 +107,7 @@ class ModelProbe(object): "ControlNetModel": ModelType.ControlNet, "CLIPVisionModelWithProjection": ModelType.CLIPVision, "T2IAdapter": ModelType.T2IAdapter, + "CLIPModel": ModelType.CLIPEmbed, } @classmethod @@ -161,7 +163,7 @@ def probe( fields["description"] = ( fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}" ) - fields["format"] = fields.get("format") or probe.get_format() + fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format() fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path) fields["default_settings"] = fields.get("default_settings") @@ -176,10 +178,10 @@ def probe( fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() # additional fields needed for main and controlnet models - if ( - fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] - and fields["format"] is ModelFormat.Checkpoint - ): + if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [ + ModelFormat.Checkpoint, + ModelFormat.BnbQuantizednf4b, + ]: ckpt_config_path = cls._get_checkpoint_config_path( model_path, model_type=fields["type"], @@ -222,7 +224,8 @@ def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[C ckpt = ckpt.get("state_dict", ckpt) for key in [str(k) for k in ckpt.keys()]: - if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")): + if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.", "double_blocks.")): + # Keys starting with double_blocks are associated with Flux models return ModelType.Main elif key.startswith(("encoder.conv_in", "decoder.conv_in")): return ModelType.VAE @@ -321,10 +324,27 @@ def _get_checkpoint_config_path( return possible_conf.absolute() if model_type is ModelType.Main: - config_file = LEGACY_CONFIGS[base_type][variant_type] - if isinstance(config_file, dict): # need another tier for sd-2.x models - config_file = config_file[prediction_type] - config_file = f"stable-diffusion/{config_file}" + if base_type == BaseModelType.Flux: + # TODO: Decide between dev/schnell + checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) + state_dict = checkpoint.get("state_dict") or checkpoint + if "guidance_in.out_layer.weight" in state_dict: + # For flux, this is a key in invokeai.backend.flux.util.params + # Due to model type and format being the descriminator for model configs this + # is used rather than attempting to support flux with separate model types and format + # If changed in the future, please fix me + config_file = "flux-dev" + else: + # For flux, this is a key in invokeai.backend.flux.util.params + # Due to model type and format being the descriminator for model configs this + # is used rather than attempting to support flux with separate model types and format + # If changed in the future, please fix me + config_file = "flux-schnell" + else: + config_file = LEGACY_CONFIGS[base_type][variant_type] + if isinstance(config_file, dict): # need another tier for sd-2.x models + config_file = config_file[prediction_type] + config_file = f"stable-diffusion/{config_file}" elif model_type is ModelType.ControlNet: config_file = ( "controlnet/cldm_v15.yaml" @@ -333,7 +353,13 @@ def _get_checkpoint_config_path( ) elif model_type is ModelType.VAE: config_file = ( - "stable-diffusion/v1-inference.yaml" + # For flux, this is a key in invokeai.backend.flux.util.ae_params + # Due to model type and format being the descriminator for model configs this + # is used rather than attempting to support flux with separate model types and format + # If changed in the future, please fix me + "flux" + if base_type is BaseModelType.Flux + else "stable-diffusion/v1-inference.yaml" if base_type is BaseModelType.StableDiffusion1 else "stable-diffusion/sd_xl_base.yaml" if base_type is BaseModelType.StableDiffusionXL @@ -416,11 +442,15 @@ def __init__(self, model_path: Path): self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) def get_format(self) -> ModelFormat: + state_dict = self.checkpoint.get("state_dict") or self.checkpoint + if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict: + return ModelFormat.BnbQuantizednf4b return ModelFormat("checkpoint") def get_variant_type(self) -> ModelVariantType: model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint) - if model_type != ModelType.Main: + base_type = self.get_base_type() + if model_type != ModelType.Main or base_type == BaseModelType.Flux: return ModelVariantType.Normal state_dict = self.checkpoint.get("state_dict") or self.checkpoint in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] @@ -440,6 +470,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint state_dict = self.checkpoint.get("state_dict") or checkpoint + if "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: + return BaseModelType.Flux key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 768: return BaseModelType.StableDiffusion1 @@ -482,6 +514,7 @@ def get_base_type(self) -> BaseModelType: (r"xl", BaseModelType.StableDiffusionXL), (r"sd2", BaseModelType.StableDiffusion2), (r"vae", BaseModelType.StableDiffusion1), + (r"FLUX.1-schnell_ae", BaseModelType.Flux), ]: if re.search(regexp, self.model_path.name, re.IGNORECASE): return basetype @@ -713,6 +746,11 @@ def get_base_type(self) -> BaseModelType: return TextualInversionCheckpointProbe(path).get_base_type() +class T5EncoderFolderProbe(FolderProbeBase): + def get_format(self) -> ModelFormat: + return ModelFormat.T5Encoder + + class ONNXFolderProbe(PipelineFolderProbe): def get_base_type(self) -> BaseModelType: # Due to the way the installer is set up, the configuration file for safetensors @@ -805,6 +843,11 @@ def get_base_type(self) -> BaseModelType: return BaseModelType.Any +class CLIPEmbedFolderProbe(FolderProbeBase): + def get_base_type(self) -> BaseModelType: + return BaseModelType.Any + + class SpandrelImageToImageFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: raise NotImplementedError() @@ -835,8 +878,10 @@ def get_base_type(self) -> BaseModelType: ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe) ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe) diff --git a/invokeai/backend/model_manager/starter_models.py b/invokeai/backend/model_manager/starter_models.py index c460a5e86e6..76b91f0d34c 100644 --- a/invokeai/backend/model_manager/starter_models.py +++ b/invokeai/backend/model_manager/starter_models.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -from invokeai.backend.model_manager.config import BaseModelType, ModelType +from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType class StarterModelWithoutDependencies(BaseModel): @@ -11,6 +11,7 @@ class StarterModelWithoutDependencies(BaseModel): name: str base: BaseModelType type: ModelType + format: Optional[ModelFormat] = None is_installed: bool = False @@ -51,10 +52,76 @@ class StarterModel(StarterModelWithoutDependencies): type=ModelType.TextualInversion, ) +t5_base_encoder = StarterModel( + name="t5_base_encoder", + base=BaseModelType.Any, + source="InvokeAI/t5-v1_1-xxl::bfloat16", + description="T5-XXL text encoder (used in FLUX pipelines). ~8GB", + type=ModelType.T5Encoder, +) + +t5_8b_quantized_encoder = StarterModel( + name="t5_bnb_int8_quantized_encoder", + base=BaseModelType.Any, + source="InvokeAI/t5-v1_1-xxl::bnb_llm_int8", + description="T5-XXL text encoder with bitsandbytes LLM.int8() quantization (used in FLUX pipelines). ~5GB", + type=ModelType.T5Encoder, + format=ModelFormat.BnbQuantizedLlmInt8b, +) + +clip_l_encoder = StarterModel( + name="clip-vit-large-patch14", + base=BaseModelType.Any, + source="InvokeAI/clip-vit-large-patch14-text-encoder::bfloat16", + description="CLIP-L text encoder (used in FLUX pipelines). ~250MB", + type=ModelType.CLIPEmbed, +) + +flux_vae = StarterModel( + name="FLUX.1-schnell_ae", + base=BaseModelType.Flux, + source="black-forest-labs/FLUX.1-schnell::ae.safetensors", + description="FLUX VAE compatible with both schnell and dev variants.", + type=ModelType.VAE, +) + + # List of starter models, displayed on the frontend. # The order/sort of this list is not changed by the frontend - set it how you want it here. STARTER_MODELS: list[StarterModel] = [ # region: Main + StarterModel( + name="FLUX Schnell (Quantized)", + base=BaseModelType.Flux, + source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors", + description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB", + type=ModelType.Main, + dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], + ), + StarterModel( + name="FLUX Dev (Quantized)", + base=BaseModelType.Flux, + source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors", + description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB", + type=ModelType.Main, + dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], + ), + StarterModel( + name="FLUX Schnell", + base=BaseModelType.Flux, + source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors", + description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB", + type=ModelType.Main, + dependencies=[t5_base_encoder, flux_vae, clip_l_encoder], + ), + StarterModel( + name="FLUX Dev", + base=BaseModelType.Flux, + source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors", + description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB", + type=ModelType.Main, + dependencies=[t5_base_encoder, flux_vae, clip_l_encoder], + ), StarterModel( name="CyberRealistic v4.1", base=BaseModelType.StableDiffusion1, @@ -125,6 +192,7 @@ class StarterModel(StarterModelWithoutDependencies): # endregion # region VAE sdxl_fp16_vae_fix, + flux_vae, # endregion # region LoRA StarterModel( @@ -450,6 +518,11 @@ class StarterModel(StarterModelWithoutDependencies): type=ModelType.SpandrelImageToImage, ), # endregion + # region TextEncoders + t5_base_encoder, + t5_8b_quantized_encoder, + clip_l_encoder, + # endregion ] assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models" diff --git a/invokeai/backend/model_manager/util/select_hf_files.py b/invokeai/backend/model_manager/util/select_hf_files.py index b0a95514378..b0d33d6efb7 100644 --- a/invokeai/backend/model_manager/util/select_hf_files.py +++ b/invokeai/backend/model_manager/util/select_hf_files.py @@ -54,6 +54,7 @@ def filter_files( "lora_weights.safetensors", "weights.pb", "onnx_data", + "spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`. ) ): paths.append(file) @@ -62,13 +63,13 @@ def filter_files( # downloading random checkpoints that might also be in the repo. However there is no guarantee # that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models # will adhere to this naming convention, so this is an area to be careful of. - elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): + elif re.search(r"model.*\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name): paths.append(file) # limit search to subfolder if requested if subfolder: subfolder = root / subfolder - paths = [x for x in paths if x.parent == Path(subfolder)] + paths = [x for x in paths if Path(subfolder) in x.parents] # _filter_by_variant uniquifies the paths and returns a set return sorted(_filter_by_variant(paths, variant)) @@ -97,7 +98,9 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path if variant == ModelRepoVariant.Flax: result.add(path) - elif path.suffix in [".json", ".txt"]: + # Note: '.model' was added to support: + # https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/768d12a373ed5cc9ef9a9dea7504dc09fcc14842/tokenizer_2/spiece.model + elif path.suffix in [".json", ".txt", ".model"]: result.add(path) elif variant in [ @@ -140,6 +143,23 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path continue for candidate_list in subfolder_weights.values(): + # Check if at least one of the files has the explicit fp16 variant. + at_least_one_fp16 = False + for candidate in candidate_list: + if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16": + at_least_one_fp16 = True + break + + if not at_least_one_fp16: + # If none of the candidates in this candidate_list have the explicit fp16 variant label, then this + # candidate_list probably doesn't adhere to the variant naming convention that we expected. In this case, + # we'll simply keep all the candidates. An example of a model that hits this case is + # `black-forest-labs/FLUX.1-schnell` (as of commit 012d2fd). + for candidate in candidate_list: + result.add(candidate.path) + + # The candidate_list seems to have the expected variant naming convention. We'll select the highest scoring + # candidate. highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score) if highest_score_candidate: result.add(highest_score_candidate.path) diff --git a/invokeai/backend/quantization/__init__.py b/invokeai/backend/quantization/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/backend/quantization/bnb_llm_int8.py b/invokeai/backend/quantization/bnb_llm_int8.py new file mode 100644 index 00000000000..b92717cbc57 --- /dev/null +++ b/invokeai/backend/quantization/bnb_llm_int8.py @@ -0,0 +1,125 @@ +import bitsandbytes as bnb +import torch + +# This file contains utils for working with models that use bitsandbytes LLM.int8() quantization. +# The utils in this file are partially inspired by: +# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py + + +# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much +# cleaner by re-implementing bnb.nn.Linear8bitLt with proper use of buffers and less magic. But, for now, we try to +# stick close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes. + + +class InvokeInt8Params(bnb.nn.Int8Params): + """We override cuda() to avoid re-quantizing the weights in the following cases: + - We loaded quantized weights from a state_dict on the cpu, and then moved the model to the gpu. + - We are moving the model back-and-forth between the cpu and gpu. + """ + + def cuda(self, device): + if self.has_fp16_weights: + return super().cuda(device) + elif self.CB is not None and self.SCB is not None: + self.data = self.data.cuda() + self.CB = self.data + self.SCB = self.SCB.cuda() + else: + # we store the 8-bit rows-major weight + # we convert this weight to the turning/ampere weight during the first inference pass + B = self.data.contiguous().half().cuda(device) + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + del CBt + del SCBt + self.data = CB + self.CB = CB + self.SCB = SCB + + return self + + +class InvokeLinear8bitLt(bnb.nn.Linear8bitLt): + def _load_from_state_dict( + self, + state_dict: dict[str, torch.Tensor], + prefix: str, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias", None) + + # See `bnb.nn.Linear8bitLt._save_to_state_dict()` for the serialization logic of SCB and weight_format. + scb = state_dict.pop(prefix + "SCB", None) + # weight_format is unused, but we pop it so we can validate that there are no unexpected keys. + _weight_format = state_dict.pop(prefix + "weight_format", None) + + # TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs` + # rather than raising an exception to correctly implement this API. + assert len(state_dict) == 0 + + if scb is not None: + # We are loading a pre-quantized state dict. + self.weight = InvokeInt8Params( + data=weight, + requires_grad=self.weight.requires_grad, + has_fp16_weights=False, + # Note: After quantization, CB is the same as weight. + CB=weight, + SCB=scb, + ) + self.bias = bias if bias is None else torch.nn.Parameter(bias) + else: + # We are loading a non-quantized state dict. + + # We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to + # load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta" + # device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()` + # implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new + # `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done. + self.weight = InvokeInt8Params( + data=weight, + requires_grad=self.weight.requires_grad, + has_fp16_weights=False, + CB=None, + SCB=None, + ) + self.bias = bias if bias is None else torch.nn.Parameter(bias) + + +def _convert_linear_layers_to_llm_8bit( + module: torch.nn.Module, ignore_modules: set[str], outlier_threshold: float, prefix: str = "" +) -> None: + """Convert all linear layers in the module to bnb.nn.Linear8bitLt layers.""" + for name, child in module.named_children(): + fullname = f"{prefix}.{name}" if prefix else name + if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): + has_bias = child.bias is not None + replacement = InvokeLinear8bitLt( + child.in_features, + child.out_features, + bias=has_bias, + has_fp16_weights=False, + threshold=outlier_threshold, + ) + replacement.weight.data = child.weight.data + if has_bias: + replacement.bias.data = child.bias.data + replacement.requires_grad_(False) + module.__setattr__(name, replacement) + else: + _convert_linear_layers_to_llm_8bit( + child, ignore_modules, outlier_threshold=outlier_threshold, prefix=fullname + ) + + +def quantize_model_llm_int8(model: torch.nn.Module, modules_to_not_convert: set[str], outlier_threshold: float = 6.0): + """Apply bitsandbytes LLM.8bit() quantization to the model.""" + _convert_linear_layers_to_llm_8bit( + module=model, ignore_modules=modules_to_not_convert, outlier_threshold=outlier_threshold + ) + + return model diff --git a/invokeai/backend/quantization/bnb_nf4.py b/invokeai/backend/quantization/bnb_nf4.py new file mode 100644 index 00000000000..105bf1474c1 --- /dev/null +++ b/invokeai/backend/quantization/bnb_nf4.py @@ -0,0 +1,156 @@ +import bitsandbytes as bnb +import torch + +# This file contains utils for working with models that use bitsandbytes NF4 quantization. +# The utils in this file are partially inspired by: +# https://github.com/Lightning-AI/pytorch-lightning/blob/1551a16b94f5234a4a78801098f64d0732ef5cb5/src/lightning/fabric/plugins/precision/bitsandbytes.py + +# NOTE(ryand): All of the custom state_dict manipulation logic in this file is pretty hacky. This could be made much +# cleaner by re-implementing bnb.nn.LinearNF4 with proper use of buffers and less magic. But, for now, we try to stick +# close to the bitsandbytes classes to make interoperability easier with other models that might use bitsandbytes. + + +class InvokeLinearNF4(bnb.nn.LinearNF4): + """A class that extends `bnb.nn.LinearNF4` to add the following functionality: + - Ability to load Linear NF4 layers from a pre-quantized state_dict. + - Ability to load Linear NF4 layers from a state_dict when the model is on the "meta" device. + """ + + def _load_from_state_dict( + self, + state_dict: dict[str, torch.Tensor], + prefix: str, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """This method is based on the logic in the bitsandbytes serialization unit tests for `Linear4bit`: + https://github.com/bitsandbytes-foundation/bitsandbytes/blob/6d714a5cce3db5bd7f577bc447becc7a92d5ccc7/tests/test_linear4bit.py#L52-L71 + """ + weight = state_dict.pop(prefix + "weight") + bias = state_dict.pop(prefix + "bias", None) + # We expect the remaining keys to be quant_state keys. + quant_state_sd = state_dict + + # During serialization, the quant_state is stored as subkeys of "weight." (See + # `bnb.nn.LinearNF4._save_to_state_dict()`). We validate that they at least have the correct prefix. + # TODO(ryand): Technically, we should be using `strict`, `missing_keys`, `unexpected_keys`, and `error_msgs` + # rather than raising an exception to correctly implement this API. + assert all(k.startswith(prefix + "weight.") for k in quant_state_sd.keys()) + + if len(quant_state_sd) > 0: + # We are loading a pre-quantized state dict. + self.weight = bnb.nn.Params4bit.from_prequantized( + data=weight, quantized_stats=quant_state_sd, device=weight.device + ) + self.bias = bias if bias is None else torch.nn.Parameter(bias, requires_grad=False) + else: + # We are loading a non-quantized state dict. + + # We could simply call the `super()._load_from_state_dict()` method here, but then we wouldn't be able to + # load from a state_dict into a model on the "meta" device. Attempting to load into a model on the "meta" + # device requires setting `assign=True`, doing this with the default `super()._load_from_state_dict()` + # implementation causes `Params4Bit` to be replaced by a `torch.nn.Parameter`. By initializing a new + # `Params4bit` object, we work around this issue. It's a bit hacky, but it gets the job done. + self.weight = bnb.nn.Params4bit( + data=weight, + requires_grad=self.weight.requires_grad, + compress_statistics=self.weight.compress_statistics, + quant_type=self.weight.quant_type, + quant_storage=self.weight.quant_storage, + module=self, + ) + self.bias = bias if bias is None else torch.nn.Parameter(bias) + + +def _replace_param( + param: torch.nn.Parameter | bnb.nn.Params4bit, + data: torch.Tensor, +) -> torch.nn.Parameter: + """A helper function to replace the data of a model parameter with new data in a way that allows replacing params on + the "meta" device. + + Supports both `torch.nn.Parameter` and `bnb.nn.Params4bit` parameters. + """ + if param.device.type == "meta": + # Doing `param.data = data` raises a RuntimeError if param.data was on the "meta" device, so we need to + # re-create the param instead of overwriting the data. + if isinstance(param, bnb.nn.Params4bit): + return bnb.nn.Params4bit( + data, + requires_grad=data.requires_grad, + quant_state=param.quant_state, + compress_statistics=param.compress_statistics, + quant_type=param.quant_type, + ) + return torch.nn.Parameter(data, requires_grad=data.requires_grad) + + param.data = data + return param + + +def _convert_linear_layers_to_nf4( + module: torch.nn.Module, + ignore_modules: set[str], + compute_dtype: torch.dtype, + compress_statistics: bool = False, + prefix: str = "", +) -> None: + """Convert all linear layers in the model to NF4 quantized linear layers. + + Args: + module: All linear layers in this module will be converted. + ignore_modules: A set of module prefixes to ignore when converting linear layers. + compute_dtype: The dtype to use for computation in the quantized linear layers. + compress_statistics: Whether to enable nested quantization (aka double quantization) where the quantization + constants from the first quantization are quantized again. + prefix: The prefix of the current module in the model. Used to call this function recursively. + """ + for name, child in module.named_children(): + fullname = f"{prefix}.{name}" if prefix else name + if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): + has_bias = child.bias is not None + replacement = InvokeLinearNF4( + child.in_features, + child.out_features, + bias=has_bias, + compute_dtype=compute_dtype, + compress_statistics=compress_statistics, + ) + if has_bias: + replacement.bias = _replace_param(replacement.bias, child.bias.data) + replacement.weight = _replace_param(replacement.weight, child.weight.data) + replacement.requires_grad_(False) + module.__setattr__(name, replacement) + else: + _convert_linear_layers_to_nf4(child, ignore_modules, compute_dtype=compute_dtype, prefix=fullname) + + +def quantize_model_nf4(model: torch.nn.Module, modules_to_not_convert: set[str], compute_dtype: torch.dtype): + """Apply bitsandbytes nf4 quantization to the model. + + You likely want to call this function inside a `accelerate.init_empty_weights()` context. + + Example usage: + ``` + # Initialize the model from a config on the meta device. + with accelerate.init_empty_weights(): + model = ModelClass.from_config(...) + + # Add NF4 quantization linear layers to the model - still on the meta device. + with accelerate.init_empty_weights(): + model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.float16) + + # Load a state_dict into the model. (Could be either a prequantized or non-quantized state_dict.) + model.load_state_dict(state_dict, strict=True, assign=True) + + # Move the model to the "cuda" device. If the model was non-quantized, this is where the weight quantization takes + # place. + model.to("cuda") + ``` + """ + _convert_linear_layers_to_nf4(module=model, ignore_modules=modules_to_not_convert, compute_dtype=compute_dtype) + + return model diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py new file mode 100644 index 00000000000..804336e0007 --- /dev/null +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py @@ -0,0 +1,79 @@ +from pathlib import Path + +import accelerate +from safetensors.torch import load_file, save_file + +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.util import params +from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 +from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time + + +def main(): + """A script for quantizing a FLUX transformer model using the bitsandbytes LLM.int8() quantization method. + + This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert, + etc.) are hardcoded and would need to be modified for other use cases. + """ + # Load the FLUX transformer model onto the meta device. + model_path = Path( + "/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors" + ) + + with log_time("Intialize FLUX transformer on meta device"): + # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. + p = params["flux-schnell"] + + # Initialize the model on the "meta" device. + with accelerate.init_empty_weights(): + model = Flux(p) + + # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate + # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. + modules_to_not_convert: set[str] = set() + + model_int8_path = model_path.parent / "bnb_llm_int8.safetensors" + if model_int8_path.exists(): + # The quantized model already exists, load it and return it. + print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...") + + # Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device). + with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights(): + model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert) + + with log_time("Load state dict into model"): + sd = load_file(model_int8_path) + model.load_state_dict(sd, strict=True, assign=True) + + with log_time("Move model to cuda"): + model = model.to("cuda") + + print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.") + + else: + # The quantized model does not exist, quantize the model and save it. + print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...") + + with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights(): + model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert) + + with log_time("Load state dict into model"): + state_dict = load_file(model_path) + # TODO(ryand): Cast the state_dict to the appropriate dtype? + model.load_state_dict(state_dict, strict=True, assign=True) + + with log_time("Move model to cuda and quantize"): + model = model.to("cuda") + + with log_time("Save quantized model"): + model_int8_path.parent.mkdir(parents=True, exist_ok=True) + save_file(model.state_dict(), model_int8_path) + + print(f"Successfully quantized and saved model to '{model_int8_path}'.") + + assert isinstance(model, Flux) + return model + + +if __name__ == "__main__": + main() diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py new file mode 100644 index 00000000000..f1621dbc6dd --- /dev/null +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py @@ -0,0 +1,96 @@ +import time +from contextlib import contextmanager +from pathlib import Path + +import accelerate +import torch +from safetensors.torch import load_file, save_file + +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.util import params +from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 + + +@contextmanager +def log_time(name: str): + """Helper context manager to log the time taken by a block of code.""" + start = time.time() + try: + yield None + finally: + end = time.time() + print(f"'{name}' took {end - start:.4f} secs") + + +def main(): + """A script for quantizing a FLUX transformer model using the bitsandbytes NF4 quantization method. + + This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert, + etc.) are hardcoded and would need to be modified for other use cases. + """ + model_path = Path( + "/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors" + ) + + # inference_dtype = torch.bfloat16 + with log_time("Intialize FLUX transformer on meta device"): + # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. + p = params["flux-schnell"] + + # Initialize the model on the "meta" device. + with accelerate.init_empty_weights(): + model = Flux(p) + + # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate + # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. + modules_to_not_convert: set[str] = set() + + model_nf4_path = model_path.parent / "bnb_nf4.safetensors" + if model_nf4_path.exists(): + # The quantized model already exists, load it and return it. + print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...") + + # Replace the linear layers with NF4 quantized linear layers (still on the meta device). + with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights(): + model = quantize_model_nf4( + model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 + ) + + with log_time("Load state dict into model"): + state_dict = load_file(model_nf4_path) + model.load_state_dict(state_dict, strict=True, assign=True) + + with log_time("Move model to cuda"): + model = model.to("cuda") + + print(f"Successfully loaded pre-quantized model from '{model_nf4_path}'.") + + else: + # The quantized model does not exist, quantize the model and save it. + print(f"No pre-quantized model found at '{model_nf4_path}'. Quantizing the model...") + + with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights(): + model = quantize_model_nf4( + model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16 + ) + + with log_time("Load state dict into model"): + state_dict = load_file(model_path) + # TODO(ryand): Cast the state_dict to the appropriate dtype? + model.load_state_dict(state_dict, strict=True, assign=True) + + with log_time("Move model to cuda and quantize"): + model = model.to("cuda") + + with log_time("Save quantized model"): + model_nf4_path.parent.mkdir(parents=True, exist_ok=True) + save_file(model.state_dict(), model_nf4_path) + + print(f"Successfully quantized and saved model to '{model_nf4_path}'.") + + assert isinstance(model, Flux) + return model + + +if __name__ == "__main__": + main() diff --git a/invokeai/backend/quantization/scripts/quantize_t5_xxl_bnb_llm_int8.py b/invokeai/backend/quantization/scripts/quantize_t5_xxl_bnb_llm_int8.py new file mode 100644 index 00000000000..fc681e8fc57 --- /dev/null +++ b/invokeai/backend/quantization/scripts/quantize_t5_xxl_bnb_llm_int8.py @@ -0,0 +1,92 @@ +from pathlib import Path + +import accelerate +from safetensors.torch import load_file, save_file +from transformers import AutoConfig, AutoModelForTextEncoding, T5EncoderModel + +from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 +from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time + + +def load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict): + # There is a shared reference to a single weight tensor in the model. + # Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should + # be present in the state_dict. + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True) + assert len(unexpected_keys) == 0 + assert set(missing_keys) == {"encoder.embed_tokens.weight"} + # Assert that the layers we expect to be shared are actually shared. + assert model.encoder.embed_tokens.weight is model.shared.weight + + +def main(): + """A script for quantizing a T5 text encoder model using the bitsandbytes LLM.int8() quantization method. + + This script is primarily intended for reference. The script params (e.g. the model_path, modules_to_not_convert, + etc.) are hardcoded and would need to be modified for other use cases. + """ + model_path = Path("/data/misc/text_encoder_2") + + with log_time("Intialize T5 on meta device"): + model_config = AutoConfig.from_pretrained(model_path) + with accelerate.init_empty_weights(): + model = AutoModelForTextEncoding.from_config(model_config) + + # TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate + # `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize. + modules_to_not_convert: set[str] = set() + + model_int8_path = model_path / "bnb_llm_int8.safetensors" + if model_int8_path.exists(): + # The quantized model already exists, load it and return it. + print(f"A pre-quantized model already exists at '{model_int8_path}'. Attempting to load it...") + + # Replace the linear layers with LLM.int8() quantized linear layers (still on the meta device). + with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights(): + model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert) + + with log_time("Load state dict into model"): + sd = load_file(model_int8_path) + load_state_dict_into_t5(model, sd) + + with log_time("Move model to cuda"): + model = model.to("cuda") + + print(f"Successfully loaded pre-quantized model from '{model_int8_path}'.") + + else: + # The quantized model does not exist, quantize the model and save it. + print(f"No pre-quantized model found at '{model_int8_path}'. Quantizing the model...") + + with log_time("Replace linear layers with LLM.int8() layers"), accelerate.init_empty_weights(): + model = quantize_model_llm_int8(model, modules_to_not_convert=modules_to_not_convert) + + with log_time("Load state dict into model"): + # Load sharded state dict. + files = list(model_path.glob("*.safetensors")) + state_dict = {} + for file in files: + sd = load_file(file) + state_dict.update(sd) + load_state_dict_into_t5(model, state_dict) + + with log_time("Move model to cuda and quantize"): + model = model.to("cuda") + + with log_time("Save quantized model"): + model_int8_path.parent.mkdir(parents=True, exist_ok=True) + state_dict = model.state_dict() + state_dict.pop("encoder.embed_tokens.weight") + save_file(state_dict, model_int8_path) + # This handling of shared weights could also be achieved with save_model(...), but then we'd lose control + # over which keys are kept. And, the corresponding load_model(...) function does not support assign=True. + # save_model(model, model_int8_path) + + print(f"Successfully quantized and saved model to '{model_int8_path}'.") + + assert isinstance(model, T5EncoderModel) + return model + + +if __name__ == "__main__": + main() diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 5fe1483ebc9..c5fda909c72 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -25,11 +25,6 @@ def to(self, device, dtype=None): return self -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] - - @dataclass class SDXLConditioningInfo(BasicConditioningInfo): """SDXL text conditioning information produced by Compel.""" @@ -43,6 +38,17 @@ def to(self, device, dtype=None): return super().to(device=device, dtype=dtype) +@dataclass +class FLUXConditioningInfo: + clip_embeds: torch.Tensor + t5_embeds: torch.Tensor + + +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo] + + @dataclass class IPAdapterConditioningInfo: cond_image_prompt_embeds: torch.Tensor diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1737bd4f297..a9ece94b969 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -784,6 +784,7 @@ "simpleModelPlaceholder": "URL or path to a local file or diffusers folder", "source": "Source", "starterModels": "Starter Models", + "starterModelsInModelManager": "Starter Models can be found in Model Manager", "syncModels": "Sync Models", "textualInversions": "Textual Inversions", "triggerPhrases": "Trigger Phrases", diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx index a3c9c82d0eb..81913f3e8ee 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StartModelsResultItem.tsx @@ -5,17 +5,33 @@ import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiPlusBold } from 'react-icons/pi'; import type { GetStarterModelsResponse } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; type Props = { result: GetStarterModelsResponse[number]; + modelList: AnyModelConfig[]; }; -export const StarterModelsResultItem = memo(({ result }: Props) => { +export const StarterModelsResultItem = memo(({ result, modelList }: Props) => { const { t } = useTranslation(); const allSources = useMemo(() => { - const _allSources = [{ source: result.source, config: { name: result.name, description: result.description } }]; + const _allSources = [ + { + source: result.source, + config: { + name: result.name, + description: result.description, + type: result.type, + base: result.base, + format: result.format, + }, + }, + ]; if (result.dependencies) { for (const d of result.dependencies) { - _allSources.push({ source: d.source, config: { name: d.name, description: d.description } }); + _allSources.push({ + source: d.source, + config: { name: d.name, description: d.description, type: d.type, base: d.base, format: d.format }, + }); } } return _allSources; @@ -24,9 +40,12 @@ export const StarterModelsResultItem = memo(({ result }: Props) => { const onClick = useCallback(() => { for (const { config, source } of allSources) { + if (modelList.some((mc) => config.base === mc.base && config.name === mc.name && config.type === mc.type)) { + continue; + } installModel({ config, source }); } - }, [allSources, installModel]); + }, [modelList, allSources, installModel]); return ( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm.tsx index 837ef5c63b8..eaf2cb534ef 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsForm.tsx @@ -1,17 +1,31 @@ import { Flex } from '@invoke-ai/ui-library'; +import { EMPTY_ARRAY } from 'app/store/constants'; import { FetchingModelsLoader } from 'features/modelManagerV2/subpanels/ModelManagerPanel/FetchingModelsLoader'; -import { memo } from 'react'; -import { useGetStarterModelsQuery } from 'services/api/endpoints/models'; +import { memo, useMemo } from 'react'; +import { + modelConfigsAdapterSelectors, + useGetModelConfigsQuery, + useGetStarterModelsQuery, +} from 'services/api/endpoints/models'; import { StarterModelsResults } from './StarterModelsResults'; export const StarterModelsForm = memo(() => { const { isLoading, data } = useGetStarterModelsQuery(); + const { data: modelListRes } = useGetModelConfigsQuery(); + + const modelList = useMemo(() => { + if (!modelListRes) { + return EMPTY_ARRAY; + } + + return modelConfigsAdapterSelectors.selectAll(modelListRes); + }, [modelListRes]); return ( {isLoading && } - {data && } + {data && } ); }); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsResults.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsResults.tsx index e593ee5fc3c..c443171060e 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsResults.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/StarterModels/StarterModelsResults.tsx @@ -5,14 +5,16 @@ import { memo, useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { PiXBold } from 'react-icons/pi'; import type { GetStarterModelsResponse } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; import { StarterModelsResultItem } from './StartModelsResultItem'; type StarterModelsResultsProps = { results: NonNullable; + modelList: AnyModelConfig[]; }; -export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps) => { +export const StarterModelsResults = memo(({ results, modelList }: StarterModelsResultsProps) => { const { t } = useTranslation(); const [searchTerm, setSearchTerm] = useState(''); @@ -72,7 +74,7 @@ export const StarterModelsResults = memo(({ results }: StarterModelsResultsProps {filteredResults.map((result) => ( - + ))} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx index bf07bad58cd..3eb0a91d672 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx @@ -13,6 +13,7 @@ const BASE_COLOR_MAP: Record = { 'sd-2': 'teal', sdxl: 'invokeBlue', 'sdxl-refiner': 'invokeBlue', + flux: 'gold', }; const ModelBaseBadge = ({ base }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx index a4690662c35..68cd4556463 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx @@ -13,6 +13,9 @@ const FORMAT_NAME_MAP: Record = { invokeai: 'internal', embedding_file: 'embedding', embedding_folder: 'embedding', + t5_encoder: 't5_encoder', + bnb_quantized_int8b: 'bnb_quantized_int8b', + bnb_quantized_nf4b: 'quantized', }; const FORMAT_COLOR_MAP: Record = { @@ -22,6 +25,9 @@ const FORMAT_COLOR_MAP: Record = { invokeai: 'base', embedding_file: 'base', embedding_folder: 'base', + t5_encoder: 'base', + bnb_quantized_int8b: 'base', + bnb_quantized_nf4b: 'base', }; const ModelFormatBadge = ({ format }: Props) => { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index 755a6e21fb2..b1c071bed3e 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -5,6 +5,7 @@ import type { FilterableModelType } from 'features/modelManagerV2/store/modelMan import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { + useClipEmbedModels, useControlNetModels, useEmbeddingModels, useIPAdapterModels, @@ -13,6 +14,7 @@ import { useRefinerModels, useSpandrelImageToImageModels, useT2IAdapterModels, + useT5EncoderModels, useVAEModels, } from 'services/api/hooks/modelsByType'; import type { AnyModelConfig } from 'services/api/types'; @@ -73,6 +75,18 @@ const ModelList = () => { [vaeModels, searchTerm, filteredModelType] ); + const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels(); + const filteredT5EncoderModels = useMemo( + () => modelsFilter(t5EncoderModels, searchTerm, filteredModelType), + [t5EncoderModels, searchTerm, filteredModelType] + ); + + const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useClipEmbedModels(); + const filteredClipEmbedModels = useMemo( + () => modelsFilter(clipEmbedModels, searchTerm, filteredModelType), + [clipEmbedModels, searchTerm, filteredModelType] + ); + const [spandrelImageToImageModels, { isLoading: isLoadingSpandrelImageToImageModels }] = useSpandrelImageToImageModels(); const filteredSpandrelImageToImageModels = useMemo( @@ -90,7 +104,9 @@ const ModelList = () => { filteredT2IAdapterModels.length + filteredIPAdapterModels.length + filteredVAEModels.length + - filteredSpandrelImageToImageModels.length + filteredSpandrelImageToImageModels.length + + t5EncoderModels.length + + clipEmbedModels.length ); }, [ filteredControlNetModels.length, @@ -102,6 +118,8 @@ const ModelList = () => { filteredT2IAdapterModels.length, filteredVAEModels.length, filteredSpandrelImageToImageModels.length, + t5EncoderModels.length, + clipEmbedModels.length, ]); return ( @@ -154,6 +172,16 @@ const ModelList = () => { {!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && ( )} + {/* T5 Encoders List */} + {isLoadingT5EncoderModels && } + {!isLoadingT5EncoderModels && filteredT5EncoderModels.length > 0 && ( + + )} + {/* Clip Embed List */} + {isLoadingClipEmbedModels && } + {!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && ( + + )} {/* Spandrel Image to Image List */} {isLoadingSpandrelImageToImageModels && ( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx index 9db3334e89e..91dba7d71ff 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx @@ -19,6 +19,8 @@ export const ModelTypeFilter = memo(() => { controlnet: 'ControlNet', vae: 'VAE', t2i_adapter: t('common.t2iAdapter'), + t5_encoder: 'T5Encoder', + clip_embed: 'Clip Embed', ip_adapter: t('common.ipAdapter'), clip_vision: 'Clip Vision', spandrel_image_to_image: 'Image-to-Image', diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index d863def9737..c4e8da6eda7 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -14,6 +14,8 @@ import { isEnumFieldInputTemplate, isFloatFieldInputInstance, isFloatFieldInputTemplate, + isFluxMainModelFieldInputInstance, + isFluxMainModelFieldInputTemplate, isImageFieldInputInstance, isImageFieldInputTemplate, isIntegerFieldInputInstance, @@ -38,6 +40,8 @@ import { isStringFieldInputTemplate, isT2IAdapterModelFieldInputInstance, isT2IAdapterModelFieldInputTemplate, + isT5EncoderModelFieldInputInstance, + isT5EncoderModelFieldInputTemplate, isVAEModelFieldInputInstance, isVAEModelFieldInputTemplate, } from 'features/nodes/types/field'; @@ -48,6 +52,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; import ColorFieldInputComponent from './inputs/ColorFieldInputComponent'; import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent'; import EnumFieldInputComponent from './inputs/EnumFieldInputComponent'; +import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent'; import ImageFieldInputComponent from './inputs/ImageFieldInputComponent'; import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent'; import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent'; @@ -59,6 +64,7 @@ import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputCo import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent'; import StringFieldInputComponent from './inputs/StringFieldInputComponent'; import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent'; +import T5EncoderModelFieldInputComponent from './inputs/T5EncoderModelFieldInputComponent'; import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent'; type InputFieldProps = { @@ -113,6 +119,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } + if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) { + return ; + } + if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) { return ; } @@ -145,6 +155,9 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) { return ; } + if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) { + return ; + } if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) { return ; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxMainModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxMainModelFieldInputComponent.tsx new file mode 100644 index 00000000000..3a0ddb211ec --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/FluxMainModelFieldInputComponent.tsx @@ -0,0 +1,55 @@ +import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; +import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useFluxModels } from 'services/api/hooks/modelsByType'; +import type { MainModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const FluxMainModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useFluxModels(); + const _onChange = useCallback( + (value: MainModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldMainModelValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + }); + + return ( + + + + + + ); +}; + +export default memo(FluxMainModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T5EncoderModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T5EncoderModelFieldInputComponent.tsx new file mode 100644 index 00000000000..72b60bcee96 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T5EncoderModelFieldInputComponent.tsx @@ -0,0 +1,60 @@ +import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldT5EncoderValueChanged } from 'features/nodes/store/nodesSlice'; +import type { T5EncoderModelFieldInputInstance, T5EncoderModelFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useT5EncoderModels } from 'services/api/hooks/modelsByType'; +import type { T5EncoderBnbQuantizedLlmInt8bModelConfig, T5EncoderModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const T5EncoderModelFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const { t } = useTranslation(); + const disabledTabs = useAppSelector((s) => s.config.disabledTabs); + const dispatch = useAppDispatch(); + const [modelConfigs, { isLoading }] = useT5EncoderModels(); + const _onChange = useCallback( + (value: T5EncoderBnbQuantizedLlmInt8bModelConfig | T5EncoderModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldT5EncoderValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + }); + + return ( + + + + + + + + ); +}; + +export default memo(T5EncoderModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index f9214c15727..6bcd5f276eb 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -23,6 +23,7 @@ import type { StatefulFieldValue, StringFieldValue, T2IAdapterModelFieldValue, + T5EncoderModelFieldValue, VAEModelFieldValue, } from 'features/nodes/types/field'; import { @@ -44,6 +45,7 @@ import { zStatefulFieldValue, zStringFieldValue, zT2IAdapterModelFieldValue, + zT5EncoderModelFieldValue, zVAEModelFieldValue, } from 'features/nodes/types/field'; import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; @@ -341,6 +343,9 @@ export const nodesSlice = createSlice({ ) => { fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue); }, + fieldT5EncoderValueChanged: (state, action: FieldValueAction) => { + fieldValueReducer(state, action, zT5EncoderModelFieldValue); + }, fieldEnumModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zEnumFieldValue); }, @@ -402,6 +407,7 @@ export const { fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, + fieldT5EncoderValueChanged, nodeEditorReset, nodeIsIntermediateChanged, nodeIsOpenChanged, @@ -514,6 +520,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, + fieldT5EncoderValueChanged, nodesChanged, nodeIsIntermediateChanged, nodeIsOpenChanged, diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index c84b2dae623..3fafcbce46f 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -61,7 +61,7 @@ export type SchedulerField = z.infer; // #endregion // #region Model-related schemas -const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']); const zModelType = z.enum([ 'main', 'vae', @@ -73,9 +73,12 @@ const zModelType = z.enum([ 'onnx', 'clip_vision', 'spandrel_image_to_image', + 't5_encoder', + 'clip_embed', ]); const zSubModelType = z.enum([ 'unet', + 'transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 05697c384c0..100c094c464 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -31,6 +31,7 @@ export const MODEL_TYPES = [ 'ControlNetModelField', 'LoRAModelField', 'MainModelField', + 'FluxMainModelField', 'SDXLMainModelField', 'SDXLRefinerModelField', 'VaeModelField', @@ -38,6 +39,7 @@ export const MODEL_TYPES = [ 'VAEField', 'CLIPField', 'T2IAdapterModelField', + 'T5EncoderField', 'SpandrelImageToImageModelField', ]; @@ -50,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = { CLIPField: 'green.500', ColorField: 'pink.300', ConditioningField: 'cyan.500', + FluxConditioningField: 'cyan.500', ControlField: 'teal.500', ControlNetModelField: 'teal.500', EnumField: 'blue.500', @@ -61,6 +64,7 @@ export const FIELD_COLORS: { [key: string]: string } = { LatentsField: 'pink.500', LoRAModelField: 'teal.500', MainModelField: 'teal.500', + FluxMainModelField: 'teal.500', SDXLMainModelField: 'teal.500', SDXLRefinerModelField: 'teal.500', SpandrelImageToImageModelField: 'teal.500', @@ -68,6 +72,8 @@ export const FIELD_COLORS: { [key: string]: string } = { T2IAdapterField: 'teal.500', T2IAdapterModelField: 'teal.500', UNetField: 'red.500', + T5EncoderField: 'green.500', + TransformerField: 'red.500', VAEField: 'blue.500', VAEModelField: 'teal.500', }; diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 925bd40b9db..ee0f61a0fea 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -115,6 +115,10 @@ const zSDXLMainModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLMainModelField'), originalType: zStatelessFieldType.optional(), }); +const zFluxMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('FluxMainModelField'), + originalType: zStatelessFieldType.optional(), +}); const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLRefinerModelField'), originalType: zStatelessFieldType.optional(), @@ -143,6 +147,10 @@ const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({ name: z.literal('SpandrelImageToImageModelField'), originalType: zStatelessFieldType.optional(), }); +const zT5EncoderModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T5EncoderModelField'), + originalType: zStatelessFieldType.optional(), +}); const zSchedulerFieldType = zFieldTypeBase.extend({ name: z.literal('SchedulerField'), originalType: zStatelessFieldType.optional(), @@ -158,6 +166,7 @@ const zStatefulFieldType = z.union([ zModelIdentifierFieldType, zMainModelFieldType, zSDXLMainModelFieldType, + zFluxMainModelFieldType, zSDXLRefinerModelFieldType, zVAEModelFieldType, zLoRAModelFieldType, @@ -165,6 +174,7 @@ const zStatefulFieldType = z.union([ zIPAdapterModelFieldType, zT2IAdapterModelFieldType, zSpandrelImageToImageModelFieldType, + zT5EncoderModelFieldType, zColorFieldType, zSchedulerFieldType, ]); @@ -447,6 +457,29 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain zSDXLMainModelFieldInputTemplate.safeParse(val).success; // #endregion +// #region FluxMainModelField + +const zFluxMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +const zFluxMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zFluxMainModelFieldValue, +}); +const zFluxMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zFluxMainModelFieldType, + originalType: zFieldType.optional(), + default: zFluxMainModelFieldValue, +}); +const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zFluxMainModelFieldType, +}); +export type FluxMainModelFieldInputInstance = z.infer; +export type FluxMainModelFieldInputTemplate = z.infer; +export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance => + zFluxMainModelFieldInputInstance.safeParse(val).success; +export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate => + zFluxMainModelFieldInputTemplate.safeParse(val).success; + +// #endregion + // #region SDXLRefinerModelField /** @alias */ // tells knip to ignore this duplicate export @@ -613,6 +646,29 @@ export const isSpandrelImageToImageModelFieldInputTemplate = ( zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success; // #endregion +// #region T5EncoderModelField + +export const zT5EncoderModelFieldValue = zModelIdentifierField.optional(); +const zT5EncoderModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zT5EncoderModelFieldValue, +}); +const zT5EncoderModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zT5EncoderModelFieldType, + originalType: zFieldType.optional(), + default: zT5EncoderModelFieldValue, +}); + +export type T5EncoderModelFieldValue = z.infer; + +export type T5EncoderModelFieldInputInstance = z.infer; +export type T5EncoderModelFieldInputTemplate = z.infer; +export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5EncoderModelFieldInputInstance => + zT5EncoderModelFieldInputInstance.safeParse(val).success; +export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate => + zT5EncoderModelFieldInputTemplate.safeParse(val).success; + +// #endregio + // #region SchedulerField export const zSchedulerFieldValue = zSchedulerField.optional(); @@ -693,6 +749,7 @@ export const zStatefulFieldValue = z.union([ zModelIdentifierFieldValue, zMainModelFieldValue, zSDXLMainModelFieldValue, + zFluxMainModelFieldValue, zSDXLRefinerModelFieldValue, zVAEModelFieldValue, zLoRAModelFieldValue, @@ -700,6 +757,7 @@ export const zStatefulFieldValue = z.union([ zIPAdapterModelFieldValue, zT2IAdapterModelFieldValue, zSpandrelImageToImageModelFieldValue, + zT5EncoderModelFieldValue, zColorFieldValue, zSchedulerFieldValue, ]); @@ -720,6 +778,7 @@ const zStatefulFieldInputInstance = z.union([ zBoardFieldInputInstance, zModelIdentifierFieldInputInstance, zMainModelFieldInputInstance, + zFluxMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance, zVAEModelFieldInputInstance, @@ -728,6 +787,7 @@ const zStatefulFieldInputInstance = z.union([ zIPAdapterModelFieldInputInstance, zT2IAdapterModelFieldInputInstance, zSpandrelImageToImageModelFieldInputInstance, + zT5EncoderModelFieldInputInstance, zColorFieldInputInstance, zSchedulerFieldInputInstance, ]); @@ -749,6 +809,7 @@ const zStatefulFieldInputTemplate = z.union([ zBoardFieldInputTemplate, zModelIdentifierFieldInputTemplate, zMainModelFieldInputTemplate, + zFluxMainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate, zSDXLRefinerModelFieldInputTemplate, zVAEModelFieldInputTemplate, @@ -757,6 +818,7 @@ const zStatefulFieldInputTemplate = z.union([ zIPAdapterModelFieldInputTemplate, zT2IAdapterModelFieldInputTemplate, zSpandrelImageToImageModelFieldInputTemplate, + zT5EncoderModelFieldInputTemplate, zColorFieldInputTemplate, zSchedulerFieldInputTemplate, zStatelessFieldInputTemplate, @@ -779,6 +841,7 @@ const zStatefulFieldOutputTemplate = z.union([ zBoardFieldOutputTemplate, zModelIdentifierFieldOutputTemplate, zMainModelFieldOutputTemplate, + zFluxMainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate, zSDXLRefinerModelFieldOutputTemplate, zVAEModelFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index a5a2d89f03c..8afda4e2a78 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -15,12 +15,14 @@ const FIELD_VALUE_FALLBACK_MAP: Record = MainModelField: undefined, SchedulerField: 'euler', SDXLMainModelField: undefined, + FluxMainModelField: undefined, SDXLRefinerModelField: undefined, StringField: '', T2IAdapterModelField: undefined, SpandrelImageToImageModelField: undefined, VAEModelField: undefined, ControlNetModelField: undefined, + T5EncoderModelField: undefined, }; export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => { diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 8478415cd14..5149bd4d3a1 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -8,6 +8,7 @@ import type { FieldInputTemplate, FieldType, FloatFieldInputTemplate, + FluxMainModelFieldInputTemplate, ImageFieldInputTemplate, IntegerFieldInputTemplate, IPAdapterModelFieldInputTemplate, @@ -22,6 +23,7 @@ import type { StatelessFieldInputTemplate, StringFieldInputTemplate, T2IAdapterModelFieldInputTemplate, + T5EncoderModelFieldInputTemplate, VAEModelFieldInputTemplate, } from 'features/nodes/types/field'; import { isStatefulFieldType } from 'features/nodes/types/field'; @@ -180,6 +182,20 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: FluxMainModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -208,6 +224,20 @@ const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: T5EncoderModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -386,11 +416,13 @@ export const TEMPLATE_BUILDER_MAP: Record generation.model); @@ -17,7 +17,7 @@ const ParamMainModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); const selectedModel = useAppSelector(selectModel); - const [modelConfigs, { isLoading }] = useMainModels(); + const [modelConfigs, { isLoading }] = useSDMainModels(); const tooltipLabel = useMemo(() => { if (!modelConfigs.length || !selectedModel) { return; diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts index 678b2b37f37..130c14561dc 100644 --- a/invokeai/frontend/web/src/features/parameters/types/constants.ts +++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts @@ -9,6 +9,7 @@ export const MODEL_TYPE_MAP = { 'sd-2': 'Stable Diffusion 2.x', sdxl: 'Stable Diffusion XL', 'sdxl-refiner': 'Stable Diffusion XL Refiner', + flux: 'Flux', }; /** @@ -20,6 +21,7 @@ export const MODEL_TYPE_SHORT_MAP = { 'sd-2': 'SD2.X', sdxl: 'SDXL', 'sdxl-refiner': 'SDXLR', + flux: 'FLUX', }; /** @@ -46,6 +48,10 @@ export const CLIP_SKIP_MAP = { maxClip: 24, markers: [0, 1, 2, 3, 5, 10, 15, 20, 24], }, + flux: { + maxClip: 0, + markers: [], + }, }; /** diff --git a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts index 0f1a81c26f6..6daddf7c4eb 100644 --- a/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts +++ b/invokeai/frontend/web/src/services/api/hooks/modelsByType.ts @@ -3,16 +3,20 @@ import { useMemo } from 'react'; import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; import type { AnyModelConfig } from 'services/api/types'; import { + isClipEmbedModelConfig, isControlNetModelConfig, isControlNetOrT2IAdapterModelConfig, + isFluxMainModelModelConfig, isIPAdapterModelConfig, isLoRAModelConfig, isNonRefinerMainModelConfig, + isNonRefinerNonFluxMainModelConfig, isNonSDXLMainModelConfig, isRefinerMainModelModelConfig, isSDXLMainModelModelConfig, isSpandrelImageToImageModelConfig, isT2IAdapterModelConfig, + isT5EncoderModelConfig, isTIModelConfig, isVAEModelConfig, } from 'services/api/types'; @@ -32,14 +36,18 @@ const buildModelsHook = return [modelConfigs, result] as const; }; +export const useSDMainModels = buildModelsHook(isNonRefinerNonFluxMainModelConfig); export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig); export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig); export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig); +export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig); export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig); export const useLoRAModels = buildModelsHook(isLoRAModelConfig); export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig); export const useControlNetModels = buildModelsHook(isControlNetModelConfig); export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig); +export const useT5EncoderModels = buildModelsHook(isT5EncoderModelConfig); +export const useClipEmbedModels = buildModelsHook(isClipEmbedModelConfig); export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig); export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig); export const useEmbeddingModels = buildModelsHook(isTIModelConfig); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index f0b6c0198cd..88c767d1a70 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1607,7 +1607,7 @@ export type components = { * @description Base model type. * @enum {string} */ - BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner"; + BaseModelType: "any" | "sd-1" | "sd-2" | "sdxl" | "sdxl-refiner" | "flux"; /** Batch */ Batch: { /** @@ -2439,6 +2439,72 @@ export type components = { */ bulk_download_item_name: string; }; + /** + * CLIPEmbedDiffusersConfig + * @description Model config for Clip Embeddings. + */ + CLIPEmbedDiffusersConfig: { + /** + * Key + * @description A unique key for this model. + */ + key: string; + /** + * Hash + * @description The hash of the model file(s). + */ + hash: string; + /** + * Path + * @description Path to the model on the filesystem. Relative paths are relative to the Invoke root directory. + */ + path: string; + /** + * Name + * @description Name of the model. + */ + name: string; + /** @description The base model. */ + base: components["schemas"]["BaseModelType"]; + /** + * Description + * @description Model description + */ + description?: string | null; + /** + * Source + * @description The original source of the model (path, URL or repo_id). + */ + source: string; + /** @description The type of source */ + source_type: components["schemas"]["ModelSourceType"]; + /** + * Source Api Response + * @description The original API response from the source, as stringified JSON. + */ + source_api_response?: string | null; + /** + * Cover Image + * @description Url for image to preview model + */ + cover_image?: string | null; + /** + * Format + * @default diffusers + * @constant + * @enum {string} + */ + format: "diffusers"; + /** @default */ + repo_variant?: components["schemas"]["ModelRepoVariant"] | null; + /** + * Type + * @default clip_embed + * @constant + * @enum {string} + */ + type: "clip_embed"; + }; /** CLIPField */ CLIPField: { /** @description Info to load tokenizer submodel */ @@ -3669,11 +3735,11 @@ export type components = { cover_image?: string | null; /** * Format + * @description Format of the provided checkpoint model * @default checkpoint - * @constant * @enum {string} */ - format: "checkpoint"; + format: "checkpoint" | "bnb_quantized_nf4b"; /** * Config Path * @description path to the checkpoint model config file @@ -5654,6 +5720,246 @@ export type components = { */ type: "float_to_int"; }; + /** + * FluxConditioningField + * @description A conditioning tensor primitive value + */ + FluxConditioningField: { + /** + * Conditioning Name + * @description The name of conditioning tensor + */ + conditioning_name: string; + }; + /** + * FluxConditioningOutput + * @description Base class for nodes that output a single conditioning tensor + */ + FluxConditioningOutput: { + /** @description Conditioning tensor */ + conditioning: components["schemas"]["FluxConditioningField"]; + /** + * type + * @default flux_conditioning_output + * @constant + * @enum {string} + */ + type: "flux_conditioning_output"; + }; + /** + * Flux Main Model + * @description Loads a flux base model, outputting its submodels. + */ + FluxModelLoaderInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** @description Flux model (Transformer) to load */ + model: components["schemas"]["ModelIdentifierField"]; + /** @description T5 tokenizer and text encoder */ + t5_encoder: components["schemas"]["ModelIdentifierField"]; + /** + * type + * @default flux_model_loader + * @constant + * @enum {string} + */ + type: "flux_model_loader"; + }; + /** + * FluxModelLoaderOutput + * @description Flux base model loader output + */ + FluxModelLoaderOutput: { + /** + * Transformer + * @description Transformer + */ + transformer: components["schemas"]["TransformerField"]; + /** + * CLIP + * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + */ + clip: components["schemas"]["CLIPField"]; + /** + * T5 Encoder + * @description T5 tokenizer and text encoder + */ + t5_encoder: components["schemas"]["T5EncoderField"]; + /** + * VAE + * @description VAE + */ + vae: components["schemas"]["VAEField"]; + /** + * Max Seq Length + * @description The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer) + * @enum {integer} + */ + max_seq_len: 256 | 512; + /** + * type + * @default flux_model_loader_output + * @constant + * @enum {string} + */ + type: "flux_model_loader_output"; + }; + /** + * FLUX Text Encoding + * @description Encodes and preps a prompt for a flux image. + */ + FluxTextEncoderInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * CLIP + * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null + */ + clip?: components["schemas"]["CLIPField"]; + /** + * T5Encoder + * @description T5 tokenizer and text encoder + * @default null + */ + t5_encoder?: components["schemas"]["T5EncoderField"]; + /** + * T5 Max Seq Len + * @description Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models. + * @default null + * @enum {integer} + */ + t5_max_seq_len?: 256 | 512; + /** + * Prompt + * @description Text prompt to encode. + * @default null + */ + prompt?: string; + /** + * type + * @default flux_text_encoder + * @constant + * @enum {string} + */ + type: "flux_text_encoder"; + }; + /** + * FLUX Text to Image + * @description Text-to-image generation using a FLUX model. + */ + FluxTextToImageInvocation: { + /** + * @description The board to save the image to + * @default null + */ + board?: components["schemas"]["BoardField"] | null; + /** + * @description Optional metadata to be saved with the image + * @default null + */ + metadata?: components["schemas"]["MetadataField"] | null; + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * Transformer + * @description Flux model (Transformer) to load + * @default null + */ + transformer?: components["schemas"]["TransformerField"]; + /** + * @description VAE + * @default null + */ + vae?: components["schemas"]["VAEField"]; + /** + * @description Positive conditioning tensor + * @default null + */ + positive_text_conditioning?: components["schemas"]["FluxConditioningField"]; + /** + * Width + * @description Width of the generated image. + * @default 1024 + */ + width?: number; + /** + * Height + * @description Height of the generated image. + * @default 1024 + */ + height?: number; + /** + * Num Steps + * @description Number of diffusion steps. Recommend values are schnell: 4, dev: 50. + * @default 4 + */ + num_steps?: number; + /** + * Guidance + * @description The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell. + * @default 4 + */ + guidance?: number; + /** + * Seed + * @description Randomness seed for reproducibility. + * @default 0 + */ + seed?: number; + /** + * type + * @default flux_text_to_image + * @constant + * @enum {string} + */ + type: "flux_text_to_image"; + }; /** FoundModel */ FoundModel: { /** @@ -5788,7 +6094,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"]; + [key: string]: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxTextToImageInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"]; }; /** * Edges @@ -5825,7 +6131,7 @@ export type components = { * @description The results of node executions */ results?: { - [key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; + [key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; }; /** * Errors @@ -8210,7 +8516,7 @@ export type components = { * Invocation * @description The ID of the invocation */ - invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"]; + invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxTextToImageInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"]; /** * Invocation Source Id * @description The ID of the prepared invocation's source node @@ -8220,7 +8526,7 @@ export type components = { * Result * @description The result of the invocation */ - result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; + result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; }; /** * InvocationDenoiseProgressEvent @@ -8256,7 +8562,7 @@ export type components = { * Invocation * @description The ID of the invocation */ - invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"]; + invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxTextToImageInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"]; /** * Invocation Source Id * @description The ID of the prepared invocation's source node @@ -8319,7 +8625,7 @@ export type components = { * Invocation * @description The ID of the invocation */ - invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"]; + invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxTextToImageInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"]; /** * Invocation Source Id * @description The ID of the prepared invocation's source node @@ -8394,6 +8700,9 @@ export type components = { float_math: components["schemas"]["FloatOutput"]; float_range: components["schemas"]["FloatCollectionOutput"]; float_to_int: components["schemas"]["IntegerOutput"]; + flux_model_loader: components["schemas"]["FluxModelLoaderOutput"]; + flux_text_encoder: components["schemas"]["FluxConditioningOutput"]; + flux_text_to_image: components["schemas"]["ImageOutput"]; freeu: components["schemas"]["UNetOutput"]; grounding_dino: components["schemas"]["BoundingBoxCollectionOutput"]; hed_image_processor: components["schemas"]["ImageOutput"]; @@ -8534,7 +8843,7 @@ export type components = { * Invocation * @description The ID of the invocation */ - invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"]; + invocation: components["schemas"]["AddInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BoundingBoxInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["FluxModelLoaderInvocation"] | components["schemas"]["FluxTextEncoderInvocation"] | components["schemas"]["FluxTextToImageInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["GroundingDinoInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["MaskTensorToImageInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SegmentAnythingInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["SpandrelImageToImageAutoscaleInvocation"] | components["schemas"]["SpandrelImageToImageInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["TiledMultiDiffusionDenoiseLatents"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"]; /** * Invocation Source Id * @description The ID of the prepared invocation's source node @@ -9422,6 +9731,96 @@ export type components = { * @enum {integer} */ LogLevel: 0 | 10 | 20 | 30 | 40 | 50; + /** + * MainBnbQuantized4bCheckpointConfig + * @description Model config for main checkpoint models. + */ + MainBnbQuantized4bCheckpointConfig: { + /** + * Key + * @description A unique key for this model. + */ + key: string; + /** + * Hash + * @description The hash of the model file(s). + */ + hash: string; + /** + * Path + * @description Path to the model on the filesystem. Relative paths are relative to the Invoke root directory. + */ + path: string; + /** + * Name + * @description Name of the model. + */ + name: string; + /** @description The base model. */ + base: components["schemas"]["BaseModelType"]; + /** + * Description + * @description Model description + */ + description?: string | null; + /** + * Source + * @description The original source of the model (path, URL or repo_id). + */ + source: string; + /** @description The type of source */ + source_type: components["schemas"]["ModelSourceType"]; + /** + * Source Api Response + * @description The original API response from the source, as stringified JSON. + */ + source_api_response?: string | null; + /** + * Cover Image + * @description Url for image to preview model + */ + cover_image?: string | null; + /** + * Type + * @default main + * @constant + * @enum {string} + */ + type: "main"; + /** + * Trigger Phrases + * @description Set of trigger phrases for this model + */ + trigger_phrases?: string[] | null; + /** @description Default settings for this model */ + default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; + /** @default normal */ + variant?: components["schemas"]["ModelVariantType"]; + /** + * Format + * @description Format of the provided checkpoint model + * @default checkpoint + * @enum {string} + */ + format: "checkpoint" | "bnb_quantized_nf4b"; + /** + * Config Path + * @description path to the checkpoint model config file + */ + config_path: string; + /** + * Converted At + * @description When this model was last converted to diffusers + */ + converted_at?: number | null; + /** @default epsilon */ + prediction_type?: components["schemas"]["SchedulerPredictionType"]; + /** + * Upcast Attention + * @default false + */ + upcast_attention?: boolean; + }; /** * MainCheckpointConfig * @description Model config for main checkpoint models. @@ -9489,11 +9888,11 @@ export type components = { variant?: components["schemas"]["ModelVariantType"]; /** * Format + * @description Format of the provided checkpoint model * @default checkpoint - * @constant * @enum {string} */ - format: "checkpoint"; + format: "checkpoint" | "bnb_quantized_nf4b"; /** * Config Path * @description path to the checkpoint model config file @@ -10398,7 +10797,7 @@ export type components = { * @description Storage format of model. * @enum {string} */ - ModelFormat: "diffusers" | "checkpoint" | "lycoris" | "onnx" | "olive" | "embedding_file" | "embedding_folder" | "invokeai"; + ModelFormat: "diffusers" | "checkpoint" | "lycoris" | "onnx" | "olive" | "embedding_file" | "embedding_folder" | "invokeai" | "t5_encoder" | "bnb_quantized_int8b" | "bnb_quantized_nf4b"; /** ModelIdentifierField */ ModelIdentifierField: { /** @@ -10698,7 +11097,7 @@ export type components = { * Config Out * @description After successful installation, this will hold the configuration object. */ - config_out?: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"]) | null; + config_out?: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["MainBnbQuantized4bCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["T5EncoderConfig"] | components["schemas"]["T5EncoderBnbQuantizedLlmInt8bConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["CLIPEmbedDiffusersConfig"]) | null; /** * Inplace * @description Leave model in its current location; otherwise install under models directory @@ -10784,7 +11183,7 @@ export type components = { * Config * @description The model's config */ - config: components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"]; + config: components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["MainBnbQuantized4bCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["T5EncoderConfig"] | components["schemas"]["T5EncoderBnbQuantizedLlmInt8bConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["CLIPEmbedDiffusersConfig"]; /** * @description The submodel type, if any * @default null @@ -10805,7 +11204,7 @@ export type components = { * Config * @description The model's config */ - config: components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"]; + config: components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["MainBnbQuantized4bCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["T5EncoderConfig"] | components["schemas"]["T5EncoderBnbQuantizedLlmInt8bConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["CLIPEmbedDiffusersConfig"]; /** * @description The submodel type, if any * @default null @@ -10886,6 +11285,11 @@ export type components = { * @description hash of model file */ hash?: string | null; + /** + * Format + * @description format of model file + */ + format?: string | null; /** * Trigger Phrases * @description Set of trigger phrases for this model @@ -10928,7 +11332,7 @@ export type components = { * @description Model type. * @enum {string} */ - ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "t2i_adapter" | "spandrel_image_to_image"; + ModelType: "onnx" | "main" | "vae" | "lora" | "controlnet" | "embedding" | "ip_adapter" | "clip_vision" | "clip_embed" | "t2i_adapter" | "t5_encoder" | "spandrel_image_to_image"; /** * ModelVariantType * @description Variant type. @@ -10941,7 +11345,7 @@ export type components = { */ ModelsList: { /** Models */ - models: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"])[]; + models: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["MainBnbQuantized4bCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["T5EncoderConfig"] | components["schemas"]["T5EncoderBnbQuantizedLlmInt8bConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["CLIPEmbedDiffusersConfig"])[]; }; /** * Multiply Integers @@ -13273,6 +13677,7 @@ export type components = { name: string; base: components["schemas"]["BaseModelType"]; type: components["schemas"]["ModelType"]; + format?: components["schemas"]["ModelFormat"] | null; /** * Is Installed * @default false @@ -13291,6 +13696,7 @@ export type components = { name: string; base: components["schemas"]["BaseModelType"]; type: components["schemas"]["ModelType"]; + format?: components["schemas"]["ModelFormat"] | null; /** * Is Installed * @default false @@ -13791,7 +14197,7 @@ export type components = { * @description Submodel type. * @enum {string} */ - SubModelType: "unet" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "vae_decoder" | "vae_encoder" | "scheduler" | "safety_checker"; + SubModelType: "unet" | "transformer" | "text_encoder" | "text_encoder_2" | "tokenizer" | "tokenizer_2" | "vae" | "vae_decoder" | "vae_encoder" | "scheduler" | "safety_checker"; /** * Subtract Integers * @description Subtracts two numbers @@ -14052,6 +14458,135 @@ export type components = { */ type: "t2i_adapter_output"; }; + /** T5EncoderBnbQuantizedLlmInt8bConfig */ + T5EncoderBnbQuantizedLlmInt8bConfig: { + /** + * Key + * @description A unique key for this model. + */ + key: string; + /** + * Hash + * @description The hash of the model file(s). + */ + hash: string; + /** + * Path + * @description Path to the model on the filesystem. Relative paths are relative to the Invoke root directory. + */ + path: string; + /** + * Name + * @description Name of the model. + */ + name: string; + /** @description The base model. */ + base: components["schemas"]["BaseModelType"]; + /** + * Description + * @description Model description + */ + description?: string | null; + /** + * Source + * @description The original source of the model (path, URL or repo_id). + */ + source: string; + /** @description The type of source */ + source_type: components["schemas"]["ModelSourceType"]; + /** + * Source Api Response + * @description The original API response from the source, as stringified JSON. + */ + source_api_response?: string | null; + /** + * Cover Image + * @description Url for image to preview model + */ + cover_image?: string | null; + /** + * Type + * @default t5_encoder + * @constant + * @enum {string} + */ + type: "t5_encoder"; + /** + * Format + * @default bnb_quantized_int8b + * @constant + * @enum {string} + */ + format: "bnb_quantized_int8b"; + }; + /** T5EncoderConfig */ + T5EncoderConfig: { + /** + * Key + * @description A unique key for this model. + */ + key: string; + /** + * Hash + * @description The hash of the model file(s). + */ + hash: string; + /** + * Path + * @description Path to the model on the filesystem. Relative paths are relative to the Invoke root directory. + */ + path: string; + /** + * Name + * @description Name of the model. + */ + name: string; + /** @description The base model. */ + base: components["schemas"]["BaseModelType"]; + /** + * Description + * @description Model description + */ + description?: string | null; + /** + * Source + * @description The original source of the model (path, URL or repo_id). + */ + source: string; + /** @description The type of source */ + source_type: components["schemas"]["ModelSourceType"]; + /** + * Source Api Response + * @description The original API response from the source, as stringified JSON. + */ + source_api_response?: string | null; + /** + * Cover Image + * @description Url for image to preview model + */ + cover_image?: string | null; + /** + * Type + * @default t5_encoder + * @constant + * @enum {string} + */ + type: "t5_encoder"; + /** + * Format + * @default t5_encoder + * @constant + * @enum {string} + */ + format: "t5_encoder"; + }; + /** T5EncoderField */ + T5EncoderField: { + /** @description Info to load tokenizer submodel */ + tokenizer: components["schemas"]["ModelIdentifierField"]; + /** @description Info to load text_encoder submodel */ + text_encoder: components["schemas"]["ModelIdentifierField"]; + }; /** TBLR */ TBLR: { /** Top */ @@ -14483,6 +15018,11 @@ export type components = { */ type: "tiled_multi_diffusion_denoise_latents"; }; + /** TransformerField */ + TransformerField: { + /** @description Info to load Transformer submodel */ + transformer: components["schemas"]["ModelIdentifierField"]; + }; /** * UIComponent * @description The type of UI component to use for a field, used to override the default components, which are @@ -14557,7 +15097,7 @@ export type components = { * used, and the type will be ignored. They are included here for backwards compatibility. * @enum {string} */ - UIType: "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; + UIType: "MainModelField" | "FluxMainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "T5EncoderModelField" | "SpandrelImageToImageModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** UNetField */ UNetField: { /** @description Info to load unet submodel */ @@ -14737,11 +15277,11 @@ export type components = { cover_image?: string | null; /** * Format + * @description Format of the provided checkpoint model * @default checkpoint - * @constant * @enum {string} */ - format: "checkpoint"; + format: "checkpoint" | "bnb_quantized_nf4b"; /** * Config Path * @description path to the checkpoint model config file @@ -15269,7 +15809,7 @@ export interface operations { [name: string]: unknown; }; content: { - "application/json": components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"]; + "application/json": components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["MainBnbQuantized4bCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["T5EncoderConfig"] | components["schemas"]["T5EncoderBnbQuantizedLlmInt8bConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["CLIPEmbedDiffusersConfig"]; }; }; /** @description Validation Error */ @@ -15301,7 +15841,7 @@ export interface operations { [name: string]: unknown; }; content: { - "application/json": components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"]; + "application/json": components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["MainBnbQuantized4bCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["T5EncoderConfig"] | components["schemas"]["T5EncoderBnbQuantizedLlmInt8bConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["CLIPEmbedDiffusersConfig"]; }; }; /** @description Bad request */ @@ -15398,7 +15938,7 @@ export interface operations { [name: string]: unknown; }; content: { - "application/json": components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"]; + "application/json": components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["MainBnbQuantized4bCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["T5EncoderConfig"] | components["schemas"]["T5EncoderBnbQuantizedLlmInt8bConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["CLIPEmbedDiffusersConfig"]; }; }; /** @description Bad request */ @@ -15898,7 +16438,7 @@ export interface operations { [name: string]: unknown; }; content: { - "application/json": components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"]; + "application/json": components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["MainBnbQuantized4bCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["T5EncoderConfig"] | components["schemas"]["T5EncoderBnbQuantizedLlmInt8bConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["CLIPEmbedDiffusersConfig"]; }; }; /** @description Bad request */ diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index bfe42f3f9a1..d7df8967b83 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -51,6 +51,9 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig']; export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']; export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig']; export type T2IAdapterModelConfig = S['T2IAdapterConfig']; +type ClipEmbedModelConfig = S['CLIPEmbedDiffusersConfig']; +export type T5EncoderModelConfig = S['T5EncoderConfig']; +export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig']; export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig']; type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig']; type DiffusersModelConfig = S['MainDiffusersConfig']; @@ -62,6 +65,9 @@ export type AnyModelConfig = | VAEModelConfig | ControlNetModelConfig | IPAdapterModelConfig + | T5EncoderModelConfig + | T5EncoderBnbQuantizedLlmInt8bModelConfig + | ClipEmbedModelConfig | T2IAdapterModelConfig | SpandrelImageToImageModelConfig | TextualInversionModelConfig @@ -88,6 +94,16 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd return config.type === 't2i_adapter'; }; +export const isT5EncoderModelConfig = ( + config: AnyModelConfig +): config is T5EncoderModelConfig | T5EncoderBnbQuantizedLlmInt8bModelConfig => { + return config.type === 't5_encoder'; +}; + +export const isClipEmbedModelConfig = (config: AnyModelConfig): config is ClipEmbedModelConfig => { + return config.type === 'clip_embed'; +}; + export const isSpandrelImageToImageModelConfig = ( config: AnyModelConfig ): config is SpandrelImageToImageModelConfig => { @@ -110,6 +126,10 @@ export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is M return config.type === 'main' && config.base !== 'sdxl-refiner'; }; +export const isNonRefinerNonFluxMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'main' && config.base !== 'sdxl-refiner' && config.base !== 'flux'; +}; + export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && config.base === 'sdxl-refiner'; }; @@ -118,6 +138,10 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma return config.type === 'main' && config.base === 'sdxl'; }; +export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => { + return config.type === 'main' && config.base === 'flux'; +}; + export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2'); }; diff --git a/pyproject.toml b/pyproject.toml index 37ff1936edf..2cbd8298570 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ classifiers = [ dependencies = [ # Core generation dependencies, pinned for reproducible builds. "accelerate==0.30.1", + "bitsandbytes==0.43.3; sys_platform!='darwin'", "clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel==2.0.2", "controlnet-aux==0.0.7", @@ -46,6 +47,8 @@ dependencies = [ "opencv-python==4.9.0.80", "pytorch-lightning==2.1.3", "safetensors==0.4.3", + # sentencepiece is required to load T5TokenizerFast (used by FLUX). + "sentencepiece==0.2.0", "spandrel==0.3.4", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "torch==2.2.2", diff --git a/tests/backend/model_manager/util/test_hf_model_select.py b/tests/backend/model_manager/util/test_hf_model_select.py index a29827e8c43..8b5a395fdbe 100644 --- a/tests/backend/model_manager/util/test_hf_model_select.py +++ b/tests/backend/model_manager/util/test_hf_model_select.py @@ -326,3 +326,80 @@ def test_select_multiple_weights( ) -> None: filtered_files = filter_files(sd15_test_files, variant) assert set(filtered_files) == {Path(f) for f in expected_files} + + +@pytest.fixture +def flux_schnell_test_files() -> list[Path]: + return [ + Path(f) + for f in [ + "FLUX.1-schnell/.gitattributes", + "FLUX.1-schnell/README.md", + "FLUX.1-schnell/ae.safetensors", + "FLUX.1-schnell/flux1-schnell.safetensors", + "FLUX.1-schnell/model_index.json", + "FLUX.1-schnell/scheduler/scheduler_config.json", + "FLUX.1-schnell/schnell_grid.jpeg", + "FLUX.1-schnell/text_encoder/config.json", + "FLUX.1-schnell/text_encoder/model.safetensors", + "FLUX.1-schnell/text_encoder_2/config.json", + "FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors", + "FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors", + "FLUX.1-schnell/text_encoder_2/model.safetensors.index.json", + "FLUX.1-schnell/tokenizer/merges.txt", + "FLUX.1-schnell/tokenizer/special_tokens_map.json", + "FLUX.1-schnell/tokenizer/tokenizer_config.json", + "FLUX.1-schnell/tokenizer/vocab.json", + "FLUX.1-schnell/tokenizer_2/special_tokens_map.json", + "FLUX.1-schnell/tokenizer_2/spiece.model", + "FLUX.1-schnell/tokenizer_2/tokenizer.json", + "FLUX.1-schnell/tokenizer_2/tokenizer_config.json", + "FLUX.1-schnell/transformer/config.json", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json", + "FLUX.1-schnell/vae/config.json", + "FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors", + ] + ] + + +@pytest.mark.parametrize( + ["variant", "expected_files"], + [ + ( + ModelRepoVariant.Default, + [ + "FLUX.1-schnell/model_index.json", + "FLUX.1-schnell/scheduler/scheduler_config.json", + "FLUX.1-schnell/text_encoder/config.json", + "FLUX.1-schnell/text_encoder/model.safetensors", + "FLUX.1-schnell/text_encoder_2/config.json", + "FLUX.1-schnell/text_encoder_2/model-00001-of-00002.safetensors", + "FLUX.1-schnell/text_encoder_2/model-00002-of-00002.safetensors", + "FLUX.1-schnell/text_encoder_2/model.safetensors.index.json", + "FLUX.1-schnell/tokenizer/merges.txt", + "FLUX.1-schnell/tokenizer/special_tokens_map.json", + "FLUX.1-schnell/tokenizer/tokenizer_config.json", + "FLUX.1-schnell/tokenizer/vocab.json", + "FLUX.1-schnell/tokenizer_2/special_tokens_map.json", + "FLUX.1-schnell/tokenizer_2/spiece.model", + "FLUX.1-schnell/tokenizer_2/tokenizer.json", + "FLUX.1-schnell/tokenizer_2/tokenizer_config.json", + "FLUX.1-schnell/transformer/config.json", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00001-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00002-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model-00003-of-00003.safetensors", + "FLUX.1-schnell/transformer/diffusion_pytorch_model.safetensors.index.json", + "FLUX.1-schnell/vae/config.json", + "FLUX.1-schnell/vae/diffusion_pytorch_model.safetensors", + ], + ), + ], +) +def test_select_flux_schnell_files( + flux_schnell_test_files: list[Path], variant: ModelRepoVariant, expected_files: list[str] +) -> None: + filtered_files = filter_files(flux_schnell_test_files, variant) + assert set(filtered_files) == {Path(f) for f in expected_files}