Skip to content

Commit

Permalink
Implement numpy.clip (#1839)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun authored Dec 30, 2024
1 parent 1b25eb7 commit b36142b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
16 changes: 14 additions & 2 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _numpy_full(pv: ProgramVisitor,
"""
if isinstance(shape, Number) or symbolic.issymbolic(shape):
shape = [shape]

is_data = False
if isinstance(fill_value, (Number, np.bool_)):
vtype = dtypes.dtype_to_typeclass(type(fill_value))
Expand Down Expand Up @@ -587,7 +587,7 @@ def _arange(pv: ProgramVisitor,

if any(not isinstance(s, Number) for s in [start, stop, step]):
if step == 1: # Common case where ceiling is not necessary
shape = (stop - start,)
shape = (stop - start, )
else:
shape = (symbolic.int_ceil(stop - start, step), )
else:
Expand Down Expand Up @@ -1064,6 +1064,17 @@ def _min(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):
identity=dtypes.max_value(sdfg.arrays[a].dtype))


@oprepo.replaces('numpy.clip')
def _clip(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a, a_min=None, a_max=None, **kwargs):
if a_min is None and a_max is None:
raise ValueError("clip() requires at least one of `a_min` or `a_max`")
if a_min is None:
return implement_ufunc(pv, None, sdfg, state, 'minimum', [a, a_max], kwargs)[0]
if a_max is None:
return implement_ufunc(pv, None, sdfg, state, 'maximum', [a, a_min], kwargs)[0]
return implement_ufunc(pv, None, sdfg, state, 'clip', [a, a_min, a_max], kwargs)[0]


def _minmax2(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, b: str, ismin=True):
""" Implements the min or max function with 2 scalar arguments. """

Expand Down Expand Up @@ -5321,6 +5332,7 @@ def _vsplit(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, ary: str,
############################################################################################################
# Fast Fourier Transform numpy package (numpy.fft)


def _real_to_complex(real_type: dace.typeclass):
if real_type == dace.float32:
return dace.complex64
Expand Down
12 changes: 12 additions & 0 deletions tests/numpy/ufunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,16 @@ def test_ufunc_clip(A: dace.float32[10]):
return np.clip(A, 0.2, 0.5)


@compare_numpy_output()
def test_ufunc_clip_min(A: dace.float32[10]):
return np.clip(A, 0.2, None)


@compare_numpy_output()
def test_ufunc_clip_max(A: dace.float32[10]):
return np.clip(A, None, a_max=0.5)


if __name__ == "__main__":
test_ufunc_add_ff()
test_ufunc_subtract_ff()
Expand Down Expand Up @@ -1523,3 +1533,5 @@ def test_ufunc_clip(A: dace.float32[10]):
test_ufunc_trunc_f()
test_ufunc_trunc_u()
test_ufunc_clip()
test_ufunc_clip_min()
test_ufunc_clip_max()

0 comments on commit b36142b

Please sign in to comment.