From a7ed0bb9f4992b6d20a44792ff9aee0fc9f53235 Mon Sep 17 00:00:00 2001 From: terry-xuan-gao Date: Sun, 5 Feb 2023 20:01:54 +0800 Subject: [PATCH] add: policy_fast.py, model_fast.py, and their tests Signed-off-by: terry-xuan-gao --- casbin/core_enforcer.py | 4 +- casbin/model/__init__.py | 6 +- casbin/model/model.py | 22 +------ casbin/model/model_fast.py | 35 +++++++++++ casbin/model/policy.py | 88 ---------------------------- casbin/model/policy_fast.py | 101 ++++++++++++++++++++++++++++++++ tests/__init__.py | 1 + tests/model/__init__.py | 1 + tests/model/test_policy.py | 86 +-------------------------- tests/model/test_policy_fast.py | 101 ++++++++++++++++++++++++++++++++ tests/test_enforcer.py | 61 ------------------- tests/test_enforcer_fast.py | 79 +++++++++++++++++++++++++ 12 files changed, 327 insertions(+), 258 deletions(-) create mode 100644 casbin/model/model_fast.py create mode 100644 casbin/model/policy_fast.py create mode 100644 tests/model/test_policy_fast.py create mode 100644 tests/test_enforcer_fast.py diff --git a/casbin/core_enforcer.py b/casbin/core_enforcer.py index 033243ca..13f9efae 100644 --- a/casbin/core_enforcer.py +++ b/casbin/core_enforcer.py @@ -17,7 +17,7 @@ from typing import Sequence from casbin.effect import Effector, get_effector, effect_to_bool -from casbin.model import Model, FastModel, FunctionMap, filter_policy +from casbin.model import Model, FastModel, FunctionMap, fast_policy_filter from casbin.persist import Adapter from casbin.persist.adapters import FileAdapter from casbin.rbac import default_role_manager @@ -333,7 +333,7 @@ def enforce(self, *rvals): result, _ = self.enforce_ex(*rvals) else: keys = [rvals[x] for x in self._cache_key_order] - with filter_policy(self.model.model["p"]["p"].policy, *keys): + with fast_policy_filter(self.model.model["p"]["p"].policy, *keys): result, _ = self.enforce_ex(*rvals) return result diff --git a/casbin/model/__init__.py b/casbin/model/__init__.py index 0606f50e..7667b1c0 100644 --- a/casbin/model/__init__.py +++ b/casbin/model/__init__.py @@ -13,6 +13,8 @@ # limitations under the License. from .assertion import Assertion -from .model import Model, FastModel -from .policy import Policy, FilterablePolicy, filter_policy +from .model import Model +from .model_fast import FastModel +from .policy import Policy +from .policy_fast import FastPolicy, fast_policy_filter from .function import FunctionMap diff --git a/casbin/model/model.py b/casbin/model/model.py index 1ce35e3c..3162b433 100644 --- a/casbin/model/model.py +++ b/casbin/model/model.py @@ -14,8 +14,8 @@ from . import Assertion from casbin import util, config -from .policy import Policy, FilterablePolicy -from typing import Any, Sequence +from .policy import Policy + DEFAULT_DOMAIN = "" DEFAULT_SEPARATOR = "::" @@ -208,21 +208,3 @@ def write_string(sec): s[-1] = s[-1].strip() return "".join(s) - - -class FastModel(Model): - _cache_key_order: Sequence[int] - - def __init__(self, cache_key_order: Sequence[int]) -> None: - super().__init__() - self._cache_key_order = cache_key_order - - def add_def(self, sec: str, key: str, value: Any) -> None: - super().add_def(sec, key, value) - if sec == "p" and key == "p": - self.model[sec][key].policy = FilterablePolicy(self._cache_key_order) - - def clear_policy(self) -> None: - """clears all current policy.""" - super().clear_policy() - self.model["p"]["p"].policy = FilterablePolicy(self._cache_key_order) diff --git a/casbin/model/model_fast.py b/casbin/model/model_fast.py new file mode 100644 index 00000000..7f8c9d5d --- /dev/null +++ b/casbin/model/model_fast.py @@ -0,0 +1,35 @@ +# Copyright 2021 The casbin Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .policy_fast import FastPolicy +from typing import Any, Sequence +from .model import Model + + +class FastModel(Model): + _cache_key_order: Sequence[int] + + def __init__(self, cache_key_order: Sequence[int]) -> None: + super().__init__() + self._cache_key_order = cache_key_order + + def add_def(self, sec: str, key: str, value: Any) -> None: + super().add_def(sec, key, value) + if sec == "p" and key == "p": + self.model[sec][key].policy = FastPolicy(self._cache_key_order) + + def clear_policy(self) -> None: + """clears all current policy.""" + super().clear_policy() + self.model["p"]["p"].policy = FastPolicy(self._cache_key_order) diff --git a/casbin/model/policy.py b/casbin/model/policy.py index 629b27b7..27a6a19d 100644 --- a/casbin/model/policy.py +++ b/casbin/model/policy.py @@ -14,9 +14,6 @@ import logging -from contextlib import contextmanager -from typing import Any, Container, Dict, Iterable, Iterator, Optional, Sequence, Set, cast - DEFAULT_SEP = "," @@ -292,88 +289,3 @@ def get_values_for_field_in_policy(self, sec, ptype, field_index): values.append(value) return values - - -def in_cache(cache: Dict[str, Any], keys: Sequence[str]) -> Optional[Set[Sequence[str]]]: - if keys[0] in cache: - if len(keys) > 1: - return in_cache(cache[keys[-0]], keys[1:]) - return cast(Set[Sequence[str]], cache[keys[0]]) - else: - return None - - -class FilterablePolicy(Container[Sequence[str]]): - _cache: Dict[str, Any] - _current_filter: Optional[Set[Sequence[str]]] - _cache_key_order: Sequence[int] - - def __init__(self, cache_key_order: Sequence[int]) -> None: - self._cache = {} - self._current_filter = None - self._cache_key_order = cache_key_order - - def __iter__(self) -> Iterator[Sequence[str]]: - yield from self.__get_policy() - - def __len__(self) -> int: - return len(list(self.__get_policy())) - - def __contains__(self, item: object) -> bool: - if not isinstance(item, (list, tuple)) or len(self._cache_key_order) >= len(item): - return False - keys = [item[x] for x in self._cache_key_order] - exists = in_cache(self._cache, keys) - if not exists: - return False - return tuple(item) in exists - - def __getitem__(self, item: int) -> Sequence[str]: - for i, entry in enumerate(self): - if i == item: - return entry - raise KeyError("No such value exists") - - def append(self, item: Sequence[str]) -> None: - cache = self._cache - keys = [item[x] for x in self._cache_key_order] - - for key in keys[:-1]: - if key not in cache: - cache[key] = dict() - cache = cache[key] - if keys[-1] not in cache: - cache[keys[-1]] = set() - - cache[keys[-1]].add(tuple(item)) - - def remove(self, policy: Sequence[str]) -> bool: - keys = [policy[x] for x in self._cache_key_order] - exists = in_cache(self._cache, keys) - if not exists: - return True - - exists.remove(tuple(policy)) - return True - - def __get_policy(self) -> Iterable[Sequence[str]]: - if self._current_filter is not None: - return (list(x) for x in self._current_filter) - else: - return (list(v2) for v in self._cache.values() for v1 in v.values() for v2 in v1) - - def apply_filter(self, *keys: str) -> None: - value = in_cache(self._cache, keys) - self._current_filter = value or set() - - def clear_filter(self) -> None: - self._current_filter = None - - -@contextmanager -def filter_policy(policy: FilterablePolicy, *keys: str) -> Iterator[None]: - try: - policy.apply_filter(*keys) - yield - finally: - policy.clear_filter() diff --git a/casbin/model/policy_fast.py b/casbin/model/policy_fast.py new file mode 100644 index 00000000..bdd918ca --- /dev/null +++ b/casbin/model/policy_fast.py @@ -0,0 +1,101 @@ +# Copyright 2021 The casbin Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from typing import Any, Container, Dict, Iterable, Iterator, Optional, Sequence, Set, cast + + +def in_cache(cache: Dict[str, Any], keys: Sequence[str]) -> Optional[Set[Sequence[str]]]: + if keys[0] in cache: + if len(keys) > 1: + return in_cache(cache[keys[-0]], keys[1:]) + return cast(Set[Sequence[str]], cache[keys[0]]) + else: + return None + + +class FastPolicy(Container[Sequence[str]]): + _cache: Dict[str, Any] + _current_filter: Optional[Set[Sequence[str]]] + _cache_key_order: Sequence[int] + + def __init__(self, cache_key_order: Sequence[int]) -> None: + self._cache = {} + self._current_filter = None + self._cache_key_order = cache_key_order + + def __iter__(self) -> Iterator[Sequence[str]]: + yield from self.__get_policy() + + def __len__(self) -> int: + return len(list(self.__get_policy())) + + def __contains__(self, item: object) -> bool: + if not isinstance(item, (list, tuple)) or len(self._cache_key_order) >= len(item): + return False + keys = [item[x] for x in self._cache_key_order] + exists = in_cache(self._cache, keys) + if not exists: + return False + return tuple(item) in exists + + def __getitem__(self, item: int) -> Sequence[str]: + for i, entry in enumerate(self): + if i == item: + return entry + raise KeyError("No such value exists") + + def append(self, item: Sequence[str]) -> None: + cache = self._cache + keys = [item[x] for x in self._cache_key_order] + + for key in keys[:-1]: + if key not in cache: + cache[key] = dict() + cache = cache[key] + if keys[-1] not in cache: + cache[keys[-1]] = set() + + cache[keys[-1]].add(tuple(item)) + + def remove(self, policy: Sequence[str]) -> bool: + keys = [policy[x] for x in self._cache_key_order] + exists = in_cache(self._cache, keys) + if not exists: + return True + + exists.remove(tuple(policy)) + return True + + def __get_policy(self) -> Iterable[Sequence[str]]: + if self._current_filter is not None: + return (list(x) for x in self._current_filter) + else: + return (list(v2) for v in self._cache.values() for v1 in v.values() for v2 in v1) + + def apply_filter(self, *keys: str) -> None: + value = in_cache(self._cache, keys) + self._current_filter = value or set() + + def clear_filter(self) -> None: + self._current_filter = None + + +@contextmanager +def fast_policy_filter(policy: FastPolicy, *keys: str) -> Iterator[None]: + try: + policy.apply_filter(*keys) + yield + finally: + policy.clear_filter() diff --git a/tests/__init__.py b/tests/__init__.py index e4493206..e11fc17a 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,6 +18,7 @@ from .test_frontend import TestFrontend from .test_management_api import TestManagementApi, TestManagementApiSynced from .test_rbac_api import TestRbacApi, TestRbacApiSynced +from .test_enforcer_fast import TestFastEnforcer from . import benchmarks from . import config from . import model diff --git a/tests/model/__init__.py b/tests/model/__init__.py index 528fed7a..900ee454 100644 --- a/tests/model/__init__.py +++ b/tests/model/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .test_policy import TestPolicy +from .test_policy_fast import TestContextManager, TestFastPolicy diff --git a/tests/model/test_policy.py b/tests/model/test_policy.py index 9ac8d5d0..855eb841 100644 --- a/tests/model/test_policy.py +++ b/tests/model/test_policy.py @@ -14,7 +14,7 @@ from unittest import TestCase -from casbin.model import Model, FilterablePolicy, filter_policy +from casbin.model import Model from tests.test_enforcer import get_examples @@ -172,87 +172,3 @@ def test_remove_filtered_policy(self): res = m.remove_filtered_policy("p", "p", 1, "domain1", "data1") self.assertFalse(res) - - -class TestFilterablePolicy: - def test_able_to_add_rules(self) -> None: - policy = FilterablePolicy([2, 1]) - - policy.append(["sub", "obj", "read"]) - - assert list(policy) == [["sub", "obj", "read"]] - - def test_does_not_add_duplicates(self) -> None: - policy = FilterablePolicy([2, 1]) - - policy.append(["sub", "obj", "read"]) - policy.append(["sub", "obj", "read"]) - - assert list(policy) == [["sub", "obj", "read"]] - - def test_can_remove_rules(self) -> None: - policy = FilterablePolicy([2, 1]) - - policy.append(["sub", "obj", "read"]) - policy.remove(["sub", "obj", "read"]) - - assert list(policy) == [] - - def test_returns_lengtt(self) -> None: - policy = FilterablePolicy([2, 1]) - - policy.append(["sub", "obj", "read"]) - - assert len(policy) == 1 - - def test_supports_in_keyword(self) -> None: - policy = FilterablePolicy([2, 1]) - - policy.append(["sub", "obj", "read"]) - - assert ["sub", "obj", "read"] in policy - - def test_supports_filters(self) -> None: - policy = FilterablePolicy([2, 1]) - - policy.append(["sub", "obj", "read"]) - policy.append(["sub", "obj", "read2"]) - policy.append(["sub", "obj2", "read2"]) - - policy.apply_filter("read2", "obj2") - - assert list(policy) == [["sub", "obj2", "read2"]] - - def test_clears_filters(self) -> None: - policy = FilterablePolicy([2, 1]) - - policy.append(["sub", "obj", "read"]) - policy.append(["sub", "obj", "read2"]) - policy.append(["sub", "obj2", "read2"]) - - policy.apply_filter("read2", "obj2") - policy.clear_filter() - - assert list(policy) == [ - ["sub", "obj", "read"], - ["sub", "obj", "read2"], - ["sub", "obj2", "read2"], - ] - - -class TestContextManager: - def test_filters_policy(self) -> None: - policy = FilterablePolicy([2, 1]) - - policy.append(["sub", "obj", "read"]) - policy.append(["sub", "obj", "read2"]) - policy.append(["sub", "obj2", "read2"]) - - with filter_policy(policy, "read2", "obj2"): - assert list(policy) == [["sub", "obj2", "read2"]] - - assert list(policy) == [ - ["sub", "obj", "read"], - ["sub", "obj", "read2"], - ["sub", "obj2", "read2"], - ] diff --git a/tests/model/test_policy_fast.py b/tests/model/test_policy_fast.py new file mode 100644 index 00000000..c876f9f9 --- /dev/null +++ b/tests/model/test_policy_fast.py @@ -0,0 +1,101 @@ +# Copyright 2021 The casbin Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import TestCase + +from casbin.model import FastPolicy, fast_policy_filter + + +class TestFastPolicy(TestCase): + def test_able_to_add_rules(self) -> None: + policy = FastPolicy([2, 1]) + + policy.append(["sub", "obj", "read"]) + + assert list(policy) == [["sub", "obj", "read"]] + + def test_does_not_add_duplicates(self) -> None: + policy = FastPolicy([2, 1]) + + policy.append(["sub", "obj", "read"]) + policy.append(["sub", "obj", "read"]) + + assert list(policy) == [["sub", "obj", "read"]] + + def test_can_remove_rules(self) -> None: + policy = FastPolicy([2, 1]) + + policy.append(["sub", "obj", "read"]) + policy.remove(["sub", "obj", "read"]) + + assert list(policy) == [] + + def test_returns_lengtt(self) -> None: + policy = FastPolicy([2, 1]) + + policy.append(["sub", "obj", "read"]) + + assert len(policy) == 1 + + def test_supports_in_keyword(self) -> None: + policy = FastPolicy([2, 1]) + + policy.append(["sub", "obj", "read"]) + + assert ["sub", "obj", "read"] in policy + + def test_supports_filters(self) -> None: + policy = FastPolicy([2, 1]) + + policy.append(["sub", "obj", "read"]) + policy.append(["sub", "obj", "read2"]) + policy.append(["sub", "obj2", "read2"]) + + policy.apply_filter("read2", "obj2") + + assert list(policy) == [["sub", "obj2", "read2"]] + + def test_clears_filters(self) -> None: + policy = FastPolicy([2, 1]) + + policy.append(["sub", "obj", "read"]) + policy.append(["sub", "obj", "read2"]) + policy.append(["sub", "obj2", "read2"]) + + policy.apply_filter("read2", "obj2") + policy.clear_filter() + + assert list(policy) == [ + ["sub", "obj", "read"], + ["sub", "obj", "read2"], + ["sub", "obj2", "read2"], + ] + + +class TestContextManager: + def test_fast_policy_filter(self) -> None: + policy = FastPolicy([2, 1]) + + policy.append(["sub", "obj", "read"]) + policy.append(["sub", "obj", "read2"]) + policy.append(["sub", "obj2", "read2"]) + + with fast_policy_filter(policy, "read2", "obj2"): + assert list(policy) == [["sub", "obj2", "read2"]] + + assert list(policy) == [ + ["sub", "obj", "read"], + ["sub", "obj", "read2"], + ["sub", "obj2", "read2"], + ] diff --git a/tests/test_enforcer.py b/tests/test_enforcer.py index 4ca7c6a2..1dc83100 100644 --- a/tests/test_enforcer.py +++ b/tests/test_enforcer.py @@ -435,64 +435,3 @@ def test_auto_loading_policy(self): # thread needs a moment to exit time.sleep(10 / 1000) self.assertFalse(e.is_auto_loading_running()) - - -class TestFastEnforcer(TestCaseBase): - def test_creates_proper_policy(self) -> None: - e = self.get_enforcer( - get_examples("basic_model.conf"), - get_examples("basic_policy.csv"), - [2, 1], - ) - - assert isinstance(e.model.model["p"]["p"].policy, casbin.model.FilterablePolicy) - - def test_initializes_model(self) -> None: - e = self.get_enforcer( - get_examples("basic_model.conf"), - get_examples("basic_policy.csv"), - [2, 1], - ) - - assert list(e.model.model["p"]["p"].policy) == [ - ["alice", "data1", "read"], - ["bob", "data2", "write"], - ] - - def test_able_to_clear_policy(self) -> None: - e = self.get_enforcer( - get_examples("basic_model.conf"), - get_examples("basic_policy.csv"), - [2, 1], - ) - - e.clear_policy() - - assert isinstance(e.model.model["p"]["p"].policy, casbin.FilterablePolicy) - assert list(e.model.model["p"]["p"].policy) == [] - - def test_able_to_enforce_rule(self) -> None: - e = self.get_enforcer( - get_examples("basic_model.conf"), - get_examples("basic_policy.csv"), - [2, 1], - ) - - assert e.enforce("alice", "data1", "read") - assert not e.enforce("alice2", "data1", "read") - - def test_speed_of_fast_enforcer(self) -> None: - e = self.get_enforcer( - get_examples("basic_model.conf"), - get_examples("basic_policy.csv"), - [2, 1], - ) - - for x in range(100): - for y in range(1000): - e.add_policy(f"/user{x}", f"/obj{y}", "read") - - s = time.time() - # this is the absolute worst case last entry and should require iterating 10M rows and be very slow - a = e.enforce("/user99", "/obj999", "read") - print(a, (time.time() - s) * 1000) diff --git a/tests/test_enforcer_fast.py b/tests/test_enforcer_fast.py new file mode 100644 index 00000000..1a7eae4c --- /dev/null +++ b/tests/test_enforcer_fast.py @@ -0,0 +1,79 @@ +# Copyright 2021 The casbin Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from unittest import TestCase +from typing import Sequence + +import casbin + + +def get_examples(path): + examples_path = os.path.split(os.path.realpath(__file__))[0] + "/../examples/" + return os.path.abspath(examples_path + path) + + +class TestCaseBase(TestCase): + def get_enforcer(self, model=None, adapter=None, cache_key_order: Sequence[int] = None): + return casbin.Enforcer( + model, + adapter, + cache_key_order, + ) + + +class TestFastEnforcer(TestCaseBase): + def test_creates_proper_policy(self) -> None: + e = self.get_enforcer( + get_examples("basic_model.conf"), + get_examples("basic_policy.csv"), + [2, 1], + ) + + assert isinstance(e.model.model["p"]["p"].policy, casbin.FastPolicy) + + def test_initializes_model(self) -> None: + e = self.get_enforcer( + get_examples("basic_model.conf"), + get_examples("basic_policy.csv"), + [2, 1], + ) + + assert list(e.model.model["p"]["p"].policy) == [ + ["alice", "data1", "read"], + ["bob", "data2", "write"], + ] + + def test_able_to_clear_policy(self) -> None: + e = self.get_enforcer( + get_examples("basic_model.conf"), + get_examples("basic_policy.csv"), + [2, 1], + ) + + e.clear_policy() + + assert isinstance(e.model.model["p"]["p"].policy, casbin.FastPolicy) + assert list(e.model.model["p"]["p"].policy) == [] + + def test_able_to_enforce_rule(self) -> None: + e = self.get_enforcer( + get_examples("basic_model.conf"), + get_examples("basic_policy.csv"), + [2, 1], + ) + + assert e.enforce("alice", "data1", "read") + assert not e.enforce("alice2", "data1", "read")