Skip to content

Commit

Permalink
Rearanged some code to speed up pipline, now it does not excecute the…
Browse files Browse the repository at this point in the history
… base image feature extraction everytime.
  • Loading branch information
tim-rehnstrom committed May 31, 2024
1 parent e026c90 commit 644fb3a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 30 deletions.
43 changes: 18 additions & 25 deletions analyst/workspace/gather_training_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
"from IPython.display import display, clear_output\n",
"from jupyter_ui_poll import ui_events\n",
"\n",
"import torch\n",
"from lightglue import SuperPoint\n",
"from lightglue.utils import load_image\n",
"\n",
"import scripts.query_image_of_bag as query_of_bag\n",
"import scripts.query_image as query_all\n",
"import scripts.save_patch_scripts as save_patch_scripts\n",
Expand Down Expand Up @@ -161,7 +165,6 @@
" clicked_points = []\n",
" ax.clear()\n",
" ax.imshow(image)\n",
" input_widget.value = ''\n",
" fig.canvas.draw()"
]
},
Expand All @@ -174,7 +177,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2037d3d63bd744adb666f5e6a1983851",
"model_id": "6ea3d3022f38425883bff2fa8ff89a92",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -188,7 +191,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b49c850df8104fd4a1b3c72dbc7d302f",
"model_id": "a83575b7ca204e318e6600ef1d17f5dd",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -202,7 +205,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2be2bd5c7ea14ab7834fdc53d9b5c045",
"model_id": "4f30d161abb9496d975039ff7fe7c8c7",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -302,23 +305,7 @@
"Connected to isaac database\n",
"From database got 96 matches\n",
"From first filtering got 43 matches\n",
"Query successful, got 43 matches\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading: \"https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth\" to /root/.cache/torch/hub/checkpoints/superpoint_v1.pth\n",
"100%|██████████| 4.96M/4.96M [00:00<00:00, 22.9MB/s]\n",
"Downloading: \"https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_lightglue.pth\" to /root/.cache/torch/hub/checkpoints/superpoint_lightglue_v0-1_arxiv.pth\n",
"100%|██████████| 45.3M/45.3M [00:02<00:00, 21.9MB/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Query successful, got 43 matches\n",
"Saved 43 images\n",
"Done extracting and saving image patches from bag\n",
"\n",
Expand Down Expand Up @@ -406,6 +393,12 @@
],
"source": [
"corners = points_and_polygons.identify_corners(clicked_points)\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # Check if Nvidia CUDA is supported by the gpu otherwise set device to cpu\n",
"base_image = load_image(base_image_path)\n",
"extractor = SuperPoint(max_num_keypoints=2048).eval().to(device) # Load Superpoint as the extractor\n",
"feats_base_image = extractor.extract(base_image.to(device))\n",
"\n",
"bags = os.listdir(bag_directory_path) # Load a list of all the bags in the bag directory\n",
"total_images_saved = 0 # Initialize the counter for the amount of total images saved\n",
"for bag in bags: # Go though the bags one by one\n",
Expand All @@ -418,7 +411,7 @@
" result = query_of_bag.query_image(target_position, target_attitude, ros_topic_pose, ros_topic_image, max_distance, min_distance, max_angle, target_size_y, target_size_z, bag) \n",
" \n",
" if result is not None: # Check if the query found any images or if the result is empty\n",
" total_images_saved = save_patch_scripts.save_images(result, image_path, base_image_path, corners, save_path, bag, total_images_saved, print_info) # If the query was sucessful, save the images\n",
" total_images_saved = save_patch_scripts.save_images(result, image_path, base_image_path, corners, save_path, bag, total_images_saved, feats_base_image, print_info) # If the query was sucessful, save the images\n",
" \n",
" else:\n",
" print('No images found\\n') # Otherwise inform the user that no images was found\n",
Expand Down Expand Up @@ -473,8 +466,8 @@
"Data copying and organization completed.\n",
"\n",
"Amount of total data is: 205\n",
"Amount of training data is: 155\n",
"Amount of test data is: 50\n"
"Amount of training data is: 159\n",
"Amount of test data is: 46\n"
]
}
],
Expand All @@ -499,7 +492,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Notebook executed in: 2078.4343552589417 seconds\n"
"Notebook executed in: 1528.259447813034 seconds\n"
]
}
],
Expand Down
7 changes: 3 additions & 4 deletions analyst/workspace/scripts/save_patch_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


# This function matches a query image to the perspective of the base image. It does this using superpoint feature extractor and lightglue feature matcher
def match_images_and_transform(base_image_path, query_image_path):
def match_images_and_transform(base_image_path, query_image_path, feats_base_image):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Check if Nvidia CUDA is supported by the gpu otherwise set device to cpu


Expand All @@ -25,7 +25,6 @@ def match_images_and_transform(base_image_path, query_image_path):
query_image = load_image(query_image_path)

# Extract features in both images and store them on the device
feats_base_image = extractor.extract(base_image.to(device))
feats_query_image = extractor.extract(query_image.to(device))

# Match the features found. Also remove the batch dimention for further processing
Expand Down Expand Up @@ -75,13 +74,13 @@ def extract_image(image, corners): # Takes the image and the corners in a numpy



def save_images(result, image_path, base_image_path, corners, save_path, bag, total_images_saved, print_info):
def save_images(result, image_path, base_image_path, corners, save_path, bag, total_images_saved, feats_base_image, print_info):
amount_images = 0 # Variable for counting the amount of images that have been saved from this bag.
for idx, element in enumerate(result): # Go though the result one by one
image = cv.imread(image_path + element['img']) #Load the image file with openCV

# Note: The corners are as follows; C1: bottom right, C2: top right, C3: top left, C4: bottom left
transformed_image = match_images_and_transform(base_image_path, image_path + element['img'])
transformed_image = match_images_and_transform(base_image_path, image_path + element['img'], feats_base_image)

extracted_image = extract_image(transformed_image, [corners['A'], corners['B'], corners['D'], corners['C']]) # Extract the image patches from the image

Expand Down
2 changes: 1 addition & 1 deletion scripts/docker/analyst.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ RUN apt-get update \
&& rm -rf /var/lib/apt/lists/*

RUN set -e \
pip3 install pyArango \
&& pip3 install pyArango \
&& pip3 install jupyterlab jupyterhub nbconvert Pygments==2.6.1 jupyros \
&& pip3 install networkx==3.1 \
&& pip3 install matplotlib opencv-python numpy-quaternion pandas \
Expand Down

0 comments on commit 644fb3a

Please sign in to comment.