Skip to content

Commit

Permalink
Support SQLAlchemy 2.0 with sqlalchemy-oso (#1739)
Browse files Browse the repository at this point in the history
  • Loading branch information
samscott89 authored Jun 13, 2024
1 parent efc3407 commit d57c4cc
Show file tree
Hide file tree
Showing 14 changed files with 183 additions and 103 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
toolchain: 1.74.1
override: true
components: rustfmt, clippy
- name: Check Rust formatting
Expand All @@ -43,10 +43,10 @@ jobs:
args: --all-features --all-targets -- -D warnings

## Check Python
- name: Install Python 3.7
- name: Install Python 3.9
uses: actions/setup-python@v1
with:
python-version: "3.7"
python-version: "3.9"
- name: Install Python formatter
run: pip install black~=22.10.0
- name: Check Python formatting
Expand Down Expand Up @@ -172,10 +172,10 @@ jobs:
toolchain: stable
override: true
components: rustfmt, clippy
- name: Install Python 3.7
- name: Install Python 3.9
uses: actions/setup-python@v1
with:
python-version: "3.7"
python-version: "3.9"
- name: Test python
run: make python-test
- name: Test flask
Expand Down Expand Up @@ -328,10 +328,10 @@ jobs:
${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }}
${{ runner.os }}-cargo-test-
- name: Install Python 3.7
- name: Install Python 3.9
uses: actions/setup-python@v1
with:
python-version: "3.7"
python-version: "3.9"

- name: install aspell
run: sudo apt-get install aspell
Expand Down
3 changes: 2 additions & 1 deletion languages/python/sqlalchemy-oso/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest==7.0.1
flask
flask<2.2
Werkzeug==2.2.2
flask_sqlalchemy<3.0
4 changes: 2 additions & 2 deletions languages/python/sqlalchemy-oso/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
oso~=0.27.1
SQLAlchemy>=1.3.17,<1.5
oso~=0.27.0
SQLAlchemy>=1.3.17,<3.0
packaging>=21.3,<24.0
2 changes: 1 addition & 1 deletion languages/python/sqlalchemy-oso/sqlalchemy_oso/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.27.1"
__version__ = "0.27.2"


from .auth import register_models
Expand Down
2 changes: 2 additions & 0 deletions languages/python/sqlalchemy-oso/sqlalchemy_oso/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
Keep us compatible with multiple SQLAlchemy versions by implementing wrappers
when needed here.
"""

import sqlalchemy
from packaging.version import parse

version = parse(sqlalchemy.__version__) # type: ignore
USING_SQLAlchemy_v1_3 = version >= parse("1.3") and version < parse("1.4")
USING_SQLAlchemy_v2_0 = version >= parse("2.0") and version < parse("3.0")


def iterate_model_classes(base_or_registry):
Expand Down
11 changes: 4 additions & 7 deletions languages/python/sqlalchemy-oso/sqlalchemy_oso/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import sqlalchemy.orm
from oso import Oso

from sqlalchemy_oso.compat import USING_SQLAlchemy_v2_0
from sqlalchemy_oso.session import authorized_sessionmaker, scoped_session

from .session import Permissions
Expand All @@ -48,13 +49,6 @@ class AuthorizedSQLAlchemy(SQLAlchemy):
:param get_user: Callable that returns the user to authorize for the current request.
:param get_checked_permissions: Callable that returns the permissions to authorize for the current request.
>>> from sqlalchemy_oso.flask import AuthorizedSQLAlchemy
>>> db = AuthorizedSQLAlchemy(
... get_oso=lambda: flask.current_app.oso,
... get_user=lambda: flask_login.current_user,
... get_checked_permissions=lambda: {Post: flask.request.method}
... )
.. _flask_sqlalchemy: https://flask-sqlalchemy.palletsprojects.com/en/2.x/
"""

Expand All @@ -65,6 +59,9 @@ def __init__(
get_checked_permissions: Callable[[], Permissions],
**kwargs: Any
) -> None:
if USING_SQLAlchemy_v2_0:
raise NotImplementedError("Unsupported on SQLAlchemy >= 2.0")

