Skip to content

Commit

Permalink
explicitly cast to intp
Browse files Browse the repository at this point in the history
  • Loading branch information
waltsims committed Nov 19, 2024
1 parent 2e138d1 commit 4437fe2
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion kwave/utils/matlab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Tuple, Union, Optional, List

from beartype import beartype as typechecker
import numpy as np


Expand Down Expand Up @@ -52,6 +52,7 @@ def matlab_assign(matrix: np.ndarray, indices: Union[int, np.ndarray], values: U
return matrix.reshape(original_shape, order="F")


@typechecker
def matlab_find(arr: Union[List[int], np.ndarray], val: int = 0, mode: str = "neq") -> np.ndarray:
"""
Finds the indices of elements in an array that satisfy a given condition.
Expand All @@ -75,6 +76,7 @@ def matlab_find(arr: Union[List[int], np.ndarray], val: int = 0, mode: str = "ne
return np.expand_dims(arr, -1) # compatibility, n => [n, 1]


@typechecker
def matlab_mask(arr: np.ndarray, mask: np.ndarray, diff: Optional[int] = None) -> np.ndarray:
"""
Applies a mask to an array and returns the masked elements.
Expand All @@ -89,6 +91,8 @@ def matlab_mask(arr: np.ndarray, mask: np.ndarray, diff: Optional[int] = None) -
"""

# mask is np.intp type
mask = mask.astype(np.intp)
if diff is None:
return np.expand_dims(arr.ravel(order="F")[mask.ravel(order="F")], axis=-1) # compatibility, n => [n, 1]
else:
Expand Down

0 comments on commit 4437fe2

Please sign in to comment.