Skip to content

Commit

Permalink
Merge branch 'main' into fix/enforce-keyword-argument-in-make-query-f…
Browse files Browse the repository at this point in the history
…ixture
  • Loading branch information
nfx authored Nov 15, 2024
2 parents 42f7199 + 60f4f6d commit 53b58d3
Show file tree
Hide file tree
Showing 15 changed files with 346 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/acceptance.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
git fetch origin $GITHUB_HEAD_REF:$GITHUB_HEAD_REF
- name: Run integration tests
uses: databrickslabs/sandbox/acceptance@acceptance/v0.3.1
uses: databrickslabs/sandbox/acceptance@acceptance/v0.4.2
with:
vault_uri: ${{ secrets.VAULT_URI }}
env:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/downstreams.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
run: |
pip install hatch==1.9.4
- name: Downstreams
uses: databrickslabs/sandbox/downstreams@acceptance/v0.3.1
uses: databrickslabs/sandbox/downstreams@acceptance/v0.4.2
with:
repo: ${{ matrix.downstream.name }}
org: databrickslabs
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: pip install hatch==1.9.4

- name: Run nightly tests
uses: databrickslabs/sandbox/acceptance@acceptance/v0.3.1
uses: databrickslabs/sandbox/acceptance@acceptance/v0.4.2
with:
vault_uri: ${{ secrets.VAULT_URI }}
create_issues: true
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
run: hatch run test

- name: Publish test coverage
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}

Expand Down
114 changes: 92 additions & 22 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion scripts/gen-readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def discover_fixtures() -> list[Fixture]:
upstreams = []
sig = inspect.signature(fn)
for param in sig.parameters.values():
if param.name in {'fresh_local_wheel_file', 'monkeypatch', 'log_workspace_link'}:
if param.name in {'fresh_local_wheel_file', 'monkeypatch', 'log_workspace_link', 'request'}:
continue
upstreams.append(param.name)
see_also[param.name].add(fixture)
Expand Down
12 changes: 12 additions & 0 deletions src/databricks/labs/pytester/fixtures/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,15 @@ def inner(name: str, path: str, *, anchor: bool = True):
_LOG.info(f'Created {name}: {url}')

return inner


@fixture
def log_account_link(acc):
"""Returns a function to log an account link."""

def inner(name: str, path: str, *, anchor: bool = False):
a = '#' if anchor else ''
url = f'https://{acc.config.hostname}/{a}{path}'
_LOG.info(f'Created {name}: {url}')

return inner
1 change: 1 addition & 0 deletions src/databricks/labs/pytester/fixtures/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
logger = logging.getLogger(__name__)


