Skip to content

Commit

Permalink
Add more supported type annotations, fix spark connect issue (#542)
Browse files Browse the repository at this point in the history
* Add more supported type annotations, fix spark connect issue

* update

* update

* update

* update

* update

* update

* update
  • Loading branch information
goodwanghan authored Jun 14, 2024
1 parent 48b7ab6 commit 1adc576
Show file tree
Hide file tree
Showing 13 changed files with 95 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
],
"postCreateCommand": "make devenv",
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
"ghcr.io/devcontainers/features/docker-in-docker:2.11.0": {},
"ghcr.io/devcontainers/features/java:1": {
"version": "11"
}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, "3.10"] # TODO: add back 3.11 when dask-sql is compatible
python-version: [3.8, "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
Expand Down
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Release Notes

## 0.9.1

- [543](https://github.com/fugue-project/fugue/issues/543) Support type hinting with standard collections
- [544](https://github.com/fugue-project/fugue/issues/544) Fix Spark connect import issue on worker side

## 0.9.0

- [482](https://github.com/fugue-project/fugue/issues/482) Move Fugue SQL dependencies into extra `[sql]` and functions to become soft dependencies
Expand Down
29 changes: 13 additions & 16 deletions fugue/dataframe/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
PositionalParam,
function_wrapper,
)
from triad.utils.convert import compare_annotations
from triad.utils.iter import EmptyAwareIterable, make_empty_aware

from ..constants import FUGUE_ENTRYPOINT
Expand All @@ -37,6 +38,14 @@
from .pandas_dataframe import PandasDataFrame


def _compare_iter(tp: Any) -> Any:
return lambda x: compare_annotations(
x, Iterable[tp] # type:ignore
) or compare_annotations(
x, Iterator[tp] # type:ignore
)


@function_wrapper(FUGUE_ENTRYPOINT)
class DataFrameFunctionWrapper(FunctionWrapper):
@property
Expand Down Expand Up @@ -228,10 +237,7 @@ def count(self, df: List[List[Any]]) -> int:
return len(df)


@fugue_annotated_param(
Iterable[List[Any]],
matcher=lambda x: x == Iterable[List[Any]] or x == Iterator[List[Any]],
)
@fugue_annotated_param(Iterable[List[Any]], matcher=_compare_iter(List[Any]))
class _IterableListParam(_LocalNoSchemaDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[List[Any]]:
Expand Down Expand Up @@ -288,10 +294,7 @@ def count(self, df: List[Dict[str, Any]]) -> int:
return len(df)


@fugue_annotated_param(
Iterable[Dict[str, Any]],
matcher=lambda x: x == Iterable[Dict[str, Any]] or x == Iterator[Dict[str, Any]],
)
@fugue_annotated_param(Iterable[Dict[str, Any]], matcher=_compare_iter(Dict[str, Any]))
class _IterableDictParam(_LocalNoSchemaDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[Dict[str, Any]]:
Expand Down Expand Up @@ -360,10 +363,7 @@ def format_hint(self) -> Optional[str]:
return "pandas"


@fugue_annotated_param(
Iterable[pd.DataFrame],
matcher=lambda x: x == Iterable[pd.DataFrame] or x == Iterator[pd.DataFrame],
)
@fugue_annotated_param(Iterable[pd.DataFrame], matcher=_compare_iter(pd.DataFrame))
class _IterablePandasParam(LocalDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[pd.DataFrame]:
Expand Down Expand Up @@ -419,10 +419,7 @@ def format_hint(self) -> Optional[str]:
return "pyarrow"


@fugue_annotated_param(
Iterable[pa.Table],
matcher=lambda x: x == Iterable[pa.Table] or x == Iterator[pa.Table],
)
@fugue_annotated_param(Iterable[pa.Table], matcher=_compare_iter(pa.Table))
class _IterableArrowParam(LocalDataFrameParam):
@no_type_check
def to_input_data(self, df: DataFrame, ctx: Any) -> Iterable[pa.Table]:
Expand Down
2 changes: 1 addition & 1 deletion fugue_spark/_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
try:
from pyspark.sql.connect.session import SparkSession as SparkConnectSession
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
except ImportError: # pragma: no cover
except Exception: # pragma: no cover
SparkConnectSession = None
SparkConnectDataFrame = None
import pyspark.sql as ps
Expand Down
2 changes: 1 addition & 1 deletion fugue_version/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.0"
__version__ = "0.9.1"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_version() -> str:
keywords="distributed spark dask ray sql dsl domain specific language",
url="http://github.com/fugue-project/fugue",
install_requires=[
"triad>=0.9.6",
"triad>=0.9.7",
"adagio>=0.2.4",
],
extras_require={
Expand Down
16 changes: 15 additions & 1 deletion tests/fugue/dataframe/test_function_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import copy
import sys
from typing import Any, Dict, Iterable, Iterator, List

import pandas as pd
Expand Down Expand Up @@ -29,7 +32,10 @@


def test_function_wrapper():
for f in [f20, f21, f212, f22, f23, f24, f25, f26, f30, f31, f32, f35, f36]:
fs = [f20, f21, f212, f22, f23, f24, f25, f26, f30, f31, f32, f35, f36]
if sys.version_info >= (3, 9):
fs.append(f33)
for f in fs:
df = ArrayDataFrame([[0]], "a:int")
w = DataFrameFunctionWrapper(f, "^[ldsp][ldsp]$", "[ldspq]")
res = w.run([df], dict(a=df), ignore_unknown=False, output_schema="a:int")
Expand Down Expand Up @@ -372,6 +378,14 @@ def f32(
return ArrayDataFrame(arr, "a:int").as_dict_iterable()


def f33(
e: list[dict[str, Any]], a: Iterable[dict[str, Any]]
) -> EmptyAwareIterable[Dict[str, Any]]:
e += list(a)
arr = [[x["a"]] for x in e]
return ArrayDataFrame(arr, "a:int").as_dict_iterable()


def f35(e: pd.DataFrame, a: LocalDataFrame) -> Iterable[pd.DataFrame]:
e = PandasDataFrame(e, "a:int").as_pandas()
a = ArrayDataFrame(a, "a:int").as_pandas()
Expand Down
42 changes: 42 additions & 0 deletions tests/fugue_dask/test_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from fugue_dask.execution_engine import DaskExecutionEngine
from fugue_test.builtin_suite import BuiltInTests
from fugue_test.execution_suite import ExecutionEngineTests
from fugue.column import col, all_cols
import fugue.column.functions as ff

_CONF = {
"fugue.rpc.server": "fugue.rpc.flask.FlaskRPCServer",
Expand All @@ -50,6 +52,46 @@ def test_get_parallelism(self):
def test__join_outer_pandas_incompatible(self):
return

# TODO: dask-sql 2024.5.0 has a bug, can't pass the HAVING tests
def test_select(self):
try:
import qpd
import dask_sql
except ImportError:
return

a = ArrayDataFrame(
[[1, 2], [None, 2], [None, 1], [3, 4], [None, 4]], "a:double,b:int"
)

# simple
b = fa.select(a, col("b"), (col("b") + 1).alias("c").cast(str))
self.df_eq(
b,
[[2, "3"], [2, "3"], [1, "2"], [4, "5"], [4, "5"]],
"b:int,c:str",
throw=True,
)

# with distinct
b = fa.select(
a, col("b"), (col("b") + 1).alias("c").cast(str), distinct=True
)
self.df_eq(
b,
[[2, "3"], [1, "2"], [4, "5"]],
"b:int,c:str",
throw=True,
)

# wildcard
b = fa.select(a, all_cols(), where=col("a") + col("b") == 3)
self.df_eq(b, [[1, 2]], "a:double,b:int", throw=True)

# aggregation
b = fa.select(a, col("a"), ff.sum(col("b")).cast(float).alias("b"))
self.df_eq(b, [[1, 2], [3, 4], [None, 7]], "a:double,b:double", throw=True)

def test_to_df(self):
e = self.engine
a = e.to_df([[1, 2], [3, 4]], "a:int,b:int")
Expand Down
4 changes: 2 additions & 2 deletions tests/fugue_duckdb/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class DuckDataFrameTests(DataFrameTests.Tests):
def df(self, data: Any = None, schema: Any = None) -> DuckDataFrame:
df = ArrowDataFrame(data, schema)
return DuckDataFrame(duckdb.from_arrow(df.native, self.context.session))
return DuckDataFrame(duckdb.from_arrow(df.native, connection=self.context.session))

def test_as_array_special_values(self):
for func in [
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_duck_as_local(self):
class NativeDuckDataFrameTests(DataFrameTests.NativeTests):
def df(self, data: Any = None, schema: Any = None) -> DuckDataFrame:
df = ArrowDataFrame(data, schema)
return DuckDataFrame(duckdb.from_arrow(df.native, self.context.session)).native
return DuckDataFrame(duckdb.from_arrow(df.native, connection=self.context.session)).native

def to_native_df(self, pdf: pd.DataFrame) -> Any:
return duckdb.from_df(pdf)
Expand Down
2 changes: 1 addition & 1 deletion tests/fugue_duckdb/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_type_conversion(backend_context):

def assert_(tp):
dt = duckdb.from_arrow(
pa.Table.from_pydict(dict(a=pa.nulls(2, tp))), con
pa.Table.from_pydict(dict(a=pa.nulls(2, tp))), connection=con
).types[0]
assert to_pa_type(dt) == tp
dt = to_duck_type(tp)
Expand Down
16 changes: 9 additions & 7 deletions tests/fugue_ibis/mock/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,17 @@ def sample(
f"one and only one of n and frac should be non-negative, {n}, {frac}"
),
)
tn = self.get_temp_table_name()
idf = self.to_df(df)
tn = f"({idf.native.compile()})"
if seed is not None:
_seed = f",{seed}"
else:
_seed = ""
if frac is not None:
sql = f"SELECT * FROM {tn} USING SAMPLE bernoulli({frac*100} PERCENT)"
sql = f"SELECT * FROM {tn} USING SAMPLE {frac*100}% (bernoulli{_seed})"
else:
sql = f"SELECT * FROM {tn} USING SAMPLE reservoir({n} ROWS)"
if seed is not None:
sql += f" REPEATABLE ({seed})"
idf = self.to_df(df)
_res = f"WITH {tn} AS ({idf.native.compile()}) " + sql
sql = f"SELECT * FROM {tn} USING SAMPLE {n} ROWS (reservoir{_seed})"
_res = f"SELECT * FROM ({sql})" # ibis has a bug to inject LIMIT
return self.to_df(self.backend.sql(_res))

def _register_df(
Expand Down
3 changes: 3 additions & 0 deletions tests/fugue_ibis/test_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def test_properties(self):
assert not self.engine.map_engine.is_distributed
assert not self.engine.sql_engine.is_distributed

assert self.engine.sql_engine.get_temp_table_name(
) != self.engine.sql_engine.get_temp_table_name()

def test_select(self):
# it can't work properly with DuckDB (hugeint is not recognized)
pass
Expand Down

0 comments on commit 1adc576

Please sign in to comment.