Skip to content

Commit

Permalink
Add is_frame_base and xfail dask expr enabled failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger committed Feb 29, 2024
1 parent e6e5a5f commit 39acd1c
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 36 deletions.
7 changes: 3 additions & 4 deletions dask_ml/ensemble/_blockwise.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import dask
import dask.array as da
import dask.dataframe as dd
import numpy as np
import sklearn.base
from sklearn.utils.validation import check_is_fitted

from ..base import ClassifierMixin, RegressorMixin
from ..utils import check_array
from ..utils import check_array, is_frame_base


class BlockwiseBase(sklearn.base.BaseEstimator):
Expand Down Expand Up @@ -62,7 +61,7 @@ def _predict(self, X):
dtype=np.dtype(dtype),
chunks=chunks,
)
elif isinstance(X, dd.DataFrame):
elif is_frame_base(X):
meta = np.empty((0, len(self.classes_)), dtype=dtype)
combined = X.map_partitions(
_predict_stack, estimators=self.estimators_, meta=meta
Expand Down Expand Up @@ -184,7 +183,7 @@ def _collect_probas(self, X):
chunks=chunks,
meta=meta,
)
elif isinstance(X, dd.DataFrame):
elif is_frame_base(X):
# TODO: replace with a _predict_proba_stack version.
# This current raises; dask.dataframe doesn't like map_partitions that
# return new axes.
Expand Down
31 changes: 20 additions & 11 deletions dask_ml/linear_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,34 @@
import numpy as np
from multipledispatch import dispatch

if not dask.config.get("dataframe.query-planning"):
try:
import dask_expr

@dispatch(dd._Frame)
_dask_expr_avail = True
except ImportError:
dask_expr = None
_dask_expr_avail = False


if dask.config.get("dataframe.query-planning") and _dask_expr_avail:

@dispatch(dask_expr.FrameBase)
def exp(A):
return da.exp(A)

@dispatch(dd._Frame)
@dispatch(dask_expr.FrameBase)
def absolute(A):
return da.absolute(A)

@dispatch(dd._Frame)
@dispatch(dask_expr.FrameBase)
def sign(A):
return da.sign(A)

@dispatch(dd._Frame)
@dispatch(dask_expr.FrameBase)
def log1p(A):
return da.log1p(A)

@dispatch(dd._Frame) # noqa: F811
@dispatch(dask_expr.FrameBase) # noqa: F811
def add_intercept(X): # noqa: F811
columns = X.columns
if "intercept" in columns:
Expand All @@ -33,23 +42,23 @@ def add_intercept(X): # noqa: F811

else:

@dispatch(dd.DataFrame)
@dispatch(dd._Frame)
def exp(A):
return da.exp(A)

@dispatch(dd.DataFrame)
@dispatch(dd._Frame)
def absolute(A):
return da.absolute(A)

@dispatch(dd.DataFrame)
@dispatch(dd._Frame)
def sign(A):
return da.sign(A)

@dispatch(dd.DataFrame)
@dispatch(dd._Frame)
def log1p(A):
return da.log1p(A)

@dispatch(dd.DataFrame) # noqa: F811
@dispatch(dd._Frame) # noqa: F811
def add_intercept(X): # noqa: F811
columns = X.columns
if "intercept" in columns:
Expand Down
15 changes: 14 additions & 1 deletion dask_ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
logger = logging.getLogger()


def is_frame_base(inst):
if getattr(dd, "_dask_expr_enabled", lambda: False)():
from dask_expr import FrameBase

return isinstance(inst, FrameBase)
return isinstance(inst, dd._Frame)


def _svd_flip_copy(x, y, u_based_decision=True):
# If the array is locked, copy the array and transpose it
# This happens with a very large array > 1TB
Expand Down Expand Up @@ -212,7 +220,12 @@ def check_array(

def _assert_eq(l, r, name=None, **kwargs):
array_types = (np.ndarray, da.Array)
frame_types = (pd.core.generic.NDFrame, dd.DataFrame)
if getattr(dd, "_dask_expr_enabled", lambda: False)():
from dask_expr import FrameBase

frame_types = (pd.core.generic.NDFrame, FrameBase)
else:
frame_types = (pd.core.generic.NDFrame, dd._Frame)
if isinstance(l, array_types):
assert_eq_ar(l, r, **kwargs)
elif isinstance(l, frame_types):
Expand Down
10 changes: 5 additions & 5 deletions dask_ml/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sklearn.metrics
from sklearn.utils.validation import check_is_fitted

from dask_ml.utils import _timer
from dask_ml.utils import _timer, is_frame_base

from ._partial import fit
from ._utils import copy_learned_attributes
Expand Down Expand Up @@ -241,7 +241,7 @@ def transform(self, X):
return X.map_blocks(
_transform, estimator=self._postfit_estimator, meta=meta
)
elif isinstance(X, dd.DataFrame):
elif is_frame_base(X):
if meta is None:
# dask-dataframe relies on dd.core.no_default
# for infering meta
Expand Down Expand Up @@ -324,7 +324,7 @@ def predict(self, X):
)
return result

