Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher committed Oct 3, 2024
1 parent a2e7cf1 commit 265de82
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 45 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
- New argument `strict` in `$drop()` to determine whether unknown column names
should trigger an error (#1220).
- New method `$to_dummies()` for `DataFrame` (#1225).
- New method `$join_where()` for `DataFrame` and `LazyFrame` to perform
inequality joins (#1237).

### Bug fixes

Expand Down
40 changes: 40 additions & 0 deletions R/dataframe__frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -2577,3 +2577,43 @@ DataFrame_to_dummies = function(
.pr$DataFrame$to_dummies(self, columns = columns, separator = separator, drop_first = drop_first) |>
unwrap("in $to_dummies():")
}

#' @inherit LazyFrame_join_where title params
#'
#' @description
#' This performs an inner join, so only rows where all predicates are true are
#' included in the result, and a row from either DataFrame may be included
#' multiple times in the result.
#'
#' Note that the row order of the input DataFrames is not preserved.
#'
#' @param other DataFrame to join with.
#'
#' @return A DataFrame
#'
#' @examples
#' east = pl$DataFrame(
#' id = c(100, 101, 102),
#' dur = c(120, 140, 160),
#' rev = c(12, 14, 16),
#' cores = c(2, 8, 4)
#' )
#'
#' west = pl$DataFrame(
#' t_id = c(404, 498, 676, 742),
#' time = c(90, 130, 150, 170),
#' cost = c(9, 13, 15, 16),
#' cores = c(4, 2, 1, 4)
#' )
#'
#' east$join_where(
#' west,
#' pl$col("dur") < pl$col("time"),
#' pl$col("rev") < pl$col("cost")
#' )
DataFrame_join_where = function(
other,
...,
suffix = "_right") {
self$lazy()$join_where(self, other = other, ..., suffix = suffix)$collect()
}
2 changes: 1 addition & 1 deletion R/lazyframe__lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ LazyFrame_join = function(
#' `"x<suffix>"`.
#' @param suffix Suffix to append to columns with a duplicate name.
#'
#' @return
#' @return A LazyFrame
#'
#' @examples
#' east = pl$LazyFrame(
Expand Down
50 changes: 50 additions & 0 deletions man/DataFrame_join_where.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/LazyFrame_join_where.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

89 changes: 45 additions & 44 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,19 @@
[17] "first" "flags" "gather_every" "get_column"
[21] "get_columns" "glimpse" "group_by" "group_by_dynamic"
[25] "head" "height" "item" "join"
[29] "join_asof" "last" "lazy" "limit"
[33] "max" "mean" "median" "min"
[37] "n_chunks" "null_count" "partition_by" "pivot"
[41] "print" "quantile" "rechunk" "rename"
[45] "reverse" "rolling" "sample" "schema"
[49] "select" "select_seq" "shape" "shift"
[53] "slice" "sort" "sql" "std"
[57] "sum" "tail" "to_data_frame" "to_dummies"
[61] "to_list" "to_raw_ipc" "to_series" "to_struct"
[65] "transpose" "unique" "unnest" "unpivot"
[69] "var" "width" "with_columns" "with_columns_seq"
[73] "with_row_index" "write_csv" "write_ipc" "write_json"
[77] "write_ndjson" "write_parquet"
[29] "join_asof" "join_where" "last" "lazy"
[33] "limit" "max" "mean" "median"
[37] "min" "n_chunks" "null_count" "partition_by"
[41] "pivot" "print" "quantile" "rechunk"
[45] "rename" "reverse" "rolling" "sample"
[49] "schema" "select" "select_seq" "shape"
[53] "shift" "slice" "sort" "sql"
[57] "std" "sum" "tail" "to_data_frame"
[61] "to_dummies" "to_list" "to_raw_ipc" "to_series"
[65] "to_struct" "transpose" "unique" "unnest"
[69] "unpivot" "var" "width" "with_columns"
[73] "with_columns_seq" "with_row_index" "write_csv" "write_ipc"
[77] "write_json" "write_ndjson" "write_parquet"

---

Expand Down Expand Up @@ -150,19 +150,19 @@
[13] "fill_nan" "fill_null" "filter"
[16] "first" "gather_every" "group_by"
[19] "group_by_dynamic" "head" "join"
[22] "join_asof" "last" "limit"
[25] "max" "mean" "median"
[28] "min" "print" "profile"
[31] "quantile" "rename" "reverse"
[34] "rolling" "schema" "select"
[37] "select_seq" "serialize" "shift"
[40] "sink_csv" "sink_ipc" "sink_ndjson"
[43] "sink_parquet" "slice" "sort"
[46] "sql" "std" "sum"
[49] "tail" "to_dot" "unique"
[52] "unnest" "unpivot" "var"
[55] "width" "with_columns" "with_columns_seq"
[58] "with_context" "with_row_index"
[22] "join_asof" "join_where" "last"
[25] "limit" "max" "mean"
[28] "median" "min" "print"
[31] "profile" "quantile" "rename"
[34] "reverse" "rolling" "schema"
[37] "select" "select_seq" "serialize"
[40] "shift" "sink_csv" "sink_ipc"
[43] "sink_ndjson" "sink_parquet" "slice"
[46] "sort" "sql" "std"
[49] "sum" "tail" "to_dot"
[52] "unique" "unnest" "unpivot"
[55] "var" "width" "with_columns"
[58] "with_columns_seq" "with_context" "with_row_index"

---

Expand All @@ -180,24 +180,25 @@
[17] "fill_null" "filter"
[19] "first" "group_by"
[21] "group_by_dynamic" "join"
[23] "join_asof" "last"
[25] "max" "mean"
[27] "median" "min"
[29] "optimization_toggle" "print"
[31] "profile" "quantile"
[33] "rename" "reverse"
[35] "rolling" "schema"
[37] "select" "select_seq"
[39] "serialize" "shift"
[41] "sink_csv" "sink_ipc"
[43] "sink_json" "sink_parquet"
[45] "slice" "sort_by_exprs"
[47] "std" "sum"
[49] "tail" "to_dot"
[51] "unique" "unnest"
[53] "unpivot" "var"
[55] "with_columns" "with_columns_seq"
[57] "with_context" "with_row_index"
[23] "join_asof" "join_where"
[25] "last" "max"
[27] "mean" "median"
[29] "min" "optimization_toggle"
[31] "print" "profile"
[33] "quantile" "rename"
[35] "reverse" "rolling"
[37] "schema" "select"
[39] "select_seq" "serialize"
[41] "shift" "sink_csv"
[43] "sink_ipc" "sink_json"
[45] "sink_parquet" "slice"
[47] "sort_by_exprs" "std"
[49] "sum" "tail"
[51] "to_dot" "unique"
[53] "unnest" "unpivot"
[55] "var" "with_columns"
[57] "with_columns_seq" "with_context"
[59] "with_row_index"

# public and private methods of each class Expr

Expand Down
34 changes: 34 additions & 0 deletions tests/testthat/test-dataframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -1761,3 +1761,37 @@ test_that("$to_dummies() works", {
)
)
})

test_that("inequality joins work", {
east = pl$DataFrame(
id = c(100, 101, 102),
dur = c(120, 140, 160),
rev = c(12, 14, 16),
cores = c(2, 8, 4)
)
west = pl$DataFrame(
t_id = c(404, 498, 676, 742),
time = c(90, 130, 150, 170),
cost = c(9, 13, 15, 16),
cores = c(4, 2, 1, 4)
)
out = east$join_where(
west,
pl$col("dur") < pl$col("time"),
pl$col("rev") < pl$col("cost")
)

expect_identical(
out$to_data_frame(),
data.frame(
id = rep(c(100, 101), 3:2),
dur = rep(c(120, 140), 3:2),
rev = rep(c(12, 14), 3:2),
cores = rep(c(2, 8), 3:2),
t_id = c(498, 676, 742, 676, 742),
time = c(130, 150, 170, 150, 170),
cost = c(13, 15, 16, 15, 16),
cores_right = c(2, 1, 4, 1, 4)
)
)
})
34 changes: 34 additions & 0 deletions tests/testthat/test-lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -1198,3 +1198,37 @@ test_that("$cast() works", {
list(x = NA_integer_)
)
})

test_that("inequality joins work", {
east = pl$LazyFrame(
id = c(100, 101, 102),
dur = c(120, 140, 160),
rev = c(12, 14, 16),
cores = c(2, 8, 4)
)
west = pl$LazyFrame(
t_id = c(404, 498, 676, 742),
time = c(90, 130, 150, 170),
cost = c(9, 13, 15, 16),
cores = c(4, 2, 1, 4)
)
out = east$join_where(
west,
pl$col("dur") < pl$col("time"),
pl$col("rev") < pl$col("cost")
)$collect()

expect_identical(
out$to_data_frame(),
data.frame(
id = rep(c(100, 101), 3:2),
dur = rep(c(120, 140), 3:2),
rev = rep(c(12, 14), 3:2),
cores = rep(c(2, 8), 3:2),
t_id = c(498, 676, 742, 676, 742),
time = c(130, 150, 170, 150, 170),
cost = c(13, 15, 16, 15, 16),
cores_right = c(2, 1, 4, 1, 4)
)
)
})

0 comments on commit 265de82

Please sign in to comment.