# TODO: replace with LSQL implementation
def escape_sql_identifier(path: str, *, maxsplit: int = 2) -> str:
"""
Escapes the path components to make them SQL safe.
Expand Down
195 changes: 192 additions & 3 deletions src/databricks/labs/pytester/fixtures/iam.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
import logging
import warnings
from collections.abc import Callable, Generator
from collections.abc import Callable, Generator, Iterable
from datetime import timedelta

import pytest
from pytest import fixture
from databricks.sdk import AccountGroupsAPI, GroupsAPI, WorkspaceClient
from databricks.sdk.credentials_provider import OAuthCredentialsProvider, OauthCredentialsStrategy
from databricks.sdk.oauth import ClientCredentials, Token
from databricks.sdk.service.oauth2 import CreateServicePrincipalSecretResponse
from databricks.labs.lsql import Row
from databricks.labs.lsql.backends import StatementExecutionBackend, SqlBackend
from databricks.sdk import AccountGroupsAPI, GroupsAPI, WorkspaceClient, AccountClient
from databricks.sdk.config import Config
from databricks.sdk.errors import ResourceConflict, NotFound
from databricks.sdk.retries import retried
from databricks.sdk.service import iam
from databricks.sdk.service.iam import User, Group
from databricks.sdk.service.iam import (
User,
Group,
ServicePrincipal,
Patch,
PatchOp,
ComplexValue,
PatchSchema,
WorkspacePermission,
)

from databricks.labs.pytester.fixtures.baseline import factory

Expand Down Expand Up @@ -183,3 +198,177 @@ def create(
return group

yield from factory(name, create, lambda item: interface.delete(item.id))


class RunAs:
def __init__(self, service_principal: ServicePrincipal, workspace_client: WorkspaceClient, env_or_skip):
self._service_principal = service_principal
self._workspace_client = workspace_client
self._env_or_skip = env_or_skip

@property
def ws(self):
return self._workspace_client

@property
def sql_backend(self) -> SqlBackend:
# TODO: Switch to `__getattr__` + `SubRequest` to get a generic way of initializing all workspace fixtures.
# This will allow us to remove the `sql_backend` fixture and make the `RunAs` class more generic.
# It turns out to be more complicated than it first appears, because we don't get these at pytest.collect phase.
warehouse_id = self._env_or_skip("DATABRICKS_WAREHOUSE_ID")
return StatementExecutionBackend(self._workspace_client, warehouse_id)

def sql_exec(self, statement: str) -> None:
return self.sql_backend.execute(statement)

def sql_fetch_all(self, statement: str) -> Iterable[Row]:
return self.sql_backend.fetch(statement)

def __getattr__(self, item: str):
if item in self.__dict__:
return self.__dict__[item]
fixture_value = self._request.getfixturevalue(item)
return fixture_value

@property
def display_name(self) -> str:
assert self._service_principal.display_name is not None
return self._service_principal.display_name

@property
def application_id(self) -> str:
assert self._service_principal.application_id is not None
return self._service_principal.application_id

def __repr__(self):
return f'RunAs({self.display_name})'


def _make_workspace_client(
ws: WorkspaceClient,
created_secret: CreateServicePrincipalSecretResponse,
service_principal: ServicePrincipal,
) -> WorkspaceClient:
oidc = ws.config.oidc_endpoints
assert oidc is not None, 'OIDC is required'
application_id = service_principal.application_id
secret_value = created_secret.secret
assert application_id is not None
assert secret_value is not None

token_source = ClientCredentials(
client_id=application_id,
client_secret=secret_value,
token_url=oidc.token_endpoint,
scopes=["all-apis"],
use_header=True,
)

def inner() -> dict[str, str]:
inner_token = token_source.token()
return {'Authorization': f'{inner_token.token_type} {inner_token.access_token}'}

def token() -> Token:
return token_source.token()

credentials_provider = OAuthCredentialsProvider(inner, token)
credentials_strategy = OauthCredentialsStrategy('oauth-m2m', lambda _: credentials_provider)
ws_as_spn = WorkspaceClient(host=ws.config.host, credentials_strategy=credentials_strategy)
return ws_as_spn


@fixture
def make_run_as(acc: AccountClient, ws: WorkspaceClient, make_random, env_or_skip, log_account_link, is_in_debug):
"""
This fixture provides a function to create an account service principal via [`acc` fixture](#acc-fixture) and
assign it to a workspace. The service principal is removed after the test is complete. The service principal is
created with a random display name and assigned to the workspace with the default permissions.
Use the `account_groups` argument to assign the service principal to account groups, which have the required
permissions to perform a specific action.
Example:
```python
def test_run_as_lower_privilege_user(make_run_as, ws):
run_as = make_run_as(account_groups=['account.group.name'])
through_query = next(run_as.sql_fetch_all("SELECT CURRENT_USER() AS my_name"))
me = ws.current_user.me()
assert me.user_name != through_query.my_name
```
Returned object has the following properties:
* `ws`: Workspace client that is authenticated as the ephemeral service principal.
* `sql_backend`: SQL backend that is authenticated as the ephemeral service principal.
* `sql_exec`: Function to execute a SQL statement on behalf of the ephemeral service principal.
* `sql_fetch_all`: Function to fetch all rows from a SQL statement on behalf of the ephemeral service principal.
* `display_name`: Display name of the ephemeral service principal.
* `application_id`: Application ID of the ephemeral service principal.
* if you want to have other fixtures available in the context of the ephemeral service principal, you can override
the [`ws` fixture](#ws-fixture) on the file level, which would make all workspace fixtures provided by this
plugin to run as lower privilege ephemeral service principal. You cannot combine it with the account-admin-level
principal you're using to create the ephemeral principal.
Example:
```python
from pytest import fixture
@fixture
def ws(make_run_as):
run_as = make_run_as(account_groups=['account.group.used.for.all.tests.in.this.file'])
return run_as.ws
def test_creating_notebook_on_behalf_of_ephemeral_principal(make_notebook):
notebook = make_notebook()
assert notebook.exists()
```
This fixture currently doesn't work with Databricks Metadata Service authentication on Azure Databricks.
"""

if ws.config.auth_type == 'metadata-service' and ws.config.is_azure:
# TODO: fix `invalid_scope: AADSTS1002012: The provided value for scope all-apis is not valid.` error
#
# We're having issues with the Azure Metadata Service and service principals. The error message is:
# Client credential flows must have a scope value with /.default suffixed to the resource identifier
# (application ID URI)
pytest.skip('Azure Metadata Service does not support service principals')

def create(*, account_groups: list[str] | None = None):
workspace_id = ws.get_workspace_id()
service_principal = acc.service_principals.create(display_name=f'spn-{make_random()}')
assert service_principal.id is not None
service_principal_id = int(service_principal.id)
created_secret = acc.service_principal_secrets.create(service_principal_id)
if account_groups:
group_mapping = {}
for group in acc.groups.list(attributes='id,displayName'):
if group.id is None:
continue
group_mapping[group.display_name] = group.id
for group_name in account_groups:
if group_name not in group_mapping:
raise ValueError(f'Group {group_name} does not exist')
group_id = group_mapping[group_name]
acc.groups.patch(
group_id,
operations=[
Patch(PatchOp.ADD, 'members', [ComplexValue(value=str(service_principal_id)).as_dict()]),
],
schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP],
)
permissions = [WorkspacePermission.USER]
acc.workspace_assignment.update(workspace_id, service_principal_id, permissions=permissions)
ws_as_spn = _make_workspace_client(ws, created_secret, service_principal)

log_account_link('account service principal', f'users/serviceprincipals/{service_principal_id}')

return RunAs(service_principal, ws_as_spn, env_or_skip)

def remove(run_as: RunAs):
service_principal_id = run_as._service_principal.id # pylint: disable=protected-access
assert service_principal_id is not None
acc.service_principals.delete(service_principal_id)

yield from factory("service principal", create, remove)
14 changes: 11 additions & 3 deletions src/databricks/labs/pytester/fixtures/notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,18 @@ def create(
default_content = "SELECT 1"
else:
raise ValueError(f"Unsupported language: {language}")
path = path or f"/Users/{ws.current_user.me().user_name}/dummy-{make_random(4)}-{watchdog_purge_suffix}"
current_user = ws.current_user.me()
path = path or f"/Users/{current_user.user_name}/dummy-{make_random(4)}-{watchdog_purge_suffix}"
workspace_path = WorkspacePath(ws, path)
if '@' not in current_user.user_name:
# If current user is a service principal added with `make_run_as`, there might be no home folder
workspace_path.parent.mkdir(exist_ok=True)
content = content or default_content
if isinstance(content, str):
content = io.BytesIO(content.encode(encoding))
if isinstance(ws, Mock): # For testing
ws.workspace.download.return_value = content if isinstance(content, io.BytesIO) else io.BytesIO(content)
ws.workspace.upload(path, content, language=language, format=format, overwrite=overwrite)
workspace_path = WorkspacePath(ws, path)
logger.info(f"Created notebook: {workspace_path.as_uri()}")
return workspace_path

Expand Down Expand Up @@ -110,10 +114,14 @@ def create(
suffix = ".sql"
else:
raise ValueError(f"Unsupported language: {language}")
path = path or f"/Users/{ws.current_user.me().user_name}/dummy-{make_random(4)}-{watchdog_purge_suffix}{suffix}"
current_user = ws.current_user.me()
path = path or f"/Users/{current_user.user_name}/dummy-{make_random(4)}-{watchdog_purge_suffix}{suffix}"
content = content or default_content
encoding = encoding or _DEFAULT_ENCODING
workspace_path = WorkspacePath(ws, path)
if '@' not in current_user.user_name:
# If current user is a service principal added with `make_run_as`, there might be no home folder
workspace_path.parent.mkdir(exist_ok=True)
if isinstance(content, bytes):
workspace_path.write_bytes(content)
else:
Expand Down
5 changes: 4 additions & 1 deletion src/databricks/labs/pytester/fixtures/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
make_random,
product_info,
log_workspace_link,
log_account_link,
)
from databricks.labs.pytester.fixtures.sql import sql_backend, sql_exec, sql_fetch_all
from databricks.labs.pytester.fixtures.compute import (
Expand All @@ -16,7 +17,7 @@
make_pipeline,
make_warehouse,
)
from databricks.labs.pytester.fixtures.iam import make_group, make_acc_group, make_user
from databricks.labs.pytester.fixtures.iam import make_group, make_acc_group, make_user, make_run_as
from databricks.labs.pytester.fixtures.catalog import (
make_udf,
make_catalog,
Expand Down Expand Up @@ -60,6 +61,7 @@
'debug_env',
'env_or_skip',
'ws',
'make_run_as',
'acc',
'spark',
'sql_backend',
Expand Down Expand Up @@ -105,6 +107,7 @@
'make_warehouse_permissions',
'make_lakeview_dashboard_permissions',
'log_workspace_link',
'log_account_link',
'make_dashboard_permissions',
'make_alert_permissions',
'make_query',
Expand Down
1 change: 1 addition & 0 deletions tests/integration/fixtures/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_schema_fixture(make_schema):
logger.info(f"Created new schema: {make_schema()}")


@pytest.mark.skip("Invalid configuration value detected for fs.azure.account.key")
def test_managed_schema_fixture(make_schema, make_random, env_or_skip):
schema_name = f"dummy_s{make_random(4)}".lower()
schema_location = f"{env_or_skip('TEST_MOUNT_CONTAINER')}/a/{schema_name}"
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/fixtures/test_iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,10 @@ def test_new_account_group(make_acc_group, acc):
group = make_acc_group()
loaded = acc.groups.get(group.id)
assert group.display_name == loaded.display_name


def test_run_as_lower_privilege_user(make_run_as, ws):
run_as = make_run_as(account_groups=['role.labs.lsql.write'])
through_query = next(run_as.sql_fetch_all("SELECT CURRENT_USER() AS my_name"))
current_user = ws.current_user.me()
assert current_user.user_name != through_query.my_name
12 changes: 12 additions & 0 deletions tests/integration/fixtures/test_run_as.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pytest import fixture


@fixture
def ws(make_run_as):
run_as = make_run_as(account_groups=['role.labs.lsql.write'])
return run_as.ws


def test_creating_notebook_on_behalf_of_ephemeral_principal(make_notebook):
notebook = make_notebook()
assert notebook.exists()
Loading

0 comments on commit 53b58d3

Please sign in to comment.