Skip to content

Commit

Permalink
Bump unit testing coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
nfx committed Sep 13, 2024
1 parent 8102ee3 commit d22a880
Show file tree
Hide file tree
Showing 29 changed files with 543 additions and 345 deletions.
114 changes: 61 additions & 53 deletions README.md

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ exclude_lines = [
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]
exclude_also = [
"import",
"__all__",
"@pytest.fixture"
]

[tool.pylint.main]
# PyLint configuration is adapted from Google Python Style Guide with modifications.
Expand Down
58 changes: 0 additions & 58 deletions src/databricks/labs/pytester/environment.py

This file was deleted.

45 changes: 0 additions & 45 deletions src/databricks/labs/pytester/fixtures/baseline.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
import logging
import random
import string
from datetime import timedelta, datetime, timezone
from functools import partial

from pytest import fixture

from databricks.labs.lsql.backends import StatementExecutionBackend
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import DatabricksError

_LOG = logging.getLogger(__name__)


"""Preserve resources created during tests for at least this long."""
TEST_RESOURCE_PURGE_TIMEOUT = timedelta(hours=1)


@fixture
def make_random():
"""
Expand Down Expand Up @@ -175,41 +168,3 @@ def inner(name: str, path: str, *, anchor: bool = True):
_LOG.info(f'Created {name}: {url}')

return inner


@fixture
def sql_backend(ws, env_or_skip) -> StatementExecutionBackend:
"""Create and provide a SQL backend for executing statements.
Requires the environment variable `DATABRICKS_WAREHOUSE_ID` to be set.
"""
warehouse_id = env_or_skip("DATABRICKS_WAREHOUSE_ID")
return StatementExecutionBackend(ws, warehouse_id)


@fixture
def sql_exec(sql_backend):
"""Execute SQL statement and don't return any results."""
return partial(sql_backend.execute)


@fixture
def sql_fetch_all(sql_backend):
"""Fetch all rows from a SQL statement."""
return partial(sql_backend.fetch)


def get_test_purge_time(timeout: timedelta = TEST_RESOURCE_PURGE_TIMEOUT) -> str:
"""Purge time for test objects, representing the (UTC-based) hour from which objects may be purged."""
# Note: this code is duplicated in the workflow installer (WorkflowsDeployment) so that it can avoid the
# transitive pytest deployment from this module.
now = datetime.now(timezone.utc)
purge_deadline = now + timeout
# Round UP to the next hour boundary: that is when resources will be deleted.
purge_hour = purge_deadline + (datetime.min.replace(tzinfo=timezone.utc) - purge_deadline) % timedelta(hours=1)
return purge_hour.strftime("%Y%m%d%H")


def get_purge_suffix() -> str:
"""HEX-encoded purge time suffix for test objects."""
return f'ra{int(get_test_purge_time()):x}'
81 changes: 31 additions & 50 deletions src/databricks/labs/pytester/fixtures/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@
from collections.abc import Generator, Callable
from pytest import fixture
from databricks.labs.blueprint.commands import CommandExecutor
from databricks.sdk.errors import NotFound
from databricks.sdk.errors import DatabricksError
from databricks.sdk.service.catalog import (
FunctionInfo,
SchemaInfo,
TableInfo,
TableType,
DataSourceFormat,
CatalogInfo,
ColumnInfo,
StorageCredentialInfo,
AwsIamRoleRequest,
AzureServicePrincipal,
)
from databricks.sdk.service.compute import Language
from databricks.labs.pytester.fixtures.baseline import factory, get_test_purge_time
from databricks.labs.pytester.fixtures.baseline import factory

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,6 +45,7 @@ def make_table(
make_schema,
make_random,
log_workspace_link,
watchdog_remove_after,
) -> Generator[Callable[..., TableInfo], None, None]:
"""
Create a table and return its info. Remove it after the test. Returns instance of `databricks.sdk.service.catalog.TableInfo`.
Expand Down Expand Up @@ -75,33 +75,23 @@ def test_catalog_fixture(make_catalog, make_schema, make_table):
```
"""

def generate_sql_schema(columns: list[ColumnInfo]) -> str:
def generate_sql_schema(columns: list[tuple[str, str]]) -> str:
"""Generate a SQL schema from columns."""
schema = "("
for index, column in enumerate(columns):
schema += escape_sql_identifier(column.name or str(index), maxsplit=0)
if column.type_name is None:
type_name = "STRING"
else:
type_name = column.type_name.value
for index, (col_name, type_name) in enumerate(columns):
schema += escape_sql_identifier(col_name or str(index), maxsplit=0)
schema += f" {type_name}, "
schema = schema[:-2] + ")" # Remove the last ', '
return schema