self._get_oso = get_oso
self._get_user = get_user
self._get_checked_permissions = get_checked_permissions
Expand Down
1 change: 1 addition & 0 deletions languages/python/sqlalchemy-oso/sqlalchemy_oso/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""SQLAlchemy session classes and factories for oso."""

import logging
from typing import Any, Callable, Dict, Optional, Type

Expand Down
182 changes: 121 additions & 61 deletions languages/python/sqlalchemy-oso/sqlalchemy_oso/sqlalchemy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
We must detect all entities properly to apply authorization.
"""

import sqlalchemy
from sqlalchemy import inspect
from sqlalchemy.orm.util import AliasedClass, AliasedInsp

from .compat import USING_SQLAlchemy_v1_3, USING_SQLAlchemy_v2_0


def to_class(entity):
"""Get mapped class from SQLAlchemy entity."""
Expand All @@ -20,7 +23,124 @@ def to_class(entity):
return entity


try:
if USING_SQLAlchemy_v1_3:
# unsupported for <= 1.3
def all_entities_in_statement(statement):
raise NotImplementedError("Unsupported on SQLAlchemy < 1.4")

else:
if USING_SQLAlchemy_v2_0:

# the structure we're dealing with is essentially:

# (path, strategy, options)
# where "path" indicates what it is we are loading,
# like (A, A.bs, B, B.cs, C)
# "strategy" is a tuple that keys to one of the loader strategies,
# some of them apply to relationships and others to column attributes
# then "options" is extra stuff like "innerjoin=True"
def get_joinedload_entities(stmt):
"""Get extra entities that are loaded from a ``stmt`` due to joinedload
options specified in the statement options.
These entities will not be returned directly by the query, but will prepopulate
relationships in the returned data.
For example::
get_joinedload_entities(query(A).options(joinedload(A.bs))) == {A, B}
"""
# there are two kinds of options that both represent the same information,
# just in different ways. This is largely a product of legacy options
# that have things like strings, i.e. joinedload("addresses"). note we
# aren't covering that here, which is legacy form. you can if you want
# raise an exception if you detect that form here.

entities = set()

for opt in stmt._with_options:
if hasattr(opt, "_to_bind"):
# these options are called _UnboundLoad
for b in opt._to_bind:
if ("lazy", "joined") in b.strategy:
# the "path" is a tuple showing the entity/relationships
# being targeted

# TODO check for wild card.
# TODO: Check whether entity is a string.
entities.add(b.path[-1].entity)
elif hasattr(opt, "context"):
# these options are called Load
for loadopt in opt.context:
if ("lazy", "joined") in loadopt.strategy:
# the "path" is a tuple showing the entity/relationships
# being targeted

# TODO: Check for of_type.
# TODO: Check whether entity is a string, unsupported.
# TODO check for wild card.
entities.add(loadopt.path[-1].entity)

return entities

else:
# Start POC code from @zzzeek (Mike Bayer)
# TODO: Still needs to be generalized & support other options.

# the structure we're dealing with is essentially:

# (path, strategy, options)
# where "path" indicates what it is we are loading,
# like (A, A.bs, B, B.cs, C)
# "strategy" is a tuple that keys to one of the loader strategies,
# some of them apply to relationships and others to column attributes
# then "options" is extra stuff like "innerjoin=True"
def get_joinedload_entities(stmt):
"""Get extra entities that are loaded from a ``stmt`` due to joinedload
options specified in the statement options.
These entities will not be returned directly by the query, but will prepopulate
relationships in the returned data.
For example::
get_joinedload_entities(query(A).options(joinedload(A.bs))) == {A, B}
"""
# there are two kinds of options that both represent the same information,
# just in different ways. This is largely a product of legacy options
# that have things like strings, i.e. joinedload("addresses"). note we
# aren't covering that here, which is legacy form. you can if you want
# raise an exception if you detect that form here.

entities = set()

for opt in stmt._with_options:
if hasattr(opt, "_to_bind"):
# these options are called _UnboundLoad
for b in opt._to_bind:
if ("lazy", "joined") in b.strategy:
# the "path" is a tuple showing the entity/relationships
# being targeted

# TODO check for wild card.
# TODO: Check whether entity is a string.
entities.add(b.path[-1].entity)
elif hasattr(opt, "context"):
# these options are called Load
for key, loadopt in opt.context.items():
if (
key[0] == "loader"
and ("lazy", "joined") in loadopt.strategy
):
# the "path" is a tuple showing the entity/relationships
# being targeted

