Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implemented variable order field index #354

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions casbin/async_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from casbin.async_management_enforcer import AsyncManagementEnforcer
from casbin.util import join_slice, array_remove_duplicates, set_subtract
from casbin.constant.constants import DOMAIN_INDEX, SUBJECT_INDEX, OBJECT_INDEX


class AsyncEnforcer(AsyncManagementEnforcer):
Expand Down Expand Up @@ -280,8 +281,8 @@ async def get_implicit_users_for_resource(self, resource):
get_implicit_users_for_resource("data1") will return [[alice data1 read]]
Note: only users will be returned, roles (2nd arg in "g") will be excluded."""
permissions = dict()
subject_index = await self.get_field_index("p", "sub")
object_index = await self.get_field_index("p", "obj")
subject_index = await self.get_field_index("p", SUBJECT_INDEX)
object_index = await self.get_field_index("p", OBJECT_INDEX)
rm = self.get_role_manager()
roles = self.get_all_roles()

Expand All @@ -304,9 +305,9 @@ async def get_implicit_users_for_resource_by_domain(self, resource, domain):
"""get implicit user based on resource and domain.
Compared to GetImplicitUsersForResource, domain is supported"""
permissions = dict()
subject_index = await self.get_field_index("p", "sub")
object_index = await self.get_field_index("p", "obj")
dom_index = await self.get_field_index("p", "dom")
subject_index = await self.get_field_index("p", SUBJECT_INDEX)
object_index = await self.get_field_index("p", OBJECT_INDEX)
dom_index = await self.get_field_index("p", DOMAIN_INDEX)
rm = self.get_role_manager()
roles = await self.get_all_roles_by_domain(domain)

Expand Down
9 changes: 0 additions & 9 deletions casbin/async_internal_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,3 @@ async def _remove_filtered_policy_returns_effects(self, sec, ptype, field_index,
self.watcher.update()

return rule_removed

async def get_field_index(self, ptype, field):
"""gets the index of the field name."""
return self.model.get_field_index(ptype, field)

async def set_field_index(self, ptype, field, index):
"""sets the index of the field name."""
assertion = self.model["p"][ptype]
assertion.field_index_map[field] = index
25 changes: 19 additions & 6 deletions casbin/async_management_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from casbin.async_internal_enforcer import AsyncInternalEnforcer
from casbin.model.policy_op import PolicyOp
from casbin.constant.constants import ACTION_INDEX, SUBJECT_INDEX, OBJECT_INDEX


class AsyncManagementEnforcer(AsyncInternalEnforcer):
Expand All @@ -22,27 +23,30 @@ class AsyncManagementEnforcer(AsyncInternalEnforcer):

def get_all_subjects(self):
"""gets the list of subjects that show up in the current policy."""
return self.get_all_named_subjects("p")
return self.model.get_values_for_field_in_policy_all_types_by_name("p", SUBJECT_INDEX)

def get_all_named_subjects(self, ptype):
"""gets the list of subjects that show up in the current named policy."""
return self.model.get_values_for_field_in_policy("p", ptype, 0)
field_index = self.model.get_field_index(ptype, SUBJECT_INDEX)
return self.model.get_values_for_field_in_policy("p", ptype, field_index)

def get_all_objects(self):
"""gets the list of objects that show up in the current policy."""
return self.get_all_named_objects("p")
return self.model.get_values_for_field_in_policy_all_types_by_name("p", OBJECT_INDEX)

def get_all_named_objects(self, ptype):
"""gets the list of objects that show up in the current named policy."""
return self.model.get_values_for_field_in_policy("p", ptype, 1)
field_index = self.model.get_field_index(ptype, OBJECT_INDEX)
return self.model.get_values_for_field_in_policy("p", ptype, field_index)

def get_all_actions(self):
"""gets the list of actions that show up in the current policy."""
return self.get_all_named_actions("p")
return self.model.get_values_for_field_in_policy_all_types_by_name("p", ACTION_INDEX)

def get_all_named_actions(self, ptype):
"""gets the list of actions that show up in the current named policy."""
return self.model.get_values_for_field_in_policy("p", ptype, 2)
field_index = self.model.get_field_index(ptype, ACTION_INDEX)
return self.model.get_values_for_field_in_policy("p", ptype, field_index)

def get_all_roles(self):
"""gets the list of roles that show up in the current named policy."""
Expand Down Expand Up @@ -302,3 +306,12 @@ async def remove_filtered_named_grouping_policy(self, ptype, field_index, *field
def add_function(self, name, func):
"""adds a customized function."""
self.fm.add_function(name, func)

async def get_field_index(self, ptype, field):
"""gets the index of the field name."""
return self.model.get_field_index(ptype, field)

async def set_field_index(self, ptype, field, index):
"""sets the index of the field name."""
assertion = self.model["p"][ptype]
assertion.field_index_map[field] = index
22 changes: 14 additions & 8 deletions casbin/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from casbin.management_enforcer import ManagementEnforcer
from casbin.util import join_slice, array_remove_duplicates, set_subtract
from casbin.constant.constants import DOMAIN_INDEX, SUBJECT_INDEX, OBJECT_INDEX


class Enforcer(ManagementEnforcer):
Expand Down Expand Up @@ -73,7 +74,8 @@ def delete_user(self, user):
"""
res1 = self.remove_filtered_grouping_policy(0, user)

