Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crop & Mask updates #6876

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions invokeai/app/invocations/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,91 @@ def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
width=image_dto.width,
height=image_dto.height,
)


@invocation_output("crop_to_object_output")
class CropToObjectOutput(ImageOutput):
offset_top: int = OutputField(description="The number of pixels cropped from the top")
offset_left: int = OutputField(description="The number of pixels cropped from the left")
offset_right: int = OutputField(description="The number of pixels cropped from the right")
offset_bottom: int = OutputField(description="The number of pixels cropped from the bottom")


@invocation(
"crop_to_object",
title="Crop to Object with Margin",
tags=["image", "crop"],
category="image",
version="1.0.1",
)
class CropToObjectInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Crops an image to a specified box around the object of specified color."""

image: ImageField = InputField(description="An input mask image with black and white content")
margin: int = InputField(default=0, ge=0, description="The desired margin around the object, as measured in pixels")
object_color: Literal["white", "black"] = InputField(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SAM can output multi-colored masks, right? Maybe we want this to be a ColorField instead, and update the mask extraction accordingly.

A future UI component could let the user click the specific mask they want, sampling its color, and then pass that into this node. So it'd be like a two-stage filter - segment, then choose the mask.

default="white", description="The color of the object to crop around (either 'white' or 'black')"
)

def invoke(self, context: InvocationContext) -> CropToObjectOutput:
# Load the image
image = context.images.get_pil(self.image.image_name)

# Convert image to grayscale
grayscale_image = image.convert("L")

# Convert to numpy array
np_image = numpy.array(grayscale_image)

# Depending on the object color, find the object pixels
if self.object_color == "white":
# Find white pixels (value > 0)
object_pixels = numpy.argwhere(np_image > 0)
else:
# Find black pixels (value < 255)
object_pixels = numpy.argwhere(np_image < 255)

# If no object pixels are found, return the original image and zero offsets
if object_pixels.size == 0:
image_dto = context.images.save(image=image.copy())
return CropToObjectOutput(
image=ImageField(image_name=image_dto.image_name),
width=image.width,
height=image.height,
offset_top=0,
offset_left=0,
offset_right=0,
offset_bottom=0,
)

# Get bounding box of object pixels
y_min, x_min = object_pixels.min(axis=0)
y_max, x_max = object_pixels.max(axis=0)

# Expand bounding box by margin
x_min_expanded = max(x_min - self.margin, 0)
y_min_expanded = max(y_min - self.margin, 0)
x_max_expanded = min(x_max + self.margin, np_image.shape[1] - 1)
y_max_expanded = min(y_max + self.margin, np_image.shape[0] - 1)

# Crop the image
cropped_image = image.crop((x_min_expanded, y_min_expanded, x_max_expanded + 1, y_max_expanded + 1))

# Calculate offsets
offset_top = y_min_expanded
offset_left = x_min_expanded
offset_right = np_image.shape[1] - x_max_expanded - 1
offset_bottom = np_image.shape[0] - y_max_expanded - 1

# Save the cropped image
image_dto = context.images.save(image=cropped_image)

return CropToObjectOutput(
image=ImageField(image_name=image_dto.image_name),
width=cropped_image.width,
height=cropped_image.height,
offset_top=offset_top,
offset_left=offset_left,
offset_right=offset_right,
offset_bottom=offset_bottom,
)
21 changes: 17 additions & 4 deletions invokeai/app/invocations/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def invoke(self, context: InvocationContext) -> MaskOutput:
title="Image Mask to Tensor",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
version="1.0.1",
)
class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
"""Convert a mask image to a tensor. Converts the image to grayscale and uses thresholding at the specified value."""
Expand All @@ -106,13 +106,26 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
invert: bool = InputField(default=False, description="Whether to invert the mask.")

def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.images.get_pil(self.image.image_name, mode="L")
image = context.images.get_pil(self.image.image_name)
np_image = np.array(image)
# Handle different image modes
if image.mode == "RGBA":
alpha_channel = np_image[:, :, 3] # Extract alpha channel
elif image.mode == "RGB":
# For RGB images, treat all non-black pixels as opaque.
non_black_mask = np.any(np_image > 0, axis=2) # True for any non-black pixels
alpha_channel = non_black_mask.astype(np.uint8) * 255 # Convert to a mask of 0 or 255
elif image.mode == "L": # Grayscale images
alpha_channel = np_image # Grayscale image, so we directly use it
else:
raise ValueError(f"Unsupported image mode: {image.mode}")

mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)

if self.invert:
mask[0] = torch.tensor(np.array(image)[:, :] >= self.cutoff, dtype=torch.bool)
mask[0] = torch.tensor(alpha_channel == 0, dtype=torch.bool) # Transparent where alpha or brightness is 0
else:
mask[0] = torch.tensor(np.array(image)[:, :] < self.cutoff, dtype=torch.bool)
mask[0] = torch.tensor(alpha_channel > 0, dtype=torch.bool) # Opaque where alpha or brightness is > 0

return MaskOutput(
mask=TensorField(tensor_name=context.tensors.save(mask)),
Expand Down