Skip to content

Commit

Permalink
use helper function to get intersection points for different geometri…
Browse files Browse the repository at this point in the history
…es (#81)
  • Loading branch information
eberrigan authored May 13, 2024
1 parent 97c69a0 commit fe7cdae
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
19 changes: 11 additions & 8 deletions sleap_roots/convhull.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,18 +548,21 @@ def get_chull_intersection_vectors(

# Get the intersection points
if not intersection.is_empty:
intersect_points = (
np.array([[point.x, point.y] for point in intersection.geoms])
if intersection.geom_type == "MultiPoint"
else np.array([[intersection.x, intersection.y]])
)
intersect_points = extract_points_from_geometry(intersection)
else:
# Return two vectors of NaNs if there is no intersection
return leftmost_vector, rightmost_vector

# Get the leftmost and rightmost intersection points
leftmost_intersect = intersect_points[np.argmin(intersect_points[:, 0])]
rightmost_intersect = intersect_points[np.argmax(intersect_points[:, 0])]
# Convert the list of NumPy arrays to a 2D NumPy array
intersection_points_array = np.vstack(intersect_points)

# Find the leftmost and rightmost intersection points
leftmost_intersect = intersection_points_array[
np.argmin(intersection_points_array[:, 0])
]
rightmost_intersect = intersection_points_array[
np.argmax(intersection_points_array[:, 0])
]

# Make a vector from the leftmost r0 point to the leftmost intersection point
leftmost_vector = (leftmost_intersect - leftmost_r0).reshape(1, -1)
Expand Down
2 changes: 1 addition & 1 deletion sleap_roots/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import List, Optional, Tuple


def extract_points_from_geometry(geometry):
def extract_points_from_geometry(geometry) -> List[np.ndarray]:
"""Extracts coordinates as a list of numpy arrays from any given Shapely geometry object.
This function supports Point, MultiPoint, LineString, and GeometryCollection types.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_convhull.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_basic_functionality(pts_shape_3_6_2):
r0_pts, r1_pts, pts, hull
)

# Assertions depend on the expected outcome, which you'll need to calculate based on your function's logic
# TODO: Add more specific tests as needed
assert not np.isnan(left_vector).any(), "Left vector should not contain NaNs"
assert not np.isnan(right_vector).any(), "Right vector should not contain NaNs"

Expand Down

0 comments on commit fe7cdae

Please sign in to comment.