Skip to content

Commit

Permalink
Rename grav_index to curve_index (#68)
Browse files Browse the repository at this point in the history
* Fix documentation

* Import `lengths` in init

* Change gravitropism to curvature
  • Loading branch information
eberrigan authored Oct 7, 2023
1 parent fd41164 commit 52938e1
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 48 deletions.
3 changes: 2 additions & 1 deletion sleap_roots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sleap_roots.convhull
import sleap_roots.ellipse
import sleap_roots.networklength
import sleap_roots.lengths
import sleap_roots.points
import sleap_roots.scanline
import sleap_roots.series
Expand All @@ -16,4 +17,4 @@

# Define package version.
# This is read dynamically by setuptools in pyproject.toml to determine the release version.
__version__ = "0.0.4"
__version__ = "0.0.5"
18 changes: 9 additions & 9 deletions sleap_roots/lengths.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ def get_root_lengths_max(pts: np.ndarray) -> np.ndarray:
return max_length


def get_grav_index(
def get_curve_index(
lengths: Union[float, np.ndarray], base_tip_dists: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
"""Calculate the gravitropism index of a root.
"""Calculate the curvature index of a root.
The gravitropism index quantifies the curviness of the root's growth. A higher
gravitropism index indicates a curvier root (less responsive to gravity), while a
The curvature index quantifies the curviness of the root's growth. A higher
curvature index indicates a curvier root (less responsive to gravity), while a
lower index indicates a straighter root (more responsive to gravity). The index is
computed as the difference between the maximum root length and straight-line
distance from the base to the tip of the root, normalized by the root length.
Expand All @@ -129,7 +129,7 @@ def get_grav_index(
root(s). Can be a scalar or a 1D numpy array of shape `(instances,)`.
Returns:
Gravitropism index of the root(s), quantifying its/their curviness. Will be a
Curvature index of the root(s), quantifying its/their curviness. Will be a
scalar if input is scalar, or a 1D numpy array of shape `(instances,)`
otherwise.
"""
Expand All @@ -144,8 +144,8 @@ def get_grav_index(
if lengths.shape != base_tip_dists.shape:
raise ValueError("The shapes of lengths and base_tip_dists must match.")

# Calculate the gravitropism index where possible
grav_index = np.where(
# Calculate the curvature index where possible
curve_index = np.where(
(~np.isnan(lengths))
& (~np.isnan(base_tip_dists))
& (lengths > 0)
Expand All @@ -156,6 +156,6 @@ def get_grav_index(

# Return scalar or array based on the input type
if is_scalar_input:
return grav_index.item()
return curve_index.item()
else:
return grav_index
return curve_index
2 changes: 1 addition & 1 deletion sleap_roots/scanline.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_scanline_last_ind(scanline_intersection_counts: np.ndarray):
Return:
Scalar of count_scanline_interaction index for the last interaction.
"""
# get the first scanline index using scanline_intersection_counts
# get the last scanline index using scanline_intersection_counts
if np.where((scanline_intersection_counts > 0))[0].shape[0] > 0:
scanline_last_ind = np.where((scanline_intersection_counts > 0))[0][-1]
return scanline_last_ind
Expand Down
18 changes: 9 additions & 9 deletions sleap_roots/trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
get_ellipse_b,
get_ellipse_ratio,
)
from sleap_roots.lengths import get_grav_index, get_max_length_pts, get_root_lengths
from sleap_roots.lengths import get_curve_index, get_max_length_pts, get_root_lengths
from sleap_roots.networklength import (
get_bbox,
get_network_distribution,
Expand Down Expand Up @@ -811,13 +811,13 @@ def define_traits(self) -> List[TraitDef]:
description="Scalar of base median ratio.",
),
TraitDef(
name="grav_index",
fn=get_grav_index,
name="curve_index",
fn=get_curve_index,
input_traits=["primary_length", "primary_base_tip_dist"],
scalar=True,
include_in_csv=True,
kwargs={},
description="Scalar of primary root gravity index.",
description="Scalar of primary root curvature index.",
),
TraitDef(
name="base_length_ratio",
Expand Down Expand Up @@ -1189,13 +1189,13 @@ def define_traits(self) -> List[TraitDef]:
"tip(s) of the main root(s).",
),
TraitDef(
name="main_grav_indices",
name="main_curve_indices",
fn=get_base_tip_dist,
input_traits=["main_base_pts", "main_tip_pts"],
scalar=False,
include_in_csv=True,
kwargs={},
description="Gravitropism index for each main root.",
description="Curvature index for each main root.",
),
TraitDef(
name="network_solidity",
Expand Down Expand Up @@ -1291,13 +1291,13 @@ def define_traits(self) -> List[TraitDef]:
"convex hull.",
),
TraitDef(
name="grav_index",
fn=get_grav_index,
name="curve_index",
fn=get_curve_index,
input_traits=["primary_length", "primary_base_tip_dist"],
scalar=True,
include_in_csv=True,
kwargs={},
description="Scalar of primary root gravity index.",
description="Scalar of primary root curvature index.",
),
TraitDef(
name="primary_base_tip_dist",
Expand Down
44 changes: 22 additions & 22 deletions tests/test_lengths.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sleap_roots.lengths import (
get_grav_index,
get_curve_index,
get_root_lengths,
get_root_lengths_max,
get_max_length_pts,
Expand Down Expand Up @@ -146,8 +146,8 @@ def lengths_all_nan():
return np.array([np.nan, np.nan, np.nan])


# tests for get_grav_index function
def test_get_grav_index_canola(canola_h5):
# tests for get_curve_index function
def test_get_curve_index_canola(canola_h5):
series = Series.load(
canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes"
)
Expand All @@ -158,22 +158,22 @@ def test_get_grav_index_canola(canola_h5):
bases = get_bases(max_length_pts)
tips = get_tips(max_length_pts)
base_tip_dist = get_base_tip_dist(bases, tips)
grav_index = get_grav_index(primary_length, base_tip_dist)
np.testing.assert_almost_equal(grav_index, 0.08898137324716636)
curve_index = get_curve_index(primary_length, base_tip_dist)
np.testing.assert_almost_equal(curve_index, 0.08898137324716636)


def test_get_grav_index():
def test_get_curve_index():
# Test 1: Scalar inputs where length > base_tip_dist
# Gravitropism index should be (10 - 8) / 10 = 0.2
assert get_grav_index(10, 8) == 0.2
# Curvature index should be (10 - 8) / 10 = 0.2
assert get_curve_index(10, 8) == 0.2

# Test 2: Scalar inputs where length and base_tip_dist are zero
# Should return NaN as length is zero
assert np.isnan(get_grav_index(0, 0))
assert np.isnan(get_curve_index(0, 0))

# Test 3: Scalar inputs where length < base_tip_dist
# Should return NaN as it's an invalid case
assert np.isnan(get_grav_index(5, 10))
assert np.isnan(get_curve_index(5, 10))

# Test 4: Array inputs covering various cases
# Case 1: length > base_tip_dist, should return 0.2
Expand All @@ -183,35 +183,35 @@ def test_get_grav_index():
lengths = np.array([10, 0, 5, 15])
base_tip_dists = np.array([8, 0, 10, 12])
expected = np.array([0.2, np.nan, np.nan, 0.2])
result = get_grav_index(lengths, base_tip_dists)
result = get_curve_index(lengths, base_tip_dists)
assert np.allclose(result, expected, equal_nan=True)

# Test 5: Mismatched shapes between lengths and base_tip_dists
# Should raise a ValueError
with pytest.raises(ValueError):
get_grav_index(np.array([10, 20]), np.array([8]))
get_curve_index(np.array([10, 20]), np.array([8]))

# Test 6: Array inputs with NaN values
# Case 1: length > base_tip_dist, should return 0.2
# Case 2 and 3: either length or base_tip_dist is NaN, should return NaN
lengths = np.array([10, np.nan, np.nan])
base_tip_dists = np.array([8, 8, np.nan])
expected = np.array([0.2, np.nan, np.nan])
result = get_grav_index(lengths, base_tip_dists)
result = get_curve_index(lengths, base_tip_dists)
assert np.allclose(result, expected, equal_nan=True)


def test_get_grav_index_shape():
def test_get_curve_index_shape():
# Check if scalar inputs result in scalar output
result = get_grav_index(10, 8)
result = get_curve_index(10, 8)
assert isinstance(
result, (int, float)
), f"Expected scalar output, got {type(result)}"

# Check if array inputs result in array output
lengths = np.array([10, 15])
base_tip_dists = np.array([8, 12])
result = get_grav_index(lengths, base_tip_dists)
result = get_curve_index(lengths, base_tip_dists)
assert isinstance(
result, np.ndarray
), f"Expected np.ndarray output, got {type(result)}"
Expand All @@ -225,7 +225,7 @@ def test_get_grav_index_shape():
# Check the shape of output for larger array inputs
lengths = np.array([10, 15, 20, 25])
base_tip_dists = np.array([8, 12, 18, 22])
result = get_grav_index(lengths, base_tip_dists)
result = get_curve_index(lengths, base_tip_dists)
assert (
result.shape == lengths.shape
), f"Output shape {result.shape} does not match input shape {lengths.shape}"
Expand All @@ -235,22 +235,22 @@ def test_nan_values():
lengths = np.array([10, np.nan, 30])
base_tip_dists = np.array([8, 16, np.nan])
np.testing.assert_array_equal(
get_grav_index(lengths, base_tip_dists), np.array([0.2, np.nan, np.nan])
get_curve_index(lengths, base_tip_dists), np.array([0.2, np.nan, np.nan])
)


def test_zero_lengths():
lengths = np.array([0, 20, 30])
base_tip_dists = np.array([0, 16, 24])
np.testing.assert_array_equal(
get_grav_index(lengths, base_tip_dists), np.array([np.nan, 0.2, 0.2])
get_curve_index(lengths, base_tip_dists), np.array([np.nan, 0.2, 0.2])
)


def test_invalid_scalar_values():
assert np.isnan(get_grav_index(np.nan, 8))
assert np.isnan(get_grav_index(10, np.nan))
assert np.isnan(get_grav_index(0, 8))
assert np.isnan(get_curve_index(np.nan, 8))
assert np.isnan(get_curve_index(10, np.nan))
assert np.isnan(get_curve_index(0, 8))


# tests for `get_root_lengths`
Expand Down
12 changes: 6 additions & 6 deletions tests/test_trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ def test_younger_monocot_pipeline(rice_h5, rice_folder):

# Value range assertions for traits
assert (
rice_traits["grav_index"].fillna(0) >= 0
).all(), "grav_index in rice_traits contains negative values"
rice_traits["curve_index"].fillna(0) >= 0
).all(), "curve_index in rice_traits contains negative values"
assert (
all_traits["grav_index_median"] >= 0
).all(), "grav_index in all_traits contains negative values"
all_traits["curve_index_median"] >= 0
).all(), "curve_index in all_traits contains negative values"
assert (
all_traits["main_grav_indices_mean_median"] >= 0
).all(), "main_grav_indices_mean_median in all_traits contains negative values"
all_traits["main_curve_indices_mean_median"] >= 0
).all(), "main_curve_indices_mean_median in all_traits contains negative values"
assert (
(0 <= rice_traits["main_angles_proximal_p95"])
& (rice_traits["main_angles_proximal_p95"] <= 180)
Expand Down

0 comments on commit 52938e1

Please sign in to comment.