Skip to content

Commit

Permalink
Add masking
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Sep 16, 2024
1 parent 15ec473 commit 9361873
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
24 changes: 20 additions & 4 deletions python/akimbo_ip/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,24 @@ def func1(arr):
return dec(func1, match=match, outtype=outtype, inmode="awkward")


def bitwise_or(arr, other):
if isinstance(other, (str, int)):
other = ak.Array(np.array(list(ipaddress.ip_address("255.0.0.0").packed), dtype="uint8"))
out = (ak.without_parameters(arr) | ak.without_parameters(other)).layout
out.parameters["__array__"] = "bytestring"
out.content.parameters["__array__"] = "byte"
return out


def bitwise_and(arr, other):
if isinstance(other, (str, int)):
other = ak.Array(np.array(list(ipaddress.ip_address("255.0.0.0").packed), dtype="uint8"))
out = (ak.without_parameters(arr) | ak.without_parameters(other)).layout
out.parameters["__array__"] = "bytestring"
out.content.parameters["__array__"] = "byte"
return out


class IPAccessor:
def __init__(self, accessor) -> None:
self.accessor = accessor
Expand All @@ -248,15 +266,13 @@ def __eq__(self, other):
else:
raise ValueError

def bitwise_or(self, other):
raise NotImplemented("Will allow arr[ip] | mask")
bitwise_or = dec(bitwise_or, inmode="ak", match=match_ip)

__or__ = bitwise_or
def __ror__(self, value):
return self.__or__(value)

def bitwise_and(self, other):
raise NotImplemented("Will allow arr[ip] & mask")
bitwise_and = dec(bitwise_and, inmode="ak", match=match_ip)

__and__ = bitwise_and
def __rand__(self, value):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,13 @@ def test_inner_list_hosts():
# does not include gateway/broadcast
[b'\x01\x00\x00\x00', b'\x02\x00\x00\x00', b'\x03\x00\x00\x00', b'\x04\x00\x00\x00', b'\x05\x00\x00\x00', b'\x06\x00\x00\x00']
]


def test_masks():
s = pd.Series(["7.7.7.7", "8.8.8.8"]).ak.ip.parse_address4()
out1 = s.ak.ip | s.ak.array[:1]
assert out1.ak.ip.to_int_list().tolist() == [[7, 7, 7, 7], [15, 15, 15, 15]]

out2 = s.ak.ip | "255.0.0.0"
assert out2.ak.ip.to_int_list().tolist() == [[255, 7, 7, 7], [255, 8, 8, 8]]

0 comments on commit 9361873

Please sign in to comment.