Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: port fastbin to casbin #318

Merged
merged 3 commits into from
Sep 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions casbin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .enforcer import *
from .synced_enforcer import SyncedEnforcer
from .distributed_enforcer import DistributedEnforcer
from .fast_enforcer import FastEnforcer
from .async_enforcer import AsyncEnforcer
from . import util
from .persist import *
Expand Down
6 changes: 1 addition & 5 deletions casbin/core_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import copy
import logging

from casbin.effect import Effector, get_effector, effect_to_bool
from casbin.model import Model, FunctionMap
Expand Down Expand Up @@ -202,7 +202,6 @@ def load_policy(self):
new_model.clear_policy()

try:

self.adapter.load_policy(new_model)

new_model.sort_policies_by_subject_hierarchy()
Expand All @@ -212,7 +211,6 @@ def load_policy(self):
new_model.print_policy()

if self.auto_build_role_links:

need_to_rebuild = True
for rm in self.rm_map.values():
rm.clear()
Expand All @@ -222,7 +220,6 @@ def load_policy(self):
self.model = new_model

except Exception as e:

if self.auto_build_role_links and need_to_rebuild:
self.build_role_links()

Expand Down Expand Up @@ -315,7 +312,6 @@ def add_named_domain_matching_func(self, ptype, fn):
return False

def new_enforce_context(self, suffix: str) -> EnforceContext:

