diff --git a/docs/api_ibis/fugue_ibis.rst b/docs/api_ibis/fugue_ibis.rst index 91d68a5c..a3fb6842 100644 --- a/docs/api_ibis/fugue_ibis.rst +++ b/docs/api_ibis/fugue_ibis.rst @@ -32,6 +32,22 @@ fugue\_ibis .. |FugueDataTypes| replace:: :doc:`Fugue Data Types ` +fugue\_ibis.dataframe +--------------------- + +.. automodule:: fugue_ibis.dataframe + :members: + :undoc-members: + :show-inheritance: + +fugue\_ibis.execution\_engine +----------------------------- + +.. automodule:: fugue_ibis.execution_engine + :members: + :undoc-members: + :show-inheritance: + fugue\_ibis.extensions ---------------------- diff --git a/docs/api_ray/fugue_ray.rst b/docs/api_ray/fugue_ray.rst index ff15f483..a09add01 100644 --- a/docs/api_ray/fugue_ray.rst +++ b/docs/api_ray/fugue_ray.rst @@ -27,10 +27,10 @@ fugue\_ray .. |FugueDataTypes| replace:: :doc:`Fugue Data Types ` -fugue\_ray.dateframe +fugue\_ray.dataframe -------------------- -.. automodule:: fugue_ray.dateframe +.. automodule:: fugue_ray.dataframe :members: :undoc-members: :show-inheritance: diff --git a/fugue/_utils/registry.py b/fugue/_utils/registry.py index e5224ed4..93f3462a 100644 --- a/fugue/_utils/registry.py +++ b/fugue/_utils/registry.py @@ -1,8 +1,9 @@ from typing import Callable from triad import conditional_dispatcher +from triad.utils.dispatcher import ConditionalDispatcher _FUGUE_ENTRYPOINT = "fugue.plugins" -def fugue_plugin(func: Callable) -> Callable: - return conditional_dispatcher(entry_point=_FUGUE_ENTRYPOINT)(func) +def fugue_plugin(func: Callable) -> ConditionalDispatcher: + return conditional_dispatcher(entry_point=_FUGUE_ENTRYPOINT)(func) # type: ignore diff --git a/fugue/dataframe/__init__.py b/fugue/dataframe/__init__.py index 47e0f826..d8eadb34 100644 --- a/fugue/dataframe/__init__.py +++ b/fugue/dataframe/__init__.py @@ -1,14 +1,20 @@ # flake8: noqa -from fugue.dataframe.array_dataframe import ArrayDataFrame -from fugue.dataframe.arrow_dataframe import ArrowDataFrame -from fugue.dataframe.dataframe import ( +from .array_dataframe import ArrayDataFrame +from .arrow_dataframe import ArrowDataFrame +from .dataframe import ( DataFrame, LocalBoundedDataFrame, LocalDataFrame, YieldedDataFrame, ) -from fugue.dataframe.dataframe_iterable_dataframe import LocalDataFrameIterableDataFrame -from fugue.dataframe.dataframes import DataFrames -from fugue.dataframe.iterable_dataframe import IterableDataFrame -from fugue.dataframe.pandas_dataframe import PandasDataFrame -from fugue.dataframe.utils import to_local_bounded_df, to_local_df +from .dataframe_iterable_dataframe import LocalDataFrameIterableDataFrame +from .dataframes import DataFrames +from .iterable_dataframe import IterableDataFrame +from .pandas_dataframe import PandasDataFrame +from .utils import ( + get_dataframe_column_names, + normalize_dataframe_column_names, + rename_dataframe_column_names, + to_local_bounded_df, + to_local_df, +) diff --git a/fugue/dataframe/utils.py b/fugue/dataframe/utils.py index 9204cc55..9165a22d 100644 --- a/fugue/dataframe/utils.py +++ b/fugue/dataframe/utils.py @@ -2,7 +2,7 @@ import json import os import pickle -from typing import Any, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import pandas as pd import pyarrow as pa @@ -17,6 +17,137 @@ from triad.exceptions import InvalidOperationError from triad.utils.assertion import assert_arg_not_none from triad.utils.assertion import assert_or_throw as aot +from triad.utils.rename import normalize_names + +from .._utils.registry import fugue_plugin + + +@fugue_plugin +def get_dataframe_column_names(df: Any) -> List[Any]: # pragma: no cover + """A generic function to get column names of any dataframe + + :param df: the dataframe object + :return: the column names + + .. note:: + + In order to support a new type of dataframe, an implementation must + be registered, for example + + .. code-block::python + + @get_dataframe_column_names.candidate(lambda df: isinstance(df, pa.Table)) + def _get_pyarrow_dataframe_columns(df: pa.Table) -> List[Any]: + return [f.name for f in df.schema] + """ + raise NotImplementedError(f"{type(df)} is not supported") + + +@fugue_plugin +def rename_dataframe_column_names(df: Any, names: Dict[str, Any]) -> Any: + """A generic function to rename column names of any dataframe + + :param df: the dataframe object + :param names: the rename operations as a dict: ``old name => new name`` + :return: the renamed dataframe + + .. note:: + + In order to support a new type of dataframe, an implementation must + be registered, for example + + .. code-block::python + + @rename_dataframe_column_names.candidate( + lambda df, *args, **kwargs: isinstance(df, pd.DataFrame) + ) + def _rename_pandas_dataframe( + df: pd.DataFrame, names: Dict[str, Any] + ) -> pd.DataFrame: + if len(names) == 0: + return df + return df.rename(columns=names) + """ + if len(names) == 0: + return df + else: # pragma: no cover + raise NotImplementedError(f"{type(df)} is not supported") + + +def normalize_dataframe_column_names(df: Any) -> Tuple[Any, Dict[str, Any]]: + """A generic function to normalize any dataframe's column names to follow + Fugue naming rules + + .. note:: + + This is a temporary solution before + :class:`~triad:triad.collections.schema.Schema` + can take arbitrary names + + .. admonition:: Examples + + * ``[0,1]`` => ``{"_0":0, "_1":1}`` + * ``["1a","2b"]`` => ``{"_1a":"1a", "_2b":"2b"}`` + * ``["*a","-a"]`` => ``{"_a":"*a", "_a_1":"-a"}`` + + :param df: a dataframe object + :return: the renamed dataframe and the rename operations as a dict that + can **undo** the change + + .. seealso:: + + * :func:`~.get_dataframe_column_names` + * :func:`~.rename_dataframe_column_names` + * :func:`~triad:triad.utils.rename.normalize_names` + """ + cols = get_dataframe_column_names(df) + names = normalize_names(cols) + if len(names) == 0: + return df, {} + undo = {v: k for k, v in names.items()} + return (rename_dataframe_column_names(df, names), undo) + + +@get_dataframe_column_names.candidate(lambda df: isinstance(df, pd.DataFrame)) +def _get_pandas_dataframe_columns(df: pd.DataFrame) -> List[Any]: + return list(df.columns) + + +@rename_dataframe_column_names.candidate( + lambda df, *args, **kwargs: isinstance(df, pd.DataFrame) +) +def _rename_pandas_dataframe(df: pd.DataFrame, names: Dict[str, Any]) -> pd.DataFrame: + if len(names) == 0: + return df + return df.rename(columns=names) + + +@get_dataframe_column_names.candidate(lambda df: isinstance(df, pa.Table)) +def _get_pyarrow_dataframe_columns(df: pa.Table) -> List[Any]: + return [f.name for f in df.schema] + + +@rename_dataframe_column_names.candidate( + lambda df, *args, **kwargs: isinstance(df, pa.Table) +) +def _rename_pyarrow_dataframe(df: pa.Table, names: Dict[str, Any]) -> pa.Table: + if len(names) == 0: + return df + return df.rename_columns([names.get(f.name, f.name) for f in df.schema]) + + +@get_dataframe_column_names.candidate(lambda df: isinstance(df, DataFrame)) +def _get_fugue_dataframe_columns(df: "DataFrame") -> List[Any]: + return df.schema.names + + +@rename_dataframe_column_names.candidate( + lambda df, *args, **kwargs: isinstance(df, DataFrame) +) +def _rename_fugue_dataframe(df: "DataFrame", names: Dict[str, Any]) -> "DataFrame": + if len(names) == 0: + return df + return df.rename(columns=names) def _pa_type_eq(t1: pa.DataType, t2: pa.DataType) -> bool: diff --git a/fugue_dask/dataframe.py b/fugue_dask/dataframe.py index 6665d526..7639c698 100644 --- a/fugue_dask/dataframe.py +++ b/fugue_dask/dataframe.py @@ -5,6 +5,10 @@ import pyarrow as pa from fugue.dataframe import ArrowDataFrame, DataFrame, LocalDataFrame, PandasDataFrame from fugue.dataframe.dataframe import _input_schema +from fugue.dataframe.utils import ( + get_dataframe_column_names, + rename_dataframe_column_names, +) from fugue.exceptions import FugueDataFrameOperationError from triad.collections.schema import Schema from triad.utils.assertion import assert_arg_not_none, assert_or_throw @@ -17,6 +21,20 @@ from fugue_dask._utils import DASK_UTILS +@get_dataframe_column_names.candidate(lambda df: isinstance(df, pd.DataFrame)) +def _get_dask_dataframe_columns(df: pd.DataFrame) -> List[Any]: + return list(df.columns) + + +@rename_dataframe_column_names.candidate( + lambda df, *args, **kwargs: isinstance(df, pd.DataFrame) +) +def _rename_dask_dataframe(df: pd.DataFrame, names: Dict[str, Any]) -> pd.DataFrame: + if len(names) == 0: + return df + return df.rename(columns=names) + + class DaskDataFrame(DataFrame): """DataFrame that wraps Dask DataFrame. Please also read |DataFrameTutorial| to understand this Fugue concept diff --git a/fugue_ibis/_compat.py b/fugue_ibis/_compat.py index 249bc2f4..8a06b6f9 100644 --- a/fugue_ibis/_compat.py +++ b/fugue_ibis/_compat.py @@ -1,7 +1,7 @@ # flake8: noqa # pylint: disable-all -try: +try: # pragma: no cover from ibis.expr.types import Table as IbisTable -except Exception: +except Exception: # pragma: no cover from ibis.expr.types import TableExpr as IbisTable diff --git a/fugue_ibis/execution_engine.py b/fugue_ibis/execution_engine.py index 1ae0c921..1dfac2b4 100644 --- a/fugue_ibis/execution_engine.py +++ b/fugue_ibis/execution_engine.py @@ -17,6 +17,7 @@ from triad.utils.assertion import assert_or_throw from .dataframe import IbisDataFrame +from ._compat import IbisTable import itertools _JOIN_RIGHT_SUFFIX = "_ibis_y__" @@ -40,10 +41,9 @@ def __init__(self, execution_engine: ExecutionEngine) -> None: self._ibis_engine: IbisExecutionEngine = execution_engine # type: ignore def select(self, dfs: DataFrames, statement: str) -> DataFrame: - for k, v in dfs.items(): - self._ibis_engine._to_ibis_dataframe(v).native.alias(k) - tb = self._ibis_engine.backend.sql(statement) - return self._ibis_engine._to_ibis_dataframe(tb) + return self._ibis_engine._to_ibis_dataframe( + self._ibis_engine._raw_select(statement, dfs) + ) class IbisExecutionEngine(ExecutionEngine): @@ -72,6 +72,9 @@ def _to_ibis_dataframe( ) -> IbisDataFrame: # pragma: no cover raise NotImplementedError + def _compile_sql(self, df: IbisDataFrame) -> str: + return str(df.native.compile()) + def to_df(self, df: Any, schema: Any = None, metadata: Any = None) -> DataFrame: return self._to_ibis_dataframe(df, schema=schema, metadata=metadata) @@ -249,7 +252,7 @@ def take( _presort = parse_presort_exp(presort) else: _presort = partition_spec.presort - tbn = self.get_temp_table_name() + tbn = "_temp" idf = self._to_ibis_dataframe(df) if len(_presort) == 0: @@ -264,7 +267,7 @@ def take( f"AS __fugue_take_param FROM {tbn}" f") WHERE __fugue_take_param<={n}" ) - tb = idf.native.alias(tbn).sql(sql) + tb = self._raw_select(sql, {tbn: idf}) return self._to_ibis_dataframe(tb[df.schema.names], metadata=metadata) sorts: List[str] = [] @@ -277,7 +280,7 @@ def take( if len(partition_spec.partition_by) == 0: sql = f"SELECT * FROM {tbn} {sort_expr} LIMIT {n}" - tb = idf.native.alias(tbn).sql(sql) + tb = self._raw_select(sql, {tbn: idf}) return self._to_ibis_dataframe(tb[df.schema.names], metadata=metadata) pcols = ", ".join( @@ -289,5 +292,16 @@ def take( f"AS __fugue_take_param FROM {tbn}" f") WHERE __fugue_take_param<={n}" ) - tb = idf.native.alias(tbn).sql(sql) + tb = self._raw_select(sql, {tbn: idf}) return self._to_ibis_dataframe(tb[df.schema.names], metadata=metadata) + + def _raw_select(self, statement: str, dfs: Dict[str, Any]) -> IbisTable: + cte: List[str] = [] + for k, v in dfs.items(): + idf = self._to_ibis_dataframe(v) + cte.append(k + " AS (" + self._compile_sql(idf) + ")") + if len(cte) > 0: + sql = "WITH " + ",\n".join(cte) + "\n" + statement + else: + sql = statement + return self.backend.sql(sql) diff --git a/fugue_ray/dataframe.py b/fugue_ray/dataframe.py index d79ff688..7fa75013 100644 --- a/fugue_ray/dataframe.py +++ b/fugue_ray/dataframe.py @@ -6,11 +6,35 @@ import ray.data as rd from fugue.dataframe import ArrowDataFrame, DataFrame, LocalDataFrame, PandasDataFrame from fugue.dataframe.dataframe import _input_schema +from fugue.dataframe.utils import ( + get_dataframe_column_names, + rename_dataframe_column_names, +) from fugue.exceptions import FugueDataFrameEmptyError, FugueDataFrameOperationError +from triad import assert_or_throw from triad.collections.schema import Schema -from ._utils.dataframe import build_empty, get_dataset_format, _build_empty_arrow -from triad import assert_or_throw +from ._utils.dataframe import _build_empty_arrow, build_empty, get_dataset_format + + +@get_dataframe_column_names.candidate(lambda df: isinstance(df, rd.Dataset)) +def _get_ray_dataframe_columns(df: rd.Dataset) -> List[Any]: + fmt = get_dataset_format(df) + if fmt == "pandas": + return list(df.schema(True).names) + elif fmt == "arrow": + return [f.name for f in df.schema(True)] + raise NotImplementedError(f"{fmt} is not supported") # pragma: no cover + + +@rename_dataframe_column_names.candidate( + lambda df, *args, **kwargs: isinstance(df, rd.Dataset) +) +def _rename_ray_dataframe(df: rd.Dataset, names: Dict[str, Any]) -> rd.Dataset: + if len(names) == 0: + return df + new_cols = [names.get(name, name) for name in _get_ray_dataframe_columns(df)] + return df.map_batches(lambda b: b.rename_columns(new_cols), batch_format="pyarrow") class RayDataFrame(DataFrame): diff --git a/fugue_spark/dataframe.py b/fugue_spark/dataframe.py index 01f6bf81..21aa2a1d 100644 --- a/fugue_spark/dataframe.py +++ b/fugue_spark/dataframe.py @@ -10,6 +10,10 @@ LocalDataFrame, PandasDataFrame, ) +from fugue.dataframe.utils import ( + get_dataframe_column_names, + rename_dataframe_column_names, +) from fugue.exceptions import FugueDataFrameOperationError from triad import SerializableRLock from triad.collections.schema import SchemaError @@ -18,6 +22,22 @@ from fugue_spark._utils.convert import to_cast_expression, to_schema, to_type_safe_input +@get_dataframe_column_names.candidate(lambda df: isinstance(df, ps.DataFrame)) +def _get_spark_dataframe_columns(df: ps.DataFrame) -> List[Any]: + return [f.name for f in df.schema] + + +@rename_dataframe_column_names.candidate( + lambda df, *args, **kwargs: isinstance(df, ps.DataFrame) +) +def _rename_spark_dataframe(df: ps.DataFrame, names: Dict[str, Any]) -> ps.DataFrame: + if len(names) == 0: + return df + for k, v in names.items(): + df = df.withColumnRenamed(k, v) + return df + + class SparkDataFrame(DataFrame): """DataFrame that wraps Spark DataFrame. Please also read |DataFrameTutorial| to understand this Fugue concept diff --git a/setup.py b/setup.py index 833b4408..dc730285 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ def get_version() -> str: keywords="distributed spark dask sql dsl domain specific language", url="http://github.com/fugue-project/fugue", install_requires=[ - "triad>=0.6.8", + "triad>=0.6.9", "adagio>=0.2.4", "qpd>=0.3.1", "fugue-sql-antlr>=0.1.1", diff --git a/tests/fugue/dataframe/test_utils.py b/tests/fugue/dataframe/test_utils.py index 1ae235d6..6714ea89 100644 --- a/tests/fugue/dataframe/test_utils.py +++ b/tests/fugue/dataframe/test_utils.py @@ -1,7 +1,7 @@ -import concurrent.futures import os import numpy as np +import pandas as pd import pyarrow as pa from fugue.dataframe import to_local_bounded_df, to_local_df from fugue.dataframe.array_dataframe import ArrayDataFrame @@ -11,8 +11,11 @@ from fugue.dataframe.utils import ( _schema_eq, deserialize_df, + get_dataframe_column_names, get_join_schemas, + normalize_dataframe_column_names, pickle_df, + rename_dataframe_column_names, serialize_df, unpickle_df, ) @@ -201,3 +204,52 @@ def assert_eq(df, df_expected=None, raw=False): df_eq(df, deserialize_df(s), throw=True) raises(ValueError, lambda: deserialize_df('{"x":1}')) + + +def test_get_dataframe_column_names(): + df = pd.DataFrame([[0, 1, 2]]) + assert get_dataframe_column_names(df) == [0, 1, 2] + + adf = pa.Table.from_pandas(df) + assert get_dataframe_column_names(adf) == ["0", "1", "2"] + + pdf = PandasDataFrame(pd.DataFrame([[0, 1]], columns=["a", "b"])) + assert get_dataframe_column_names(pdf) == ["a", "b"] + + +def test_rename_dataframe_column_names(): + assert rename_dataframe_column_names("dummy", {}) == "dummy" + pdf = pd.DataFrame([[0, 1, 2]], columns=["a", "b", "c"]) + df = rename_dataframe_column_names(pdf, {}) + assert get_dataframe_column_names(df) == ["a", "b", "c"] + df = rename_dataframe_column_names(pdf, {"b": "bb"}) + assert get_dataframe_column_names(df) == ["a", "bb", "c"] + + adf = pa.Table.from_pandas(pdf) + adf = rename_dataframe_column_names(adf, {}) + assert get_dataframe_column_names(adf) == ["a", "b", "c"] + adf = rename_dataframe_column_names(adf, {"b": "bb"}) + assert get_dataframe_column_names(adf) == ["a", "bb", "c"] + + fdf = PandasDataFrame(pdf) + fdf = rename_dataframe_column_names(fdf, {}) + assert get_dataframe_column_names(fdf) == ["a", "b", "c"] + fdf = rename_dataframe_column_names(fdf, {"b": "bb"}) + assert get_dataframe_column_names(fdf) == ["a", "bb", "c"] + + +def test_normalize_dataframe_column_names(): + df = pd.DataFrame([[0, 1, 2]], columns=["a", "b", "c"]) + df, names = normalize_dataframe_column_names(df) + assert get_dataframe_column_names(df) == ["a", "b", "c"] + assert names == {} + + df = pd.DataFrame([[0, 1, 2]]) + df, names = normalize_dataframe_column_names(df) + assert get_dataframe_column_names(df) == ["_0", "_1", "_2"] + assert names == {"_0": 0, "_1": 1, "_2": 2} + + df = pd.DataFrame([[0, 1, 2, 3]], columns=["1", "2", "_2", "大"]) + df, names = normalize_dataframe_column_names(df) + assert get_dataframe_column_names(df) == ["_1", "_2_1", "_2", "_1_1"] + assert names == {"_1": "1", "_2_1": "2", "_1_1": "大"} diff --git a/tests/fugue_dask/test_dataframe.py b/tests/fugue_dask/test_dataframe.py index f798688f..543d7dfc 100644 --- a/tests/fugue_dask/test_dataframe.py +++ b/tests/fugue_dask/test_dataframe.py @@ -13,6 +13,10 @@ from fugue_test.dataframe_suite import DataFrameTests from pytest import raises from triad.collections.schema import Schema +from fugue.dataframe.utils import ( + get_dataframe_column_names, + rename_dataframe_column_names, +) class DaskDataFrameTests(DataFrameTests.Tests): @@ -195,3 +199,22 @@ def _test_as_array_perf(): res = df.as_array(type_safe=True) ts += (datetime.now() - t).total_seconds() print(nts, ts) + + +def test_get_dataframe_column_names(): + df = pd.from_pandas(pandas.DataFrame([[0, 1, 2]]), npartitions=1) + assert get_dataframe_column_names(df) == [0, 1, 2] + + +def test_rename_dataframe_column_names(): + pdf = pd.from_pandas( + pandas.DataFrame([[0, 1, 2]], columns=["a", "b", "c"]), npartitions=1 + ) + df = rename_dataframe_column_names(pdf, {}) + assert isinstance(df, pd.DataFrame) + assert get_dataframe_column_names(df) == ["a", "b", "c"] + + pdf = pd.from_pandas(pandas.DataFrame([[0, 1, 2]]), npartitions=1) + df = rename_dataframe_column_names(pdf, {0: "_0", 1: "_1", 2: "_2"}) + assert isinstance(df, pd.DataFrame) + assert get_dataframe_column_names(df) == ["_0", "_1", "_2"] diff --git a/tests/fugue_ray/test_dataframe.py b/tests/fugue_ray/test_dataframe.py index 9b4f44e2..ebe4098f 100644 --- a/tests/fugue_ray/test_dataframe.py +++ b/tests/fugue_ray/test_dataframe.py @@ -1,11 +1,14 @@ from typing import Any -import dask.dataframe as pd import pandas as pd import pyarrow as pa import ray import ray.data as rd from fugue.dataframe.array_dataframe import ArrayDataFrame +from fugue.dataframe.utils import ( + get_dataframe_column_names, + rename_dataframe_column_names, +) from fugue_test.dataframe_suite import DataFrameTests from pytest import raises from triad import Schema @@ -107,3 +110,23 @@ def test_ray_num_partitions(self): rdf = rd.from_pandas(pd.DataFrame(dict(a=range(10)))) df = RayDataFrame(rdf.repartition(5)) assert 5 == df.num_partitions + + def test_get_dataframe_column_names(self): + df = rd.from_pandas(pd.DataFrame([[0, 10, 20]], columns=["0", "1", "2"])) + assert get_dataframe_column_names(df) == ["0", "1", "2"] + + df = rd.from_arrow( + pa.Table.from_pandas(pd.DataFrame([[0, 10, 20]], columns=["0", "1", "2"])) + ) + assert get_dataframe_column_names(df) == ["0", "1", "2"] + + def test_rename_dataframe_column_names(self): + rdf = rd.from_pandas(pd.DataFrame([[0, 10, 20]], columns=["a", "b", "c"])) + df = rename_dataframe_column_names(rdf, {}) + assert isinstance(df, rd.Dataset) + assert get_dataframe_column_names(df) == ["a", "b", "c"] + + pdf = rd.from_pandas(pd.DataFrame([[0, 10, 20]], columns=["0", "1", "2"])) + df = rename_dataframe_column_names(pdf, {"0": "_0", "1": "_1", "2": "_2"}) + assert isinstance(df, rd.Dataset) + assert get_dataframe_column_names(df) == ["_0", "_1", "_2"] diff --git a/tests/fugue_spark/test_dataframe.py b/tests/fugue_spark/test_dataframe.py index 2edf552f..b8c591e8 100644 --- a/tests/fugue_spark/test_dataframe.py +++ b/tests/fugue_spark/test_dataframe.py @@ -1,11 +1,17 @@ from datetime import datetime from typing import Any +import pandas as pd import pyspark +import pyspark.sql as ps from fugue.dataframe.array_dataframe import ArrayDataFrame from fugue.dataframe.pandas_dataframe import PandasDataFrame from fugue.dataframe.utils import _df_eq as df_eq -from fugue.dataframe.utils import to_local_bounded_df +from fugue.dataframe.utils import ( + get_dataframe_column_names, + rename_dataframe_column_names, + to_local_bounded_df, +) from fugue_test.dataframe_suite import DataFrameTests from pyspark.sql import SparkSession from pytest import raises @@ -117,3 +123,26 @@ def _df(data, schema=None, metadata=None): else: df = session.createDataFrame(data) return SparkDataFrame(df, schema, metadata) + + +def test_get_dataframe_column_names(spark_session): + df = spark_session.createDataFrame( + pd.DataFrame([[0, 1, 2]], columns=["0", "1", "2"]) + ) + assert get_dataframe_column_names(df) == ["0", "1", "2"] + + +def test_rename_dataframe_column_names(spark_session): + pdf = spark_session.createDataFrame( + pd.DataFrame([[0, 1, 2]], columns=["a", "b", "c"]) + ) + df = rename_dataframe_column_names(pdf, {}) + assert isinstance(df, ps.DataFrame) + assert get_dataframe_column_names(df) == ["a", "b", "c"] + + pdf = spark_session.createDataFrame( + pd.DataFrame([[0, 1, 2]], columns=["0", "1", "2"]) + ) + df = rename_dataframe_column_names(pdf, {"0": "_0", "1": "_1", "2": "_2"}) + assert isinstance(df, ps.DataFrame) + assert get_dataframe_column_names(df) == ["_0", "_1", "_2"]