From d57c4cc53df686ee5a7b1bc399daa2246c8f37f5 Mon Sep 17 00:00:00 2001 From: Sam Scott Date: Thu, 13 Jun 2024 14:55:21 -0500 Subject: [PATCH] Support SQLAlchemy 2.0 with `sqlalchemy-oso` (#1739) --- .github/workflows/test.yml | 14 +- .../sqlalchemy-oso/requirements-test.txt | 3 +- .../python/sqlalchemy-oso/requirements.txt | 4 +- .../sqlalchemy-oso/sqlalchemy_oso/__init__.py | 2 +- .../sqlalchemy-oso/sqlalchemy_oso/compat.py | 2 + .../sqlalchemy-oso/sqlalchemy_oso/flask.py | 11 +- .../sqlalchemy-oso/sqlalchemy_oso/session.py | 1 + .../sqlalchemy_oso/sqlalchemy_utils.py | 182 ++++++++++++------ .../python/sqlalchemy-oso/tests/models.py | 12 +- .../tests/test_advanced_queries_14.py | 20 +- .../python/sqlalchemy-oso/tests/test_flask.py | 7 + .../tests/test_post_relationship.py | 23 ++- .../sqlalchemy-oso/tests/test_sqlalchemy.py | 1 + languages/python/sqlalchemy-oso/tox.ini | 4 +- 14 files changed, 183 insertions(+), 103 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6c0b9482a5..1f6d18de95 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: stable + toolchain: 1.74.1 override: true components: rustfmt, clippy - name: Check Rust formatting @@ -43,10 +43,10 @@ jobs: args: --all-features --all-targets -- -D warnings ## Check Python - - name: Install Python 3.7 + - name: Install Python 3.9 uses: actions/setup-python@v1 with: - python-version: "3.7" + python-version: "3.9" - name: Install Python formatter run: pip install black~=22.10.0 - name: Check Python formatting @@ -172,10 +172,10 @@ jobs: toolchain: stable override: true components: rustfmt, clippy - - name: Install Python 3.7 + - name: Install Python 3.9 uses: actions/setup-python@v1 with: - python-version: "3.7" + python-version: "3.9" - name: Test python run: make python-test - name: Test flask @@ -328,10 +328,10 @@ jobs: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }} ${{ runner.os }}-cargo-test- - - name: Install Python 3.7 + - name: Install Python 3.9 uses: actions/setup-python@v1 with: - python-version: "3.7" + python-version: "3.9" - name: install aspell run: sudo apt-get install aspell diff --git a/languages/python/sqlalchemy-oso/requirements-test.txt b/languages/python/sqlalchemy-oso/requirements-test.txt index 8bd0cc37eb..3a6851776c 100644 --- a/languages/python/sqlalchemy-oso/requirements-test.txt +++ b/languages/python/sqlalchemy-oso/requirements-test.txt @@ -1,3 +1,4 @@ pytest==7.0.1 -flask +flask<2.2 +Werkzeug==2.2.2 flask_sqlalchemy<3.0 diff --git a/languages/python/sqlalchemy-oso/requirements.txt b/languages/python/sqlalchemy-oso/requirements.txt index d49971c19b..185f3dde03 100644 --- a/languages/python/sqlalchemy-oso/requirements.txt +++ b/languages/python/sqlalchemy-oso/requirements.txt @@ -1,3 +1,3 @@ -oso~=0.27.1 -SQLAlchemy>=1.3.17,<1.5 +oso~=0.27.0 +SQLAlchemy>=1.3.17,<3.0 packaging>=21.3,<24.0 diff --git a/languages/python/sqlalchemy-oso/sqlalchemy_oso/__init__.py b/languages/python/sqlalchemy-oso/sqlalchemy_oso/__init__.py index d48c1d160a..f4d959bb74 100644 --- a/languages/python/sqlalchemy-oso/sqlalchemy_oso/__init__.py +++ b/languages/python/sqlalchemy-oso/sqlalchemy_oso/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.27.1" +__version__ = "0.27.2" from .auth import register_models diff --git a/languages/python/sqlalchemy-oso/sqlalchemy_oso/compat.py b/languages/python/sqlalchemy-oso/sqlalchemy_oso/compat.py index 84b17d12f1..8b1d0b4f72 100644 --- a/languages/python/sqlalchemy-oso/sqlalchemy_oso/compat.py +++ b/languages/python/sqlalchemy-oso/sqlalchemy_oso/compat.py @@ -3,11 +3,13 @@ Keep us compatible with multiple SQLAlchemy versions by implementing wrappers when needed here. """ + import sqlalchemy from packaging.version import parse version = parse(sqlalchemy.__version__) # type: ignore USING_SQLAlchemy_v1_3 = version >= parse("1.3") and version < parse("1.4") +USING_SQLAlchemy_v2_0 = version >= parse("2.0") and version < parse("3.0") def iterate_model_classes(base_or_registry): diff --git a/languages/python/sqlalchemy-oso/sqlalchemy_oso/flask.py b/languages/python/sqlalchemy-oso/sqlalchemy_oso/flask.py index 86300b9233..95e2970e52 100644 --- a/languages/python/sqlalchemy-oso/sqlalchemy_oso/flask.py +++ b/languages/python/sqlalchemy-oso/sqlalchemy_oso/flask.py @@ -33,6 +33,7 @@ import sqlalchemy.orm from oso import Oso +from sqlalchemy_oso.compat import USING_SQLAlchemy_v2_0 from sqlalchemy_oso.session import authorized_sessionmaker, scoped_session from .session import Permissions @@ -48,13 +49,6 @@ class AuthorizedSQLAlchemy(SQLAlchemy): :param get_user: Callable that returns the user to authorize for the current request. :param get_checked_permissions: Callable that returns the permissions to authorize for the current request. - >>> from sqlalchemy_oso.flask import AuthorizedSQLAlchemy - >>> db = AuthorizedSQLAlchemy( - ... get_oso=lambda: flask.current_app.oso, - ... get_user=lambda: flask_login.current_user, - ... get_checked_permissions=lambda: {Post: flask.request.method} - ... ) - .. _flask_sqlalchemy: https://flask-sqlalchemy.palletsprojects.com/en/2.x/ """ @@ -65,6 +59,9 @@ def __init__( get_checked_permissions: Callable[[], Permissions], **kwargs: Any ) -> None: + if USING_SQLAlchemy_v2_0: + raise NotImplementedError("Unsupported on SQLAlchemy >= 2.0") + self._get_oso = get_oso self._get_user = get_user self._get_checked_permissions = get_checked_permissions diff --git a/languages/python/sqlalchemy-oso/sqlalchemy_oso/session.py b/languages/python/sqlalchemy-oso/sqlalchemy_oso/session.py index 61986d12fb..b0636adcbf 100644 --- a/languages/python/sqlalchemy-oso/sqlalchemy_oso/session.py +++ b/languages/python/sqlalchemy-oso/sqlalchemy_oso/session.py @@ -1,4 +1,5 @@ """SQLAlchemy session classes and factories for oso.""" + import logging from typing import Any, Callable, Dict, Optional, Type diff --git a/languages/python/sqlalchemy-oso/sqlalchemy_oso/sqlalchemy_utils.py b/languages/python/sqlalchemy-oso/sqlalchemy_oso/sqlalchemy_utils.py index c4006e4923..d1c93e5819 100644 --- a/languages/python/sqlalchemy-oso/sqlalchemy_oso/sqlalchemy_utils.py +++ b/languages/python/sqlalchemy-oso/sqlalchemy_oso/sqlalchemy_utils.py @@ -5,10 +5,13 @@ We must detect all entities properly to apply authorization. """ + import sqlalchemy from sqlalchemy import inspect from sqlalchemy.orm.util import AliasedClass, AliasedInsp +from .compat import USING_SQLAlchemy_v1_3, USING_SQLAlchemy_v2_0 + def to_class(entity): """Get mapped class from SQLAlchemy entity.""" @@ -20,7 +23,124 @@ def to_class(entity): return entity -try: +if USING_SQLAlchemy_v1_3: + # unsupported for <= 1.3 + def all_entities_in_statement(statement): + raise NotImplementedError("Unsupported on SQLAlchemy < 1.4") + +else: + if USING_SQLAlchemy_v2_0: + + # the structure we're dealing with is essentially: + + # (path, strategy, options) + # where "path" indicates what it is we are loading, + # like (A, A.bs, B, B.cs, C) + # "strategy" is a tuple that keys to one of the loader strategies, + # some of them apply to relationships and others to column attributes + # then "options" is extra stuff like "innerjoin=True" + def get_joinedload_entities(stmt): + """Get extra entities that are loaded from a ``stmt`` due to joinedload + options specified in the statement options. + + These entities will not be returned directly by the query, but will prepopulate + relationships in the returned data. + + For example:: + + get_joinedload_entities(query(A).options(joinedload(A.bs))) == {A, B} + """ + # there are two kinds of options that both represent the same information, + # just in different ways. This is largely a product of legacy options + # that have things like strings, i.e. joinedload("addresses"). note we + # aren't covering that here, which is legacy form. you can if you want + # raise an exception if you detect that form here. + + entities = set() + + for opt in stmt._with_options: + if hasattr(opt, "_to_bind"): + # these options are called _UnboundLoad + for b in opt._to_bind: + if ("lazy", "joined") in b.strategy: + # the "path" is a tuple showing the entity/relationships + # being targeted + + # TODO check for wild card. + # TODO: Check whether entity is a string. + entities.add(b.path[-1].entity) + elif hasattr(opt, "context"): + # these options are called Load + for loadopt in opt.context: + if ("lazy", "joined") in loadopt.strategy: + # the "path" is a tuple showing the entity/relationships + # being targeted + + # TODO: Check for of_type. + # TODO: Check whether entity is a string, unsupported. + # TODO check for wild card. + entities.add(loadopt.path[-1].entity) + + return entities + + else: + # Start POC code from @zzzeek (Mike Bayer) + # TODO: Still needs to be generalized & support other options. + + # the structure we're dealing with is essentially: + + # (path, strategy, options) + # where "path" indicates what it is we are loading, + # like (A, A.bs, B, B.cs, C) + # "strategy" is a tuple that keys to one of the loader strategies, + # some of them apply to relationships and others to column attributes + # then "options" is extra stuff like "innerjoin=True" + def get_joinedload_entities(stmt): + """Get extra entities that are loaded from a ``stmt`` due to joinedload + options specified in the statement options. + + These entities will not be returned directly by the query, but will prepopulate + relationships in the returned data. + + For example:: + + get_joinedload_entities(query(A).options(joinedload(A.bs))) == {A, B} + """ + # there are two kinds of options that both represent the same information, + # just in different ways. This is largely a product of legacy options + # that have things like strings, i.e. joinedload("addresses"). note we + # aren't covering that here, which is legacy form. you can if you want + # raise an exception if you detect that form here. + + entities = set() + + for opt in stmt._with_options: + if hasattr(opt, "_to_bind"): + # these options are called _UnboundLoad + for b in opt._to_bind: + if ("lazy", "joined") in b.strategy: + # the "path" is a tuple showing the entity/relationships + # being targeted + + # TODO check for wild card. + # TODO: Check whether entity is a string. + entities.add(b.path[-1].entity) + elif hasattr(opt, "context"): + # these options are called Load + for key, loadopt in opt.context.items(): + if ( + key[0] == "loader" + and ("lazy", "joined") in loadopt.strategy + ): + # the "path" is a tuple showing the entity/relationships + # being targeted + + # TODO: Check for of_type. + # TODO: Check whether entity is a string, unsupported. + # TODO check for wild card. + entities.add(key[1][-1].entity) + + return entities def all_entities_in_statement(statement): """ @@ -106,63 +226,3 @@ class A(Base): default_entities.add(rel.mapper) return default_entities - - # Start POC code from @zzzeek (Mike Bayer) - # TODO: Still needs to be generalized & support other options. - - # the structure we're dealing with is essentially: - - # (path, strategy, options) - # where "path" indicates what it is we are loading, - # like (A, A.bs, B, B.cs, C) - # "strategy" is a tuple that keys to one of the loader strategies, - # some of them apply to relationships and others to column attributes - # then "options" is extra stuff like "innerjoin=True" - def get_joinedload_entities(stmt): - """Get extra entities that are loaded from a ``stmt`` due to joinedload - options specified in the statement options. - - These entities will not be returned directly by the query, but will prepopulate - relationships in the returned data. - - For example:: - - get_joinedload_entities(query(A).options(joinedload(A.bs))) == {A, B} - """ - # there are two kinds of options that both represent the same information, - # just in different ways. This is largely a product of legacy options - # that have things like strings, i.e. joinedload("addresses"). note we - # aren't covering that here, which is legacy form. you can if you want - # raise an exception if you detect that form here. - - entities = set() - - for opt in stmt._with_options: - if hasattr(opt, "_to_bind"): - # these options are called _UnboundLoad - for b in opt._to_bind: - if ("lazy", "joined") in b.strategy: - # the "path" is a tuple showing the entity/relationships - # being targeted - - # TODO check for wild card. - # TODO: Check whether entity is a string. - entities.add(b.path[-1].entity) - elif hasattr(opt, "context"): - # these options are called Load - for key, loadopt in opt.context.items(): - if key[0] == "loader" and ("lazy", "joined") in loadopt.strategy: - # the "path" is a tuple showing the entity/relationships - # being targeted - - # TODO: Check for of_type. - # TODO: Check whether entity is a string, unsupported. - # TODO check for wild card. - entities.add(key[1][-1].entity) - - return entities - -except ImportError: - # This code should not be called for SQLAlchemy 1.4. - def all_entities_in_statement(statement): - raise NotImplementedError("Unsupported on SQLAlchemy < 1.4") diff --git a/languages/python/sqlalchemy-oso/tests/models.py b/languages/python/sqlalchemy-oso/tests/models.py index 9ecf3138cc..6e577362ae 100644 --- a/languages/python/sqlalchemy-oso/tests/models.py +++ b/languages/python/sqlalchemy-oso/tests/models.py @@ -1,8 +1,18 @@ from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship + +from sqlalchemy_oso.compat import USING_SQLAlchemy_v1_3 + +if USING_SQLAlchemy_v1_3: + # mypy: ignore-errors + from sqlalchemy.ext.declarative import declarative_base +else: + # mypy: ignore-errors + from sqlalchemy.orm import declarative_base + from sqlalchemy.schema import Table +# mypy: ignore-errors ModelBase = declarative_base(name="ModelBase") diff --git a/languages/python/sqlalchemy-oso/tests/test_advanced_queries_14.py b/languages/python/sqlalchemy-oso/tests/test_advanced_queries_14.py index d12b41c9b4..16092c3895 100644 --- a/languages/python/sqlalchemy-oso/tests/test_advanced_queries_14.py +++ b/languages/python/sqlalchemy-oso/tests/test_advanced_queries_14.py @@ -26,7 +26,7 @@ with_loader_criteria, ) -from sqlalchemy_oso.compat import USING_SQLAlchemy_v1_3 +from sqlalchemy_oso.compat import USING_SQLAlchemy_v1_3, USING_SQLAlchemy_v2_0 from sqlalchemy_oso.session import AuthorizedSession from sqlalchemy_oso.sqlalchemy_utils import ( all_entities_in_statement, @@ -118,7 +118,6 @@ def test_get_column_entities(stmt, o): (select(A), set()), (select(A).options(joinedload(A.bs)), {B}), (select(A).options(joinedload(A.bs).joinedload(B.cs)), {B, C}), - (select(A).options(Load(A).joinedload("bs")), {B}), pytest.param( select(A).options(Load(A).joinedload("*")), set(), @@ -131,21 +130,10 @@ def test_get_column_entities(stmt, o): ), ), ) -def test_get_joinedload_entities(stmt, o): - assert set(map(to_class, get_joinedload_entities(stmt))) == o - - -@pytest.mark.parametrize( - "stmt,o", - ( - pytest.param( - select(A).options(joinedload("A.bs")), - {B}, - marks=pytest.mark.xfail(reason="String doesn't work"), - ), - ), +@pytest.mark.skipif( + USING_SQLAlchemy_v2_0, reason="flask sqlalchemy does not support 2.0" ) -def test_get_joinedload_entities_str(stmt, o): +def test_get_joinedload_entities(stmt, o): assert set(map(to_class, get_joinedload_entities(stmt))) == o diff --git a/languages/python/sqlalchemy-oso/tests/test_flask.py b/languages/python/sqlalchemy-oso/tests/test_flask.py index 32e9ac266a..48b01c4258 100644 --- a/languages/python/sqlalchemy-oso/tests/test_flask.py +++ b/languages/python/sqlalchemy-oso/tests/test_flask.py @@ -2,6 +2,7 @@ from sqlalchemy import Column, Integer from sqlalchemy.orm import Session +from sqlalchemy_oso.compat import USING_SQLAlchemy_v2_0 from sqlalchemy_oso.flask import AuthorizedSQLAlchemy from sqlalchemy_oso.session import Permissions @@ -46,6 +47,9 @@ def sqlalchemy(flask_app, oso): return sqlalchemy +@pytest.mark.skipif( + USING_SQLAlchemy_v2_0, reason="flask sqlalchemy does not support 2.0" +) def test_authorized_sqlalchemy(ctx, oso, sqlalchemy, post_fixtures): global checked_permissions checked_permissions = {Post: "read"} @@ -68,6 +72,9 @@ def test_authorized_sqlalchemy(ctx, oso, sqlalchemy, post_fixtures): assert sqlalchemy.session.query(Post).count() == 1 +@pytest.mark.skipif( + USING_SQLAlchemy_v2_0, reason="flask sqlalchemy does not support 2.0" +) def test_flask_model(ctx, oso, sqlalchemy): class TestModel(sqlalchemy.Model): id = Column(Integer, primary_key=True) diff --git a/languages/python/sqlalchemy-oso/tests/test_post_relationship.py b/languages/python/sqlalchemy-oso/tests/test_post_relationship.py index f450cf5c38..a1b0950821 100644 --- a/languages/python/sqlalchemy-oso/tests/test_post_relationship.py +++ b/languages/python/sqlalchemy-oso/tests/test_post_relationship.py @@ -3,10 +3,11 @@ Tests come from the relationship document & operations laid out there https://www.notion.so/osohq/Relationships-621b884edbc6423f93d29e6066e58d16. """ + import pytest from sqlalchemy_oso.auth import authorize_model -from sqlalchemy_oso.compat import USING_SQLAlchemy_v1_3 +from sqlalchemy_oso.compat import USING_SQLAlchemy_v1_3, USING_SQLAlchemy_v2_0 from .conftest import print_query from .models import Category, Post, Tag, User @@ -183,7 +184,7 @@ def tag_test_fixture(session): # HACK! objects = {} - for (name, local) in locals().items(): + for name, local in locals().items(): if name != "session" and name != "objects": session.add(local) @@ -259,7 +260,7 @@ def tag_nested_test_fixture(session): # HACK! objects = {} - for (name, local) in locals().items(): + for name, local in locals().items(): if name != "session" and name != "objects": session.add(local) @@ -357,7 +358,7 @@ def tag_nested_many_many_test_fixture(session): # HACK! objects = {} - for (name, local) in locals().items(): + for name, local in locals().items(): if name != "session" and name != "objects": session.add(local) @@ -620,13 +621,18 @@ def test_empty_constraints_in(session, oso, tag_nested_many_many_test_fixture): # SQLAlchemy 1.4. true_clause = " AND 1 = 1" + if USING_SQLAlchemy_v2_0: + tables = "tags, post_tags" + else: + tables = "post_tags, tags" + assert str(posts) == ( "SELECT posts.id AS posts_id, posts.contents AS posts_contents, posts.title AS" + " posts_title, posts.access_level AS posts_access_level," + " posts.created_by_id AS posts_created_by_id, posts.needs_moderation AS posts_needs_moderation" + " \nFROM posts" + " \nWHERE EXISTS (SELECT 1" - + " \nFROM post_tags, tags" + + f" \nFROM {tables}" + f" \nWHERE posts.id = post_tags.post_id AND tags.name = post_tags.tag_id{true_clause})" ) posts = posts.all() @@ -648,13 +654,18 @@ def test_in_with_constraints_but_no_matching_objects( posts = session.query(Post).filter( authorize_model(oso, user, "read", session, Post) ) + + if USING_SQLAlchemy_v2_0: + tables = "tags, post_tags" + else: + tables = "post_tags, tags" assert str(posts) == ( "SELECT posts.id AS posts_id, posts.contents AS posts_contents, posts.title AS posts_title," + " posts.access_level AS posts_access_level," + " posts.created_by_id AS posts_created_by_id, posts.needs_moderation AS posts_needs_moderation" + " \nFROM posts" + " \nWHERE EXISTS (SELECT 1" - + " \nFROM post_tags, tags" + + f" \nFROM {tables}" + " \nWHERE posts.id = post_tags.post_id AND tags.name = post_tags.tag_id AND tags.name = ?)" ) posts = posts.all() diff --git a/languages/python/sqlalchemy-oso/tests/test_sqlalchemy.py b/languages/python/sqlalchemy-oso/tests/test_sqlalchemy.py index 132b9bf1a6..e5de07a39e 100644 --- a/languages/python/sqlalchemy-oso/tests/test_sqlalchemy.py +++ b/languages/python/sqlalchemy-oso/tests/test_sqlalchemy.py @@ -1,4 +1,5 @@ """Test hooks & SQLAlchemy API integrations.""" + import pytest from sqlalchemy.orm import aliased, joinedload diff --git a/languages/python/sqlalchemy-oso/tox.ini b/languages/python/sqlalchemy-oso/tox.ini index fea7bc9c9f..f8de6a418e 100644 --- a/languages/python/sqlalchemy-oso/tox.ini +++ b/languages/python/sqlalchemy-oso/tox.ini @@ -1,6 +1,6 @@ [tox] skip_missing_interpreters=true -envlist = {py3,pypy3}-sqlalchemy{13,14}-{earliest,latest} +envlist = {py3,pypy3}-sqlalchemy{13,14,20}-{earliest,latest} [testenv] passenv = CIBUILDWHEEL @@ -11,6 +11,8 @@ deps = sqlalchemy13-latest: SQLAlchemy~=1.3.17 sqlalchemy14-earliest: SQLAlchemy==1.4.0 sqlalchemy14-latest: SQLAlchemy~=1.4.0 + sqlalchemy20-earliest: SQLAlchemy==2.0.0 + sqlalchemy20-latest: SQLAlchemy~=2.0.0 commands = pytest allowlist_externals = bash commands_pre =