res2 = self.remove_filtered_policy(0, user)
sub_index = self.get_field_index("p", SUBJECT_INDEX)
res2 = self.remove_filtered_policy(sub_index, user)
return res1 or res2

def delete_role(self, role):
Expand All @@ -83,7 +85,8 @@ def delete_role(self, role):
"""
res1 = self.remove_filtered_grouping_policy(1, role)

res2 = self.remove_filtered_policy(0, role)
sub_index = self.get_field_index("p", SUBJECT_INDEX)
res2 = self.remove_filtered_policy(sub_index, role)
return res1 or res2

def delete_permission(self, *permission):
Expand Down Expand Up @@ -112,7 +115,10 @@ def delete_permissions_for_user(self, user):
deletes permissions for a user or role.
Returns false if the user or role does not have any permissions (aka not affected).
"""
return self.remove_filtered_policy(0, user)
sub_index = self.get_field_index("p", SUBJECT_INDEX)
if sub_index == -1:
return False
return self.remove_filtered_policy(sub_index, user)

def get_permissions_for_user(self, user):
"""
Expand Down Expand Up @@ -289,8 +295,8 @@ def get_implicit_users_for_resource(self, resource):
get_implicit_users_for_resource("data1") will return [[alice data1 read]]
Note: only users will be returned, roles (2nd arg in "g") will be excluded."""
permissions = dict()
subject_index = self.get_field_index("p", "sub")
object_index = self.get_field_index("p", "obj")
subject_index = self.get_field_index("p", SUBJECT_INDEX)
object_index = self.get_field_index("p", OBJECT_INDEX)
rm = self.get_role_manager()
roles = self.get_all_roles()

Expand All @@ -313,9 +319,9 @@ def get_implicit_users_for_resource_by_domain(self, resource, domain):
"""get implicit user based on resource and domain.
Compared to GetImplicitUsersForResource, domain is supported"""
permissions = dict()
subject_index = self.get_field_index("p", "sub")
object_index = self.get_field_index("p", "obj")
dom_index = self.get_field_index("p", "dom")
subject_index = self.get_field_index("p", SUBJECT_INDEX)
object_index = self.get_field_index("p", OBJECT_INDEX)
dom_index = self.get_field_index("p", DOMAIN_INDEX)
rm = self.get_role_manager()
roles = self.get_all_roles_by_domain(domain)