elif isinstance(X, dd.DataFrame):
elif is_frame_base(X):
if meta is None:
meta = dd.core.no_default
return X.map_partitions(
Expand Down Expand Up @@ -369,7 +369,7 @@ def predict_proba(self, X):
meta=meta,
chunks=(X.chunks[0], len(self._postfit_estimator.classes_)),
)
elif isinstance(X, dd.DataFrame):
elif is_frame_base(X):
if meta is None:
meta = dd.core.no_default
return X.map_partitions(
Expand Down Expand Up @@ -619,7 +619,7 @@ def _first_block(dask_object):
dask_object.to_delayed().flatten()[0], shape, dask_object.dtype
)

if isinstance(dask_object, dd.DataFrame):
if is_frame_base(dask_object):
return dask_object.get_partition(0)

else:
Expand Down
4 changes: 2 additions & 2 deletions tests/ensemble/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_bad_chunking_raises(self):
# this should *really* be a ValueError...
clf.fit(X, y)

@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED, reason="dask-expr computing early into np.ndarray"
)
def test_hard_voting_frame(self):
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_soft_voting_array(self):
score = clf.score(X, y)
assert isinstance(score, float)

@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: 'Scalar' object has no attribute '_chunks'",
)
Expand Down
2 changes: 1 addition & 1 deletion tests/model_selection/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def additional_calls(scores):
await asyncio.sleep(0.1)


@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED, reason="TypeError: 'coroutine' object is not iterable"
)
@gen_cluster(client=True)
Expand Down
12 changes: 6 additions & 6 deletions tests/preprocessing/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_input_types(self, dask_df, pandas_df):
exclude="n_samples_seen_",
)

@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: can't set attribute 'divisions'",
)
Expand Down Expand Up @@ -437,7 +437,7 @@ def test_encode_subset_of_columns(self, daskify):

tm.assert_frame_equal(result, df)

@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: can't set attribute 'divisions'",
)
Expand Down Expand Up @@ -496,7 +496,7 @@ def test_inverse_transform(self):


class TestOrdinalEncoder:
@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: can't set attribute 'divisions'",
)
Expand Down Expand Up @@ -544,7 +544,7 @@ def test_transform_raises(self):
de.transform(dummy.drop("B", axis="columns"))
assert rec.match("Columns of 'X' do not match the training")

@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: can't set attribute 'divisions'",
)
Expand Down Expand Up @@ -635,7 +635,7 @@ def test_transformed_shape(self):
# dask array with nan rows
assert a.transform(X_nan_rows).shape[1] == n_cols

@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="TypeError: No dispatch for <class 'dask_expr._collection.Scalar'>",
)
Expand Down Expand Up @@ -667,7 +667,7 @@ def test_transformer_params(self):
assert pf._transformer.interaction_only is pf.interaction_only
assert pf._transformer.include_bias is pf.include_bias

@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="TypeError: No dispatch for <class 'dask_expr._collection.Scalar'>",
)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_parallel_post_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_laziness():
assert 0 < x.compute() < 1


@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: 'MapPartitions' object has no attribute 'shape' / AttributeError: can't set attribute '_meta'",
)
Expand All @@ -81,7 +81,7 @@ def test_predict_meta_override():
assert_eq_ar(result, expected)


@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: 'MapPartitions' object has no attribute 'shape'",
)
Expand All @@ -108,7 +108,7 @@ def test_predict_proba_meta_override():
assert_eq_ar(result, expected)


@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: 'Scalar' object has no attribute 'shape'",
)
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_predict_correct_output_dtype():
assert wrap_output.dtype == base_output.dtype


@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: 'MapPartitions' object has no attribute 'shape'",
)
Expand Down Expand Up @@ -185,7 +185,7 @@ def test_predict(kind):
assert_eq_ar(result, expected)


@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: 'MapPartitions' object has no attribute 'shape'",
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_fit_shuffle_blocks():
)


@pytest.mark.skipif(
@pytest.mark.xfail(
DASK_EXPR_ENABLED,
reason="AttributeError: 'Scalar' object has no attribute 'shape'",
)
Expand Down

0 comments on commit 39acd1c

Please sign in to comment.