# TODO: Check for of_type.
# TODO: Check whether entity is a string, unsupported.
# TODO check for wild card.
entities.add(key[1][-1].entity)

return entities

def all_entities_in_statement(statement):
"""
Expand Down Expand Up @@ -106,63 +226,3 @@ class A(Base):
default_entities.add(rel.mapper)

return default_entities

# Start POC code from @zzzeek (Mike Bayer)
# TODO: Still needs to be generalized & support other options.

# the structure we're dealing with is essentially:

# (path, strategy, options)
# where "path" indicates what it is we are loading,
# like (A, A.bs, B, B.cs, C)
# "strategy" is a tuple that keys to one of the loader strategies,
# some of them apply to relationships and others to column attributes
# then "options" is extra stuff like "innerjoin=True"
def get_joinedload_entities(stmt):
"""Get extra entities that are loaded from a ``stmt`` due to joinedload
options specified in the statement options.
These entities will not be returned directly by the query, but will prepopulate
relationships in the returned data.
For example::
get_joinedload_entities(query(A).options(joinedload(A.bs))) == {A, B}
"""
# there are two kinds of options that both represent the same information,
# just in different ways. This is largely a product of legacy options
# that have things like strings, i.e. joinedload("addresses"). note we
# aren't covering that here, which is legacy form. you can if you want
# raise an exception if you detect that form here.

entities = set()

for opt in stmt._with_options:
if hasattr(opt, "_to_bind"):
# these options are called _UnboundLoad
for b in opt._to_bind:
if ("lazy", "joined") in b.strategy:
# the "path" is a tuple showing the entity/relationships
# being targeted

# TODO check for wild card.
# TODO: Check whether entity is a string.
entities.add(b.path[-1].entity)
elif hasattr(opt, "context"):
# these options are called Load
for key, loadopt in opt.context.items():
if key[0] == "loader" and ("lazy", "joined") in loadopt.strategy:
# the "path" is a tuple showing the entity/relationships
# being targeted

# TODO: Check for of_type.
# TODO: Check whether entity is a string, unsupported.
# TODO check for wild card.
entities.add(key[1][-1].entity)

return entities

except ImportError:
# This code should not be called for SQLAlchemy 1.4.
def all_entities_in_statement(statement):
raise NotImplementedError("Unsupported on SQLAlchemy < 1.4")
12 changes: 11 additions & 1 deletion languages/python/sqlalchemy-oso/tests/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship

from sqlalchemy_oso.compat import USING_SQLAlchemy_v1_3

if USING_SQLAlchemy_v1_3:
# mypy: ignore-errors
from sqlalchemy.ext.declarative import declarative_base
else:
# mypy: ignore-errors
from sqlalchemy.orm import declarative_base

from sqlalchemy.schema import Table

# mypy: ignore-errors
ModelBase = declarative_base(name="ModelBase")


Expand Down
20 changes: 4 additions & 16 deletions languages/python/sqlalchemy-oso/tests/test_advanced_queries_14.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
with_loader_criteria,
)

from sqlalchemy_oso.compat import USING_SQLAlchemy_v1_3
from sqlalchemy_oso.compat import USING_SQLAlchemy_v1_3, USING_SQLAlchemy_v2_0
from sqlalchemy_oso.session import AuthorizedSession
from sqlalchemy_oso.sqlalchemy_utils import (
all_entities_in_statement,
Expand Down Expand Up @@ -118,7 +118,6 @@ def test_get_column_entities(stmt, o):
(select(A), set()),
(select(A).options(joinedload(A.bs)), {B}),
(select(A).options(joinedload(A.bs).joinedload(B.cs)), {B, C}),
(select(A).options(Load(A).joinedload("bs")), {B}),
pytest.param(
select(A).options(Load(A).joinedload("*")),
set(),
Expand All @@ -131,21 +130,10 @@ def test_get_column_entities(stmt, o):
),
),
)
def test_get_joinedload_entities(stmt, o):
assert set(map(to_class, get_joinedload_entities(stmt))) == o


@pytest.mark.parametrize(
"stmt,o",
(
pytest.param(
select(A).options(joinedload("A.bs")),
{B},
marks=pytest.mark.xfail(reason="String doesn't work"),
),
),
@pytest.mark.skipif(
USING_SQLAlchemy_v2_0, reason="flask sqlalchemy does not support 2.0"
)
def test_get_joinedload_entities_str(stmt, o):
def test_get_joinedload_entities(stmt, o):
assert set(map(to_class, get_joinedload_entities(stmt))) == o


Expand Down
Loading

1 comment on commit d57c4cc

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rust Benchmark

Benchmark suite Current: d57c4cc Previous: efc3407 Ratio
rust_get_attribute 29007 ns/iter (± 1681) 28593 ns/iter (± 1129) 1.01
n_plus_one/100 1561224 ns/iter (± 22125) 1532810 ns/iter (± 32779) 1.02
n_plus_one/500 7570890 ns/iter (± 235880) 7413304 ns/iter (± 147108) 1.02
n_plus_one/1000 14999531 ns/iter (± 60196) 14726790 ns/iter (± 99436) 1.02
unify_once 608 ns/iter (± 1259) 612 ns/iter (± 1601) 0.99
unify_twice 1701 ns/iter (± 68) 1722 ns/iter (± 477) 0.99
many_rules 37888 ns/iter (± 1313) 38444 ns/iter (± 1655) 0.99
fib/5 342334 ns/iter (± 5618) 346628 ns/iter (± 5649) 0.99
prime/3 10141 ns/iter (± 559) 10389 ns/iter (± 435) 0.98
prime/23 10123 ns/iter (± 479) 10417 ns/iter (± 13795) 0.97
prime/43 10152 ns/iter (± 416) 10373 ns/iter (± 432) 0.98
prime/83 10143 ns/iter (± 479) 10378 ns/iter (± 505) 0.98
prime/255 9149 ns/iter (± 501) 9292 ns/iter (± 399) 0.98
indexed/100 3411 ns/iter (± 396) 3435 ns/iter (± 419) 0.99
indexed/500 3719 ns/iter (± 554) 3903 ns/iter (± 1041) 0.95
indexed/1000 4011 ns/iter (± 1566) 4413 ns/iter (± 252) 0.91
indexed/10000 7303 ns/iter (± 887) 7873 ns/iter (± 2238) 0.93
not 4023 ns/iter (± 61) 3957 ns/iter (± 59) 1.02
double_not 8522 ns/iter (± 210) 8445 ns/iter (± 116) 1.01
De_Morgan_not 5460 ns/iter (± 105) 5375 ns/iter (± 190) 1.02
load_policy 666208 ns/iter (± 5903) 658634 ns/iter (± 1819) 1.01
partial_and/1 20344 ns/iter (± 915) 20637 ns/iter (± 723) 0.99
partial_and/5 67272 ns/iter (± 2375) 68518 ns/iter (± 2310) 0.98
partial_and/10 126310 ns/iter (± 3491) 129094 ns/iter (± 3590) 0.98
partial_and/20 265945 ns/iter (± 7360) 269045 ns/iter (± 5442) 0.99
partial_and/40 581984 ns/iter (± 7718) 588582 ns/iter (± 7817) 0.99
partial_and/80 1354955 ns/iter (± 9774) 1363808 ns/iter (± 9505) 0.99
partial_and/100 1814713 ns/iter (± 4678) 1817673 ns/iter (± 11495) 1.00
partial_rule_depth/1 61940 ns/iter (± 2230) 62518 ns/iter (± 2317) 0.99
partial_rule_depth/5 212137 ns/iter (± 5253) 214194 ns/iter (± 5979) 0.99
partial_rule_depth/10 483087 ns/iter (± 11449) 485248 ns/iter (± 8924) 1.00
partial_rule_depth/20 1379490 ns/iter (± 25463) 1380553 ns/iter (± 18146) 1.00
partial_rule_depth/40 4964807 ns/iter (± 29388) 4942475 ns/iter (± 26580) 1.00
partial_rule_depth/80 27331603 ns/iter (± 220687) 27557996 ns/iter (± 306430) 0.99
partial_rule_depth/100 49411962 ns/iter (± 459309) 49507033 ns/iter (± 548583) 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.