Expand Down
9 changes: 0 additions & 9 deletions casbin/internal_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,3 @@ def _remove_filtered_policy_returns_effects(self, sec, ptype, field_index, *fiel
self.watcher.update()

return rule_removed

def get_field_index(self, ptype, field):
"""gets the index of the field name."""
return self.model.get_field_index(ptype, field)

def set_field_index(self, ptype, field, index):
"""sets the index of the field name."""
assertion = self.model["p"][ptype]
assertion.field_index_map[field] = index
15 changes: 12 additions & 3 deletions casbin/management_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ManagementEnforcer(InternalEnforcer):

def get_all_subjects(self):
"""gets the list of subjects that show up in the current policy."""
return self.get_all_named_subjects("p")
return self.model.get_values_for_field_in_policy_all_types_by_name("p", SUBJECT_INDEX)

def get_all_named_subjects(self, ptype):
"""gets the list of subjects that show up in the current named policy."""
Expand All @@ -33,7 +33,7 @@ def get_all_named_subjects(self, ptype):

def get_all_objects(self):
"""gets the list of objects that show up in the current policy."""
return self.get_all_named_objects("p")
return self.model.get_values_for_field_in_policy_all_types_by_name("p", OBJECT_INDEX)

def get_all_named_objects(self, ptype):
"""gets the list of objects that show up in the current named policy."""
Expand All @@ -42,7 +42,7 @@ def get_all_named_objects(self, ptype):

def get_all_actions(self):
"""gets the list of actions that show up in the current policy."""
return self.get_all_named_actions("p")
return self.model.get_values_for_field_in_policy_all_types_by_name("p", ACTION_INDEX)

def get_all_named_actions(self, ptype):
"""gets the list of actions that show up in the current named policy."""
Expand Down Expand Up @@ -309,3 +309,12 @@ def remove_filtered_named_grouping_policy(self, ptype, field_index, *field_value
def add_function(self, name, func):
"""adds a customized function."""
self.fm.add_function(name, func)

def get_field_index(self, ptype, field):
"""gets the index of the field name."""
return self.model.get_field_index(ptype, field)

def set_field_index(self, ptype, field, index):
"""sets the index of the field name."""
assertion = self.model["p"][ptype]
assertion.field_index_map[field] = index
1 change: 0 additions & 1 deletion casbin/model/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self):
self.policy = []
self.rm = None
self.cond_rm = None
self.priority_index: int = -1
self.policy_map: dict = {}
self.field_index_map: dict = {}

Expand Down
42 changes: 8 additions & 34 deletions casbin/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from casbin import util, config
from . import Assertion
from .policy import Policy
from casbin.constant.constants import DOMAIN_INDEX, PRIORITY_INDEX, SUBJECT_PRIORITY_EFFECT

DEFAULT_DOMAIN = ""
DEFAULT_SEPARATOR = "::"
Expand Down Expand Up @@ -116,19 +117,16 @@ def print_model(self):

def sort_policies_by_priority(self):
for ptype, assertion in self["p"].items():
for index, token in enumerate(assertion.tokens):
if token == f"{ptype}_priority":
assertion.priority_index = index
break
priority_index = self.get_field_index(ptype, PRIORITY_INDEX)

if assertion.priority_index == -1:
if priority_index == -1:
continue

assertion.policy = sorted(
assertion.policy,
key=lambda x: int(x[assertion.priority_index])
if x[assertion.priority_index].isdigit()
else x[assertion.priority_index],
key=lambda x: int(x[priority_index])
if x[priority_index].isdigit()
else x[priority_index],
)

for i, policy in enumerate(assertion.policy):
Expand All @@ -137,16 +135,12 @@ def sort_policies_by_priority(self):
return None

def sort_policies_by_subject_hierarchy(self):
if self["e"]["e"].value != "subjectPriority(p_eft) || deny":
if self["e"]["e"].value != SUBJECT_PRIORITY_EFFECT:
return

sub_index = 0
domain_index = -1
for ptype, assertion in self["p"].items():
for index, token in enumerate(assertion.tokens):
if token == "{}_dom".format(ptype):
domain_index = index
break
domain_index = self.get_field_index(ptype, DOMAIN_INDEX)

subject_hierarchy_map = self.get_subject_hierarchy_map(self["g"]["g"].policy)

Expand Down Expand Up @@ -230,23 +224,3 @@ def write_string(sec):
s[-1] = s[-1].strip()

return "".join(s)

def get_field_index(self, ptype, field):
"""get_field_index gets the index of the field for a ptype in a policy,
return -1 if the field does not exist."""
assertion = self["p"][ptype]
if field in assertion.field_index_map:
return assertion.field_index_map[field]

pattern = f"{ptype}_{field}"
index = -1
for i, token in enumerate(assertion.tokens):
if token == pattern:
index = i
break

if index == -1:
return index

assertion.field_index_map[field] = index
return index
65 changes: 62 additions & 3 deletions casbin/model/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import logging
from casbin.util import util
from casbin.constant.constants import PRIORITY_INDEX

DEFAULT_SEP = ","

Expand Down Expand Up @@ -119,14 +121,17 @@ def add_policy(self, sec, ptype, rule):
else:
return False

if sec == "p" and assertion.priority_index >= 0:
has_priority = False
if assertion.field_index_map.get(PRIORITY_INDEX) is not None:
has_priority = True
if sec == "p" and has_priority:
try:
idx_insert = int(rule[assertion.priority_index])
idx_insert = int(rule[assertion.field_index_map[PRIORITY_INDEX]])

i = len(assertion.policy) - 1
for i in range(i, 0, -1):
try:
idx = int(assertion.policy[i - 1][assertion.priority_index])
idx = int(assertion.policy[i - 1][assertion.field_index_map[PRIORITY_INDEX]])
except Exception as e:
print(e)

Expand Down Expand Up @@ -303,3 +308,57 @@ def get_values_for_field_in_policy(self, sec, ptype, field_index):
values.append(value)

return values

def get_values_for_field_in_policy_all_types(self, sec, field_index):
"""gets all values for a field for all rules in a policy of all ptypes, duplicated values are removed."""
values = []
if sec not in self.keys():
return values

for ptype in self[sec]:
value = self.get_values_for_field_in_policy(sec, ptype, field_index)
values.extend(value)

values = util.array_remove_duplicates(values)

return values

def get_values_for_field_in_policy_all_types_by_name(self, sec, field):
"""gets all values for a field for all rules in a policy of all ptypes, duplicated values are removed."""
values = []
if sec not in self.keys():
return values

for ptype in self[sec]:
index = self.get_field_index(ptype, field)
value = self.get_values_for_field_in_policy(sec, ptype, index)
values.extend(value)

values = util.array_remove_duplicates(values)

return values

def get_field_index(self, ptype, field):
"""get_field_index gets the index of the field for a ptype in a policy,
return -1 if the field does not exist."""
assertion = self["p"][ptype]
if field in assertion.field_index_map:
return assertion.field_index_map[field]

pattern = f"{ptype}_{field}"
index = -1
for i, token in enumerate(assertion.tokens):
if token == pattern:
index = i
break

if index == -1:
return index

assertion.field_index_map[field] = index
return index

def set_field_index(self, ptype, field, index):
"""sets the index of the field name."""
assertion = self["p"][ptype]
assertion.field_index_map[field] = index
Loading