Skip to content

Commit

Permalink
feat: port the performance improvements of fastbin back to pycasbin
Browse files Browse the repository at this point in the history
Signed-off-by: terry-xuan-gao <[email protected]>
  • Loading branch information
terry-xuan-gao committed Feb 5, 2023
1 parent 353f919 commit 2f41a72
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 8 deletions.
25 changes: 21 additions & 4 deletions casbin/core_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

import logging
import copy
from typing import Sequence

from casbin.effect import Effector, get_effector, effect_to_bool
from casbin.model import Model, FunctionMap
from casbin.model import Model, FastModel, FunctionMap, filter_policy
from casbin.persist import Adapter
from casbin.persist.adapters import FileAdapter
from casbin.rbac import default_role_manager
from casbin.util import generate_g_function, SimpleEval, util

# from casbin.model.policy import filter_policy


class EnforceContext:
"""
Expand Down Expand Up @@ -51,10 +54,13 @@ class CoreEnforcer:
auto_save = False
auto_build_role_links = False

def __init__(self, model=None, adapter=None):
_cache_key_order: Sequence[int] = None

def __init__(self, model=None, adapter=None, cache_key_order: Sequence[int] = None):
self.logger = logging.getLogger(__name__)
# if want to see more detail logs, change log level to info or debug
self.logger.setLevel(logging.WARNING)
CoreEnforcer._cache_key_order = cache_key_order
if isinstance(model, str):
if isinstance(adapter, str):
self.init_with_file(model, adapter)
Expand Down Expand Up @@ -113,7 +119,11 @@ def _initialize(self):
def new_model(path="", text=""):
"""creates a model."""

m = Model()
if CoreEnforcer._cache_key_order == None:
m = Model()
else:
m = FastModel(CoreEnforcer._cache_key_order)

if len(path) > 0:
m.load_model(path)
else:
Expand Down Expand Up @@ -320,7 +330,14 @@ def enforce(self, *rvals):
"""decides whether a "subject" can access a "object" with the operation "action",
input parameters are usually: (sub, obj, act).
"""
result, _ = self.enforce_ex(*rvals)

if CoreEnforcer._cache_key_order == None:
result, _ = self.enforce_ex(*rvals)
else:
keys = [rvals[x] for x in self._cache_key_order]
with filter_policy(self.model.model["p"]["p"].policy, *keys):
result, _ = self.enforce_ex(*rvals)

return result

def enforce_ex(self, *rvals):
Expand Down
4 changes: 2 additions & 2 deletions casbin/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
# limitations under the License.

from .assertion import Assertion
from .model import Model
from .policy import Policy
from .model import Model, FastModel
from .policy import Policy, FilterablePolicy, filter_policy
from .function import FunctionMap
21 changes: 20 additions & 1 deletion casbin/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from . import Assertion
from casbin import util, config
from .policy import Policy
from .policy import Policy, FilterablePolicy
from typing import Any, Sequence

DEFAULT_DOMAIN = ""
DEFAULT_SEPARATOR = "::"
Expand Down Expand Up @@ -207,3 +208,21 @@ def write_string(sec):
s[-1] = s[-1].strip()

return "".join(s)


class FastModel(Model):
_cache_key_order: Sequence[int]

def __init__(self, cache_key_order: Sequence[int]) -> None:
super().__init__()
self._cache_key_order = cache_key_order

def add_def(self, sec: str, key: str, value: Any) -> None:
super().add_def(sec, key, value)
if sec == "p" and key == "p":
self.model[sec][key].policy = FilterablePolicy(self._cache_key_order)

def clear_policy(self) -> None:
"""clears all current policy."""
super().clear_policy()
self.model["p"]["p"].policy = FilterablePolicy(self._cache_key_order)
88 changes: 88 additions & 0 deletions casbin/model/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import logging

from contextlib import contextmanager
from typing import Any, Container, Dict, Iterable, Iterator, Optional, Sequence, Set, cast

DEFAULT_SEP = ","


Expand Down Expand Up @@ -289,3 +292,88 @@ def get_values_for_field_in_policy(self, sec, ptype, field_index):
values.append(value)

return values


def in_cache(cache: Dict[str, Any], keys: Sequence[str]) -> Optional[Set[Sequence[str]]]:
if keys[0] in cache:
if len(keys) > 1:
return in_cache(cache[keys[-0]], keys[1:])
return cast(Set[Sequence[str]], cache[keys[0]])
else:
return None


class FilterablePolicy(Container[Sequence[str]]):
_cache: Dict[str, Any]
_current_filter: Optional[Set[Sequence[str]]]
_cache_key_order: Sequence[int]

def __init__(self, cache_key_order: Sequence[int]) -> None:
self._cache = {}
self._current_filter = None
self._cache_key_order = cache_key_order

def __iter__(self) -> Iterator[Sequence[str]]:
yield from self.__get_policy()

def __len__(self) -> int:
return len(list(self.__get_policy()))

def __contains__(self, item: object) -> bool:
if not isinstance(item, (list, tuple)) or len(self._cache_key_order) >= len(item):
return False
keys = [item[x] for x in self._cache_key_order]
exists = in_cache(self._cache, keys)
if not exists:
return False
return tuple(item) in exists

def __getitem__(self, item: int) -> Sequence[str]:
for i, entry in enumerate(self):
if i == item:
return entry
raise KeyError("No such value exists")

def append(self, item: Sequence[str]) -> None:
cache = self._cache
keys = [item[x] for x in self._cache_key_order]

for key in keys[:-1]:
if key not in cache:
cache[key] = dict()
cache = cache[key]
if keys[-1] not in cache:
cache[keys[-1]] = set()

cache[keys[-1]].add(tuple(item))

def remove(self, policy: Sequence[str]) -> bool:
keys = [policy[x] for x in self._cache_key_order]
exists = in_cache(self._cache, keys)
if not exists:
return True

exists.remove(tuple(policy))
return True

def __get_policy(self) -> Iterable[Sequence[str]]:
if self._current_filter is not None:
return (list(x) for x in self._current_filter)
else:
return (list(v2) for v in self._cache.values() for v1 in v.values() for v2 in v1)

def apply_filter(self, *keys: str) -> None:
value = in_cache(self._cache, keys)
self._current_filter = value or set()

def clear_filter(self) -> None:
self._current_filter = None


@contextmanager
def filter_policy(policy: FilterablePolicy, *keys: str) -> Iterator[None]:
try:
policy.apply_filter(*keys)
yield
finally:
policy.clear_filter()
85 changes: 85 additions & 0 deletions tests/model/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from casbin.model import Model
from tests.test_enforcer import get_examples

from casbin.model.policy import FilterablePolicy, filter_policy

class TestPolicy(TestCase):
def test_get_policy(self):
Expand Down Expand Up @@ -172,3 +173,87 @@ def test_remove_filtered_policy(self):

res = m.remove_filtered_policy("p", "p", 1, "domain1", "data1")
self.assertFalse(res)


class TestFilterablePolicy:
def test_able_to_add_rules(self) -> None:
policy = FilterablePolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert list(policy) == [["sub", "obj", "read"]]

def test_does_not_add_duplicates(self) -> None:
policy = FilterablePolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read"])

assert list(policy) == [["sub", "obj", "read"]]

def test_can_remove_rules(self) -> None:
policy = FilterablePolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.remove(["sub", "obj", "read"])

assert list(policy) == []

def test_returns_lengtt(self) -> None:
policy = FilterablePolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert len(policy) == 1

def test_supports_in_keyword(self) -> None:
policy = FilterablePolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert ["sub", "obj", "read"] in policy

def test_supports_filters(self) -> None:
policy = FilterablePolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

policy.apply_filter("read2", "obj2")

assert list(policy) == [["sub", "obj2", "read2"]]

def test_clears_filters(self) -> None:
policy = FilterablePolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

policy.apply_filter("read2", "obj2")
policy.clear_filter()

assert list(policy) == [
["sub", "obj", "read"],
["sub", "obj", "read2"],
["sub", "obj2", "read2"],
]


class TestContextManager:
def test_filters_policy(self) -> None:
policy = FilterablePolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

with filter_policy(policy, "read2", "obj2"):
assert list(policy) == [["sub", "obj2", "read2"]]

assert list(policy) == [
["sub", "obj", "read"],
["sub", "obj", "read2"],
["sub", "obj2", "read2"],
]
65 changes: 64 additions & 1 deletion tests/test_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import time
from unittest import TestCase
from typing import Sequence

import casbin

Expand All @@ -31,10 +32,11 @@ def __init__(self, name, age):


class TestCaseBase(TestCase):
def get_enforcer(self, model=None, adapter=None):
def get_enforcer(self, model=None, adapter=None, cache_key_order: Sequence[int] = None):
return casbin.Enforcer(
model,
adapter,
cache_key_order,
)


Expand Down Expand Up @@ -433,3 +435,64 @@ def test_auto_loading_policy(self):
# thread needs a moment to exit
time.sleep(10 / 1000)
self.assertFalse(e.is_auto_loading_running())


class TestFastEnforcer(TestCaseBase):
def test_creates_proper_policy(self) -> None:
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
[2, 1],
)

assert isinstance(e.model.model["p"]["p"].policy, casbin.model.FilterablePolicy)

def test_initializes_model(self) -> None:
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
[2, 1],
)

assert list(e.model.model["p"]["p"].policy) == [
["alice", "data1", "read"],
["bob", "data2", "write"],
]

def test_able_to_clear_policy(self) -> None:
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
[2, 1],
)

e.clear_policy()

assert isinstance(e.model.model["p"]["p"].policy, casbin.FilterablePolicy)
assert list(e.model.model["p"]["p"].policy) == []

def test_able_to_enforce_rule(self) -> None:
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
[2, 1],
)

assert e.enforce("alice", "data1", "read")
assert not e.enforce("alice2", "data1", "read")

def test_speed_of_fast_enforcer(self) -> None:
e = self.get_enforcer(
get_examples("basic_model.conf"),
get_examples("basic_policy.csv"),
[2, 1],
)

for x in range(100):
for y in range(1000):
e.add_policy(f"/user{x}", f"/obj{y}", "read")

s = time.time()
# this is the absolute worst case last entry and should require iterating 10M rows and be very slow
a = e.enforce("/user99", "/obj999", "read")
print(a, (time.time() - s) * 1000)

0 comments on commit 2f41a72

Please sign in to comment.