Skip to content

Commit

Permalink
feat: Update Tile Pre-Processor to support more modes
Browse files Browse the repository at this point in the history
  • Loading branch information
blessedcoolant committed Jun 29, 2024
1 parent 10076fb commit 6414d2d
Show file tree
Hide file tree
Showing 3 changed files with 493 additions and 166 deletions.
127 changes: 83 additions & 44 deletions invokeai/app/invocations/controlnet_image_processors.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,47 @@
# Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
import random
from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union
from typing import Any, Dict, List, Literal, Union

import cv2
import numpy as np
from controlnet_aux import (
ContentShuffleDetector,
LeresDetector,
MediapipeFaceDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
PidiNetDetector,
SamDetector,
ZoeDetector,
)
from controlnet_aux import (ContentShuffleDetector, LeresDetector,
MediapipeFaceDetector, MidasDetector, MLSDdetector,
NormalBaeDetector, PidiNetDetector, SamDetector,
ZoeDetector)
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, Field, field_validator, model_validator

from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
OutputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.fields import (FieldDescriptions, ImageField,
InputField, OutputField, UIType,
WithBoard, WithMetadata)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.invocations.util import (validate_begin_end_step,
validate_weights)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.app.util.controlnet_utils import (CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
heuristic_resize)
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.depth_anything import (DEPTH_ANYTHING_MODELS,
DepthAnythingDetector)
from invokeai.backend.image_util.dw_openpose import (DWPOSE_MODELS,
DWOpenposeDetector)
from invokeai.backend.image_util.fast_guided_filter.fast_guided_filter import \
FastGuidedFilter
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.util.devices import TorchDevice

from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
Classification, invocation, invocation_output)


class ControlField(BaseModel):

Check failure on line 47 in invokeai/app/invocations/controlnet_image_processors.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (I001)

invokeai/app/invocations/controlnet_image_processors.py:4:1: I001 Import block is un-sorted or un-formatted
Expand Down Expand Up @@ -483,30 +479,73 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):

# res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
mode: Literal["regular", "blur", "var", "super"] = InputField(
default="regular", description="The controlnet tile model being used"
)

def apply_gaussian_blur(self, image_np: np.ndarray[Any, Any], ksize: int = 5, sigmaX: float = 1.0):
if ksize % 2 == 0:
ksize += 1 # ksize must be odd
blurred_image = cv2.GaussianBlur(image_np, (ksize, ksize), sigmaX=sigmaX)
return blurred_image

def apply_guided_filter(self, image_np: np.ndarray[Any, Any], radius: int, eps: float, scale: int):
filter = FastGuidedFilter(image_np, radius, eps, scale)
return filter.filter(image_np)

# based off https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic
def tile_resample(self, np_img: np.ndarray[Any, Any]):
height, width, _ = np_img.shape

# tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(
self,
np_img: np.ndarray,
res=512, # never used?
down_sampling_rate=1.0,
):
np_img = HWC3(np_img)
if down_sampling_rate < 1.1:
if self.mode == "regular":
np_img = HWC3(np_img)
if self.down_sampling_rate < 1.1:
return np_img

height = int(float(height) / float(self.down_sampling_rate))
width = int(float(width) / float(self.down_sampling_rate))
np_img = cv2.resize(np_img, (width, height), interpolation=cv2.INTER_AREA)
return np_img
H, W, C = np_img.shape
H = int(float(H) / float(down_sampling_rate))
W = int(float(W) / float(down_sampling_rate))
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)

ratio = np.sqrt(1024.0 * 1024.0 / (width * height))

resize_w, resize_h = int(width * ratio), int(height * ratio)

if self.mode == "super":
resize_w, resize_h = int(width * ratio) // 48 * 48, int(height * ratio) // 48 * 48

np_img = cv2.resize(np_img, (resize_w, resize_h))

if self.mode == "blur":
blur_strength = random.sample([i / 10.0 for i in range(10, 201, 2)], k=1)[0]
radius = random.sample([i for i in range(1, 40, 2)], k=1)[0]

Check failure on line 521 in invokeai/app/invocations/controlnet_image_processors.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (C416)

invokeai/app/invocations/controlnet_image_processors.py:521:36: C416 Unnecessary `list` comprehension (rewrite using `list()`)
eps = random.sample([i / 1000.0 for i in range(1, 101, 2)], k=1)[0]
scale_factor = random.sample([i / 10.0 for i in range(10, 181, 5)], k=1)[0]

if random.random() > 0.5:
np_img = self.apply_gaussian_blur(np_img, ksize=int(blur_strength), sigmaX=blur_strength / 2)

if random.random() > 0.5:
np_img = self.apply_guided_filter(np_img, radius, eps, int(scale_factor))

np_img = cv2.resize(
np_img, (int(resize_w / scale_factor), int(resize_h / scale_factor)), interpolation=cv2.INTER_AREA
)
np_img = cv2.resize(np_img, (resize_w, resize_h), interpolation=cv2.INTER_CUBIC)

if self.mode == "var":
pass

if self.mode == "super":
pass

np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)

return np_img

def run_processor(self, image: Image.Image) -> Image.Image:
np_img = np.array(image, dtype=np.uint8)
processed_np_image = self.tile_resample(
np_img,
# res=self.tile_size,
down_sampling_rate=self.down_sampling_rate,
)
processed_np_image = self.tile_resample(np_img)
processed_image = Image.fromarray(processed_np_image)
return processed_image

Expand Down
Loading

0 comments on commit 6414d2d

Please sign in to comment.