Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emergent Object Detection/Classification Model.v1 #56

Closed
wants to merge 8 commits into from
Binary file added vision/emergent_object/best.pt
Binary file not shown.
92 changes: 92 additions & 0 deletions vision/emergent_object/emergent_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
from vision.common.constants import Image
from typing import Callable, Any

# You may have to install:
# pandas
# torchvision
# tqdm
# seaborn
Comment on lines +5 to +9
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dependencies should all be taken care of by #58 if you run in a poetry shell. These comments can be removed.


MODEL_PATH = "vision/emergent_object/best.pt"
Comment on lines +10 to +11
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this constant is only used in emergent_object_model(), constant MODEL_PATH should be moved to that function.



# Function to do detection / classification
def detect_emergent_object(image: Image, model: Callable):
"""
Detects an emergent object within an image

Parameters
----------
image
The image being analyzed by the model.

model
The model which is being used for object detection/classification

Returns
-------
output
A dataframe containing the xy coordinates of
the detected object within the image
"""
# Convert to RGB
image: Image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Run the model on the image. Both a file path and a numpy image work, but
# we want to use a numpy image
model_prediction = model(image)

# Retrieve the output from the model
object_location: Any = model_prediction.pandas().xyxy[0]

return object_location


# Load the model from the file
def create_emergent_model():
"""
Creates the model used for object detection/classification

Parameters
----------
None

Returns
-------
model : callable
The model used for object detection/classification
"""
model: Callable = torch.hub.load("ultralytics/yolov5", "custom", path=MODEL_PATH)
return model


if __name__ == "__main__":
import cv2

Comment on lines +65 to +66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cv2 import should be at the top of the file since it is used in emergent_object_detection()

image_path = "vision/emergent_object/people.png"

Comment on lines +67 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could make this into a command line argument

image = cv2.imread(image_path)

# Create model
model = create_emergent_model()

# Use model for detection / classification
output = detect_emergent_object(image, model)

# Convert the Pandas Dataframe to a dictionary - this will be necessary and
# should eventually be done in `detect_emergent_object()`
output_dict = output.to_dict("index")

# Draw the bounding boxes to the original image
for row in output_dict.values():
Comment on lines +81 to +82
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hint for row. Loop variables are type hinted like this:

row: Type
for row ...

# Get the output ranges
top_left = (int(row["xmin"]), int(row["ymin"]))
bottom_right = (int(row["xmax"]), int(row["ymax"]))
Comment on lines +83 to +85
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type hints for these two variables


# Draw the bounding box
cv2.rectangle(image, top_left, bottom_right, (255, 0, 0), 4)

# Display the image
cv2.imshow("", image)
Comment on lines +90 to +91
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want something like "Emergent Object" instead of an empty string for the title.

cv2.waitKey(0)