Skip to content

Commit

Permalink
Be more careful with np copies
Browse files Browse the repository at this point in the history
  • Loading branch information
jmeyers314 committed Jun 6, 2024
1 parent 07edb51 commit 2a0af07
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
2 changes: 1 addition & 1 deletion batoid/coordSys.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def toLocal(self, v):
vv : ndarray of float, shape (n, 3)
Vector in local coordinates.
"""
v = np.array(v, dtype=float)
v = np.array(v, dtype=float, copy=True)
v -= self.origin
return (self.rot.T@v.T).T

Expand Down
4 changes: 2 additions & 2 deletions batoid/coordTransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def applyForwardArray(self, x, y, z):
Unlike applyForward, this method does not transform in-place, but
returns a newly created ndarray.
"""
r = np.array([x, y, z], dtype=float).T
r = np.array([x, y, z], dtype=float, copy=True).T
r -= self.dr
return self.drot.T@r.T

Expand All @@ -107,7 +107,7 @@ def applyReverseArray(self, x, y, z):
Unlike applyReverse, this method does not transform in-place, but
returns a newly created ndarray.
"""
r = np.array([x, y, z], dtype=float)
r = np.array([x, y, z], dtype=float, copy=True)
r = (self.drot@r).T
r += self.dr
return r.T
Expand Down
6 changes: 4 additions & 2 deletions batoid/rayVector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _reshape_arrays(arrays, shape, dtype=float):
array = arrays[i]
if not hasattr(array, 'shape') or array.shape != shape:
arrays[i] = np.array(np.broadcast_to(array, shape))
arrays[i] = np.ascontiguousarray(arrays[i], dtype=dtype)
arrays[i] = np.array(arrays[i], dtype=dtype, copy=True, order='C')
return arrays


Expand Down Expand Up @@ -50,6 +50,8 @@ def __init__(
shape = np.broadcast(
x, y, z, vx, vy, vz, t, wavelength, flux, vignetted, failed
).shape
if shape == ():
shape = (1,)
x, y, z, vx, vy, vz, t, wavelength, flux = _reshape_arrays(
[x, y, z, vx, vy, vz, t, wavelength, flux],
shape
Expand Down Expand Up @@ -819,7 +821,7 @@ def _finish(
if isinstance(flux, Real):
flux = np.full(len(x), float(flux))
if source is None:
vv = np.array(dirCos, dtype=float)
vv = np.array(dirCos, dtype=float, copy=True)
vv /= n*np.sqrt(np.dot(vv, vv))
zhat = -n*vv
xhat = np.cross(np.array([1.0, 0.0, 0.0]), zhat)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_RayVector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ def test_properties():
)

rv = batoid.RayVector(x, y, z, vx, vy, vz, t, w, fx, vig, fa, cs)
assert x is not rv.x
assert y is not rv.y
assert z is not rv.z
assert vx is not rv.vx
assert vy is not rv.vy
assert vz is not rv.vz
assert t is not rv.t
assert w is not rv.wavelength
assert fx is not rv.flux
assert vig is not rv.vignetted
assert fa is not rv.failed

np.testing.assert_array_equal(rv.x, x)
np.testing.assert_array_equal(rv.y, y)
Expand Down Expand Up @@ -1101,4 +1112,4 @@ def test_fromFieldAngles():
test_factory_optic()
test_getitem()
test_fromStop()
test_fromFieldAngles()
test_fromFieldAngles()

0 comments on commit 2a0af07

Please sign in to comment.