-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Emergent Object Detection/Classification Model.v1 #56
Changes from all commits
c4839e1
8e61358
1d388ae
52a05cc
24d4b21
e475428
9e9d0ba
ad94558
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import torch | ||
from vision.common.constants import Image | ||
from typing import Callable, Any | ||
|
||
# You may have to install: | ||
# pandas | ||
# torchvision | ||
# tqdm | ||
# seaborn | ||
|
||
MODEL_PATH = "vision/emergent_object/best.pt" | ||
Comment on lines
+10
to
+11
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this constant is only used in |
||
|
||
|
||
# Function to do detection / classification | ||
def detect_emergent_object(image: Image, model: Callable): | ||
""" | ||
Detects an emergent object within an image | ||
|
||
Parameters | ||
---------- | ||
image | ||
The image being analyzed by the model. | ||
|
||
model | ||
The model which is being used for object detection/classification | ||
|
||
Returns | ||
------- | ||
output | ||
A dataframe containing the xy coordinates of | ||
the detected object within the image | ||
""" | ||
# Convert to RGB | ||
image: Image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
|
||
# Run the model on the image. Both a file path and a numpy image work, but | ||
# we want to use a numpy image | ||
model_prediction = model(image) | ||
|
||
# Retrieve the output from the model | ||
object_location: Any = model_prediction.pandas().xyxy[0] | ||
|
||
return object_location | ||
|
||
|
||
# Load the model from the file | ||
def create_emergent_model(): | ||
""" | ||
Creates the model used for object detection/classification | ||
|
||
Parameters | ||
---------- | ||
None | ||
|
||
Returns | ||
------- | ||
model : callable | ||
The model used for object detection/classification | ||
""" | ||
model: Callable = torch.hub.load("ultralytics/yolov5", "custom", path=MODEL_PATH) | ||
return model | ||
|
||
|
||
if __name__ == "__main__": | ||
import cv2 | ||
|
||
Comment on lines
+65
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
image_path = "vision/emergent_object/people.png" | ||
|
||
Comment on lines
+67
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could make this into a command line argument |
||
image = cv2.imread(image_path) | ||
|
||
# Create model | ||
model = create_emergent_model() | ||
|
||
# Use model for detection / classification | ||
output = detect_emergent_object(image, model) | ||
|
||
# Convert the Pandas Dataframe to a dictionary - this will be necessary and | ||
# should eventually be done in `detect_emergent_object()` | ||
output_dict = output.to_dict("index") | ||
|
||
# Draw the bounding boxes to the original image | ||
for row in output_dict.values(): | ||
Comment on lines
+81
to
+82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type hint for
|
||
# Get the output ranges | ||
top_left = (int(row["xmin"]), int(row["ymin"])) | ||
bottom_right = (int(row["xmax"]), int(row["ymax"])) | ||
Comment on lines
+83
to
+85
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type hints for these two variables |
||
|
||
# Draw the bounding box | ||
cv2.rectangle(image, top_left, bottom_right, (255, 0, 0), 4) | ||
|
||
# Display the image | ||
cv2.imshow("", image) | ||
Comment on lines
+90
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might want something like "Emergent Object" instead of an empty string for the title. |
||
cv2.waitKey(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dependencies should all be taken care of by #58 if you run in a poetry shell. These comments can be removed.