def generate_sql_column_casting(existing_columns: list[ColumnInfo], new_columns: list[ColumnInfo]) -> str:
def generate_sql_column_casting(existing_columns: list[tuple[str, str]], new_columns: list[tuple[str, str]]) -> str:
"""Generate the SQL to cast columns"""
if any(column.name is None for column in existing_columns):
raise ValueError(f"Columns should have a name: {existing_columns}")
if len(new_columns) > len(existing_columns):
raise ValueError(f"Too many columns: {new_columns}")
select_expressions = []
for index, (existing_column, new_column) in enumerate(zip(existing_columns, new_columns)):
column_name_new = escape_sql_identifier(new_column.name or str(index), maxsplit=0)
if new_column.type_name is None:
type_name = "STRING"
else:
type_name = new_column.type_name.value
select_expression = f"CAST({existing_column.name} AS {type_name}) AS {column_name_new}"
for index, ((existing_name, _), (new_name, new_type)) in enumerate(zip(existing_columns, new_columns)):
column_name_new = escape_sql_identifier(new_name or str(index), maxsplit=0)
select_expression = f"CAST({existing_name} AS {new_type}) AS {column_name_new}"
select_expressions.append(select_expression)
select = ", ".join(select_expressions)
return select
Expand All @@ -120,7 +110,7 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
tbl_properties: dict[str, str] | None = None,
hiveserde_ddl: str | None = None,
storage_override: str | None = None,
columns: list[ColumnInfo] | None = None,
columns: list[tuple[str, str]] | None = None,
) -> TableInfo:
if schema_name is None:
schema = make_schema(catalog_name=catalog_name)
Expand Down Expand Up @@ -154,14 +144,14 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
else:
# These are the columns from the JSON dataset below
dataset_columns = [
ColumnInfo(name="calories_burnt"),
ColumnInfo(name="device_id"),
ColumnInfo(name="id"),
ColumnInfo(name="miles_walked"),
ColumnInfo(name="num_steps"),
ColumnInfo(name="timestamp"),
ColumnInfo(name="user_id"),
ColumnInfo(name="value"),
('calories_burnt', 'STRING'),
('device_id', 'STRING'),
('id', 'STRING'),
('miles_walked', 'STRING'),
('num_steps', 'STRING'),
('timestamp', 'STRING'),
('user_id', 'STRING'),
('value', 'STRING'),
]
select = generate_sql_column_casting(dataset_columns, columns)
# Modified, otherwise it will identify the table as a DB Dataset
Expand Down Expand Up @@ -193,9 +183,9 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
storage_location = f"dbfs:/user/hive/warehouse/{schema_name}/{name}"
ddl = f"{ddl} {schema}"
if tbl_properties:
tbl_properties.update({"RemoveAfter": get_test_purge_time()})
tbl_properties.update({"RemoveAfter": watchdog_remove_after})
else:
tbl_properties = {"RemoveAfter": get_test_purge_time()}
tbl_properties = {"RemoveAfter": watchdog_remove_after}

str_properties = ",".join([f" '{k}' = '{v}' " for k, v in tbl_properties.items()])

Expand Down Expand Up @@ -238,19 +228,22 @@ def create( # pylint: disable=too-many-locals,too-many-arguments,too-many-state
def remove(table_info: TableInfo):
try:
sql_backend.execute(f"DROP TABLE IF EXISTS {table_info.full_name}")
except RuntimeError as e:
except DatabricksError as e:
if "Cannot drop a view" in str(e):
sql_backend.execute(f"DROP VIEW IF EXISTS {table_info.full_name}")
elif "SCHEMA_NOT_FOUND" in str(e):
logger.warning("Schema was already dropped while executing the test", exc_info=e)
else:
raise e

yield from factory("table", create, remove)


@fixture
def make_schema(sql_backend, make_random, log_workspace_link) -> Generator[Callable[..., SchemaInfo], None, None]:
def make_schema(
sql_backend,
make_random,
log_workspace_link,
watchdog_remove_after,
) -> Generator[Callable[..., SchemaInfo], None, None]:
"""
Create a schema and return its info. Remove it after the test. Returns instance of `databricks.sdk.service.catalog.SchemaInfo`.
Expand All @@ -272,20 +265,14 @@ def create(*, catalog_name: str = "hive_metastore", name: str | None = None) ->
if name is None:
name = f"dummy_S{make_random(4)}".lower()
full_name = f"{catalog_name}.{name}".lower()
sql_backend.execute(f"CREATE SCHEMA {full_name} WITH DBPROPERTIES (RemoveAfter={get_test_purge_time()})")
sql_backend.execute(f"CREATE SCHEMA {full_name} WITH DBPROPERTIES (RemoveAfter={watchdog_remove_after})")
schema_info = SchemaInfo(catalog_name=catalog_name, name=name, full_name=full_name)
path = f'explore/data/{schema_info.catalog_name}/{schema_info.name}'
log_workspace_link(f'{schema_info.full_name} schema', path)
return schema_info

def remove(schema_info: SchemaInfo):
try:
sql_backend.execute(f"DROP SCHEMA IF EXISTS {schema_info.full_name} CASCADE")
except RuntimeError as e:
if "SCHEMA_NOT_FOUND" in str(e):
logger.warning("Schema was already dropped while executing the test", exc_info=e)
else:
raise e
sql_backend.execute(f"DROP SCHEMA IF EXISTS {schema_info.full_name} CASCADE")

yield from factory("schema", create, remove)

Expand Down Expand Up @@ -390,13 +377,7 @@ def create(
return udf_info

def remove(udf_info: FunctionInfo):
try:
sql_backend.execute(f"DROP FUNCTION IF EXISTS {udf_info.full_name}")
except NotFound as e:
if "SCHEMA_NOT_FOUND" in str(e):
logger.warning("Schema was already dropped while executing the test", exc_info=e)
else:
raise e
sql_backend.execute(f"DROP FUNCTION IF EXISTS {udf_info.full_name}")

yield from factory("table", create, remove)

Expand Down
Loading

0 comments on commit d22a880

Please sign in to comment.