Skip to content

Commit

Permalink
Add more complex v4 methods and converters
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Sep 16, 2024
1 parent 86aa8da commit 03e65ad
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 24 deletions.
148 changes: 141 additions & 7 deletions python/akimbo_ip/accessor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ipaddress
import functools
from types import UnionType

import awkward as ak
import numpy as np
Expand All @@ -22,6 +23,11 @@ def match_ip6(arr):
return arr.is_regular and arr.size == 16 and arr.content.is_leaf and arr.content.dtype.itemsize == 1


def match_ip(arr):
"""matches either v4 or v6 IPs"""
return match_ip4(arr) or match_ip6(arr)


def match_prefix(arr):
"""A network prefix is always one byte"""
return arr.is_leaf and arr.dtype.itemsize == 1
Expand All @@ -38,12 +44,21 @@ def match_net4(arr, address="address", prefix="prefix"):


def match_net6(arr, address="address", prefix="prefix"):
"""Matches a record with IP6 field and prefix field (u8)"""
return (
arr.is_record
and {address, prefix}.issubset(arr.fields)
and match_ip6(arr[address])
and match_prefix(arr[prefix])
)


def match_list_net4(arr, address="address", prefix="prefix"):
"""Matches lists of ip4 network records"""
if arr.is_list:
cont = arr.content.content if arr.content.is_option else arr.content
return match_net4(cont)
return False


def match_stringlike(arr):
Expand Down Expand Up @@ -101,18 +116,103 @@ def contains4(nets, other, address="address", prefix="prefix"):


def hosts4(nets, address="address", prefix="prefix"):
arr = nets[address]
if arr.is_leaf:
arr = arr.data.astype("uint32")
else:
# fixed bytestring or 4 * uint8 regular
arr = arr.content.data.view("uint32")
arr, = to_ip4(nets[address])
ips, offsets = lib.hosts4(arr, nets[prefix].data.astype("uint8"))
return ak.contents.ListOffsetArray(
ak.index.Index64(offsets),
utils.u8_to_ip4(ips)
)

def network4(nets, address="address", prefix="prefix"):
arr, = to_ip4(nets[address])
out = lib.network4(arr, nets[prefix].data.astype("uint8"))
return utils.u8_to_ip4(out)


def broadcast4(nets, address="address", prefix="prefix"):
arr, = to_ip4(nets[address])
out = lib.broadcast4(arr, nets[prefix].data.astype("uint8"))
return utils.u8_to_ip4(out)


def hostmask4(nets, address="address", prefix="prefix"):
out = lib.hostmask4(nets[prefix].data.astype("uint8"))
return utils.u8_to_ip4(out)


def netmask4(nets, address="address", prefix="prefix"):
out = lib.netmask4(nets[prefix].data.astype("uint8"))
return utils.u8_to_ip4(out)


def trunc4(nets, address="address", prefix="prefix"):
arr, = to_ip4(nets[address])
out = lib.trunc4(arr, nets[prefix].data.astype("uint8"))
return ak.contents.RecordArray(
[utils.u8_to_ip4(out), nets[prefix]],
fields=[address, prefix]
)


def supernet4(nets, address="address", prefix="prefix"):
arr, = to_ip4(nets[address])
out = lib.supernet4(arr, nets[prefix].data.astype("uint8"))
return ak.contents.RecordArray(
[utils.u8_to_ip4(out), ak.contents.NumpyArray(nets[prefix].data - 1)],
fields=[address, prefix]
)


def subnets4(nets, new_prefix, address="address", prefix="prefix"):
arr, = to_ip4(nets[address])
out, offsets = lib.subnets4(arr, nets[prefix].data.astype("uint8"), new_prefix)
addr = utils.u8_to_ip4(out)
return ak.contents.ListOffsetArray(
ak.index.Index64(offsets),
ak.contents.RecordArray(
[addr,
ak.contents.NumpyArray(np.full((len(addr), ), new_prefix, dtype="uint8"))],
fields=[address, prefix]
),
)


def aggregate4(net_lists, address="address", prefix="prefix"):
offsets = net_lists.offsets.data.astype("uint64")
cont = net_lists.content.content if net_lists.content.is_option else net_lists.content
arr, = to_ip4(cont[address])
out_addr, out_pref, counts = lib.aggregate4(arr, offsets, cont[prefix].data)
# TODO: reassemble optional if input net_lists was list[optional[networks]]
return ak.contents.ListOffsetArray(
ak.index.Index64(counts),
ak.contents.RecordArray(
[utils.u8_to_ip4(out_addr), ak.contents.NumpyArray(out_pref)],
fields=[address, prefix]
)
)


def to_int_list(arr):
if (arr.is_leaf and arr.dtype.itemsize == 4):
out = ak.contents.RegularArray(
ak.contents.NumpyArray(arr.data.view('uint8')),
size=4
)
else:
out = ak.copy(arr)
out.parameters.pop('__array__')
return out


def to_bytestring(arr):
if (arr.is_leaf and arr.dtype.itemsize == 4):
out = utils.u8_to_ip4(arr)
else:
out = ak.copy(arr)
out.parameters['__array__'] = "bytestring"
out.content.parameters["__array__"] = "byte"
return out


def to_ip4(arr):
if arr.is_leaf:
Expand All @@ -121,7 +221,6 @@ def to_ip4(arr):
# bytestring or 4 * uint8 regular
return arr.content.data.view("uint32"),


def to_ip6(arr):
# always pass as bytes, and assume length is mod 16 in rust
return arr.content.data.view("uint8"),
Expand All @@ -139,6 +238,33 @@ class IPAccessor:
def __init__(self, accessor) -> None:
self.accessor = accessor

# TODO: bitwise_or and bitwise_and methods and their overrides
def __eq__(self, other):
arr = self.accessor.array
if isinstance(other, (str, int)):
arr2 = ak.Array([ipaddress.ip_address(other).packed])

return self.accessor.to_output(arr == arr2)
else:
raise ValueError

def bitwise_or(self, other):
raise NotImplemented("Will allow arr[ip] | mask")

__or__ = bitwise_or
def __ror__(self, value):
return self.__or__(value)

def bitwise_and(self, other):
raise NotImplemented("Will allow arr[ip] & mask")

__and__ = bitwise_and
def __rand__(self, value):
return self.__and__(value)

to_int_list = dec(to_int_list, inmode="ak", match=match_ip)
to_bytestring = dec(to_bytestring, inmode="ak", match=match_ip)

is_unspecified4 = dec_ip(lib.is_unspecified4)
is_broadcast4 = dec_ip(lib.is_broadcast4)
is_global4 = dec_ip(lib.is_global4)
Expand All @@ -155,6 +281,14 @@ def __init__(self, accessor) -> None:

parse_address4 = dec(parse_address4, inmode="ak", match=match_stringlike)
parse_net4 = dec(parse_net4, inmode="ak", match=match_stringlike)
network4 = dec(network4, inmode="ak", match=match_net4)
hostmask4 = dec(hostmask4, inmode="ak", match=match_net4)
netmask4 = dec(netmask4, inmode="ak", match=match_net4)
broadcast4 = dec(broadcast4, inmode="ak", match=match_net4)
trunc4 = dec(trunc4, inmode="ak", match=match_net4)
supernet4 = dec(supernet4, inmode="ak", match=match_net4)
subnets4 = dec(subnets4, inmode="ak", match=match_net4)
aggregate4 = dec(aggregate4, inmode="ak", match=match_list_net4)

contains4 = dec(contains4, inmode="ak", match=match_net4)

Expand Down
Loading

0 comments on commit 03e65ad

Please sign in to comment.