return EnforceContext(
rtype="r" + suffix,
ptype="p" + suffix,
Expand Down
6 changes: 2 additions & 4 deletions casbin/distributed_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from casbin import SyncedEnforcer
import logging

from casbin.persist import batch_adapter
from casbin.model.policy_op import PolicyOp
from casbin.persist import batch_adapter
from casbin.persist.adapters import update_adapter
from casbin.synced_enforcer import SyncedEnforcer


class DistributedEnforcer(SyncedEnforcer):
Expand Down
41 changes: 41 additions & 0 deletions casbin/fast_enforcer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from typing import Sequence

from casbin.enforcer import Enforcer
from casbin.model import Model, FastModel, fast_policy_filter, FunctionMap
from casbin.persist.adapters import FileAdapter
from casbin.util.log import configure_logging


class FastEnforcer(Enforcer):
_cache_key_order: Sequence[int] = None

def __init__(self, model=None, adapter=None, enable_log=False, cache_key_order: Sequence[int] = None):
self._cache_key_order = cache_key_order
super().__init__(model, adapter, enable_log)

def new_model(self, path="", text=""):
"""creates a model."""
if self._cache_key_order is None:
m = Model()
else:
m = FastModel(self._cache_key_order)
if len(path) > 0:
m.load_model(path)
else:
m.load_model_from_text(text)

return m

def enforce(self, *rvals):
"""decides whether a "subject" can access a "object" with the operation "action",
input parameters are usually: (sub, obj, act).
"""
if FastEnforcer._cache_key_order is None:
result, _ = self.enforce_ex(*rvals)
else:
keys = [rvals[x] for x in self._cache_key_order]
with fast_policy_filter(self.model.model["p"]["p"].policy, *keys):
result, _ = self.enforce_ex(*rvals)

return result
4 changes: 3 additions & 1 deletion casbin/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

from .assertion import Assertion
from .function import FunctionMap
from .model import Model
from .model_fast import FastModel
from .policy import Policy
from .function import FunctionMap
from .policy_fast import FastPolicy, fast_policy_filter
36 changes: 36 additions & 0 deletions casbin/model/model_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2023 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Sequence

from .model import Model
from .policy_fast import FastPolicy


class FastModel(Model):
_cache_key_order: Sequence[int]

def __init__(self, cache_key_order: Sequence[int]) -> None:
super().__init__()
self._cache_key_order = cache_key_order

def add_def(self, sec: str, key: str, value: Any) -> None:
super().add_def(sec, key, value)
if sec == "p" and key == "p":
self.model[sec][key].policy = FastPolicy(self._cache_key_order)

def clear_policy(self) -> None:
"""clears all current policy."""
super().clear_policy()
self.model["p"]["p"].policy = FastPolicy(self._cache_key_order)
101 changes: 101 additions & 0 deletions casbin/model/policy_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2023 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from typing import Any, Container, Dict, Iterable, Iterator, Optional, Sequence, Set, cast


def in_cache(cache: Dict[str, Any], keys: Sequence[str]) -> Optional[Set[Sequence[str]]]:
if keys[0] in cache:
if len(keys) > 1:
return in_cache(cache[keys[-0]], keys[1:])
return cast(Set[Sequence[str]], cache[keys[0]])
else:
return None


class FastPolicy(Container[Sequence[str]]):
_cache: Dict[str, Any]
_current_filter: Optional[Set[Sequence[str]]]
_cache_key_order: Sequence[int]

def __init__(self, cache_key_order: Sequence[int]) -> None:
self._cache = {}
self._current_filter = None
self._cache_key_order = cache_key_order

def __iter__(self) -> Iterator[Sequence[str]]:
yield from self.__get_policy()

def __len__(self) -> int:
return len(list(self.__get_policy()))

def __contains__(self, item: object) -> bool:
if not isinstance(item, (list, tuple)) or len(self._cache_key_order) >= len(item):
return False
keys = [item[x] for x in self._cache_key_order]
exists = in_cache(self._cache, keys)
if not exists:
return False
return tuple(item) in exists

def __getitem__(self, item: int) -> Sequence[str]:
for i, entry in enumerate(self):
if i == item:
return entry
raise KeyError("No such value exists")

def append(self, item: Sequence[str]) -> None:
cache = self._cache
keys = [item[x] for x in self._cache_key_order]

for key in keys[:-1]:
if key not in cache:
cache[key] = dict()
cache = cache[key]
if keys[-1] not in cache:
cache[keys[-1]] = set()

cache[keys[-1]].add(tuple(item))

def remove(self, policy: Sequence[str]) -> bool:
keys = [policy[x] for x in self._cache_key_order]
exists = in_cache(self._cache, keys)
if not exists:
return True

exists.remove(tuple(policy))
return True

def __get_policy(self) -> Iterable[Sequence[str]]:
if self._current_filter is not None:
return (list(x) for x in self._current_filter)
else:
return (list(v2) for v in self._cache.values() for v1 in v.values() for v2 in v1)

def apply_filter(self, *keys: str) -> None:
value = in_cache(self._cache, keys)
self._current_filter = value or set()

def clear_filter(self) -> None:
self._current_filter = None


@contextmanager
def fast_policy_filter(policy: FastPolicy, *keys: str) -> Iterator[None]:
try:
policy.apply_filter(*keys)
yield
finally:
policy.clear_filter()
11 changes: 6 additions & 5 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import benchmarks
from . import config
from . import model
from . import rbac
from . import util
from .test_distributed_api import TestDistributedApi
from .test_enforcer import *
from .test_fast_enforcer import TestFastEnforcer
from .test_filter import TestFilteredAdapter
from .test_frontend import TestFrontend
from .test_management_api import TestManagementApi, TestManagementApiSynced
from .test_rbac_api import TestRbacApi, TestRbacApiSynced
from . import benchmarks
from . import config
from . import model
from . import rbac
from . import util
1 change: 1 addition & 0 deletions tests/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from .test_policy import TestPolicy
from .test_policy_fast import TestContextManager, TestFastPolicy
101 changes: 101 additions & 0 deletions tests/model/test_policy_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2023 The casbin Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import TestCase

from casbin.model import FastPolicy, fast_policy_filter


class TestFastPolicy(TestCase):
def test_able_to_add_rules(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert list(policy) == [["sub", "obj", "read"]]

def test_does_not_add_duplicates(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read"])

assert list(policy) == [["sub", "obj", "read"]]

def test_can_remove_rules(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.remove(["sub", "obj", "read"])

assert list(policy) == []

def test_returns_lengtt(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert len(policy) == 1

def test_supports_in_keyword(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])

assert ["sub", "obj", "read"] in policy

def test_supports_filters(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

policy.apply_filter("read2", "obj2")

assert list(policy) == [["sub", "obj2", "read2"]]

def test_clears_filters(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

policy.apply_filter("read2", "obj2")
policy.clear_filter()

assert list(policy) == [
["sub", "obj", "read"],
["sub", "obj", "read2"],
["sub", "obj2", "read2"],
]


class TestContextManager:
def test_fast_policy_filter(self) -> None:
policy = FastPolicy([2, 1])

policy.append(["sub", "obj", "read"])
policy.append(["sub", "obj", "read2"])
policy.append(["sub", "obj2", "read2"])

with fast_policy_filter(policy, "read2", "obj2"):
assert list(policy) == [["sub", "obj2", "read2"]]

assert list(policy) == [
["sub", "obj", "read"],
["sub", "obj", "read2"],
["sub", "obj2", "read2"],
]
Loading