From 265de821a2fe792118b2f4b7831651f3c2486a2f Mon Sep 17 00:00:00 2001 From: etiennebacher Date: Thu, 3 Oct 2024 21:35:06 +0200 Subject: [PATCH] fix --- NEWS.md | 2 + R/dataframe__frame.R | 40 +++++++++++ R/lazyframe__lazy.R | 2 +- man/DataFrame_join_where.Rd | 50 ++++++++++++++ man/LazyFrame_join_where.Rd | 3 + tests/testthat/_snaps/after-wrappers.md | 89 +++++++++++++------------ tests/testthat/test-dataframe.R | 34 ++++++++++ tests/testthat/test-lazy.R | 34 ++++++++++ 8 files changed, 209 insertions(+), 45 deletions(-) create mode 100644 man/DataFrame_join_where.Rd diff --git a/NEWS.md b/NEWS.md index dff1e7c99..6b2deee65 100644 --- a/NEWS.md +++ b/NEWS.md @@ -17,6 +17,8 @@ - New argument `strict` in `$drop()` to determine whether unknown column names should trigger an error (#1220). - New method `$to_dummies()` for `DataFrame` (#1225). +- New method `$join_where()` for `DataFrame` and `LazyFrame` to perform + inequality joins (#1237). ### Bug fixes diff --git a/R/dataframe__frame.R b/R/dataframe__frame.R index 6e02c2dad..d4a5d433b 100644 --- a/R/dataframe__frame.R +++ b/R/dataframe__frame.R @@ -2577,3 +2577,43 @@ DataFrame_to_dummies = function( .pr$DataFrame$to_dummies(self, columns = columns, separator = separator, drop_first = drop_first) |> unwrap("in $to_dummies():") } + +#' @inherit LazyFrame_join_where title params +#' +#' @description +#' This performs an inner join, so only rows where all predicates are true are +#' included in the result, and a row from either DataFrame may be included +#' multiple times in the result. +#' +#' Note that the row order of the input DataFrames is not preserved. +#' +#' @param other DataFrame to join with. +#' +#' @return A DataFrame +#' +#' @examples +#' east = pl$DataFrame( +#' id = c(100, 101, 102), +#' dur = c(120, 140, 160), +#' rev = c(12, 14, 16), +#' cores = c(2, 8, 4) +#' ) +#' +#' west = pl$DataFrame( +#' t_id = c(404, 498, 676, 742), +#' time = c(90, 130, 150, 170), +#' cost = c(9, 13, 15, 16), +#' cores = c(4, 2, 1, 4) +#' ) +#' +#' east$join_where( +#' west, +#' pl$col("dur") < pl$col("time"), +#' pl$col("rev") < pl$col("cost") +#' ) +DataFrame_join_where = function( + other, + ..., + suffix = "_right") { + self$lazy()$join_where(self, other = other, ..., suffix = suffix)$collect() +} diff --git a/R/lazyframe__lazy.R b/R/lazyframe__lazy.R index 16126f7a0..e7d10eb5a 100644 --- a/R/lazyframe__lazy.R +++ b/R/lazyframe__lazy.R @@ -1384,7 +1384,7 @@ LazyFrame_join = function( #' `"x"`. #' @param suffix Suffix to append to columns with a duplicate name. #' -#' @return +#' @return A LazyFrame #' #' @examples #' east = pl$LazyFrame( diff --git a/man/DataFrame_join_where.Rd b/man/DataFrame_join_where.Rd new file mode 100644 index 000000000..0281041e7 --- /dev/null +++ b/man/DataFrame_join_where.Rd @@ -0,0 +1,50 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dataframe__frame.R +\name{DataFrame_join_where} +\alias{DataFrame_join_where} +\title{Perform a join based on one or multiple (in)equality predicates} +\usage{ +DataFrame_join_where(other, ..., suffix = "_right") +} +\arguments{ +\item{other}{DataFrame to join with.} + +\item{...}{(In)Equality condition to join the two tables on. When a column +name occurs in both tables, the proper suffix must be applied in the +predicate. For example, if both tables have a column \code{"x"} that you want to +use in the conditions, you must refer to the column of the right table as +\code{"x"}.} + +\item{suffix}{Suffix to append to columns with a duplicate name.} +} +\value{ +A DataFrame +} +\description{ +This performs an inner join, so only rows where all predicates are true are +included in the result, and a row from either DataFrame may be included +multiple times in the result. + +Note that the row order of the input DataFrames is not preserved. +} +\examples{ +east = pl$DataFrame( + id = c(100, 101, 102), + dur = c(120, 140, 160), + rev = c(12, 14, 16), + cores = c(2, 8, 4) +) + +west = pl$DataFrame( + t_id = c(404, 498, 676, 742), + time = c(90, 130, 150, 170), + cost = c(9, 13, 15, 16), + cores = c(4, 2, 1, 4) +) + +east$join_where( + west, + pl$col("dur") < pl$col("time"), + pl$col("rev") < pl$col("cost") +) +} diff --git a/man/LazyFrame_join_where.Rd b/man/LazyFrame_join_where.Rd index 99a2c1885..f8b5db3c2 100644 --- a/man/LazyFrame_join_where.Rd +++ b/man/LazyFrame_join_where.Rd @@ -17,6 +17,9 @@ use in the conditions, you must refer to the column of the right table as \item{suffix}{Suffix to append to columns with a duplicate name.} } +\value{ +A LazyFrame +} \description{ This performs an inner join, so only rows where all predicates are true are included in the result, and a row from either LazyFrame may be included diff --git a/tests/testthat/_snaps/after-wrappers.md b/tests/testthat/_snaps/after-wrappers.md index 86567cbb7..ebf49d224 100644 --- a/tests/testthat/_snaps/after-wrappers.md +++ b/tests/testthat/_snaps/after-wrappers.md @@ -84,19 +84,19 @@ [17] "first" "flags" "gather_every" "get_column" [21] "get_columns" "glimpse" "group_by" "group_by_dynamic" [25] "head" "height" "item" "join" - [29] "join_asof" "last" "lazy" "limit" - [33] "max" "mean" "median" "min" - [37] "n_chunks" "null_count" "partition_by" "pivot" - [41] "print" "quantile" "rechunk" "rename" - [45] "reverse" "rolling" "sample" "schema" - [49] "select" "select_seq" "shape" "shift" - [53] "slice" "sort" "sql" "std" - [57] "sum" "tail" "to_data_frame" "to_dummies" - [61] "to_list" "to_raw_ipc" "to_series" "to_struct" - [65] "transpose" "unique" "unnest" "unpivot" - [69] "var" "width" "with_columns" "with_columns_seq" - [73] "with_row_index" "write_csv" "write_ipc" "write_json" - [77] "write_ndjson" "write_parquet" + [29] "join_asof" "join_where" "last" "lazy" + [33] "limit" "max" "mean" "median" + [37] "min" "n_chunks" "null_count" "partition_by" + [41] "pivot" "print" "quantile" "rechunk" + [45] "rename" "reverse" "rolling" "sample" + [49] "schema" "select" "select_seq" "shape" + [53] "shift" "slice" "sort" "sql" + [57] "std" "sum" "tail" "to_data_frame" + [61] "to_dummies" "to_list" "to_raw_ipc" "to_series" + [65] "to_struct" "transpose" "unique" "unnest" + [69] "unpivot" "var" "width" "with_columns" + [73] "with_columns_seq" "with_row_index" "write_csv" "write_ipc" + [77] "write_json" "write_ndjson" "write_parquet" --- @@ -150,19 +150,19 @@ [13] "fill_nan" "fill_null" "filter" [16] "first" "gather_every" "group_by" [19] "group_by_dynamic" "head" "join" - [22] "join_asof" "last" "limit" - [25] "max" "mean" "median" - [28] "min" "print" "profile" - [31] "quantile" "rename" "reverse" - [34] "rolling" "schema" "select" - [37] "select_seq" "serialize" "shift" - [40] "sink_csv" "sink_ipc" "sink_ndjson" - [43] "sink_parquet" "slice" "sort" - [46] "sql" "std" "sum" - [49] "tail" "to_dot" "unique" - [52] "unnest" "unpivot" "var" - [55] "width" "with_columns" "with_columns_seq" - [58] "with_context" "with_row_index" + [22] "join_asof" "join_where" "last" + [25] "limit" "max" "mean" + [28] "median" "min" "print" + [31] "profile" "quantile" "rename" + [34] "reverse" "rolling" "schema" + [37] "select" "select_seq" "serialize" + [40] "shift" "sink_csv" "sink_ipc" + [43] "sink_ndjson" "sink_parquet" "slice" + [46] "sort" "sql" "std" + [49] "sum" "tail" "to_dot" + [52] "unique" "unnest" "unpivot" + [55] "var" "width" "with_columns" + [58] "with_columns_seq" "with_context" "with_row_index" --- @@ -180,24 +180,25 @@ [17] "fill_null" "filter" [19] "first" "group_by" [21] "group_by_dynamic" "join" - [23] "join_asof" "last" - [25] "max" "mean" - [27] "median" "min" - [29] "optimization_toggle" "print" - [31] "profile" "quantile" - [33] "rename" "reverse" - [35] "rolling" "schema" - [37] "select" "select_seq" - [39] "serialize" "shift" - [41] "sink_csv" "sink_ipc" - [43] "sink_json" "sink_parquet" - [45] "slice" "sort_by_exprs" - [47] "std" "sum" - [49] "tail" "to_dot" - [51] "unique" "unnest" - [53] "unpivot" "var" - [55] "with_columns" "with_columns_seq" - [57] "with_context" "with_row_index" + [23] "join_asof" "join_where" + [25] "last" "max" + [27] "mean" "median" + [29] "min" "optimization_toggle" + [31] "print" "profile" + [33] "quantile" "rename" + [35] "reverse" "rolling" + [37] "schema" "select" + [39] "select_seq" "serialize" + [41] "shift" "sink_csv" + [43] "sink_ipc" "sink_json" + [45] "sink_parquet" "slice" + [47] "sort_by_exprs" "std" + [49] "sum" "tail" + [51] "to_dot" "unique" + [53] "unnest" "unpivot" + [55] "var" "with_columns" + [57] "with_columns_seq" "with_context" + [59] "with_row_index" # public and private methods of each class Expr diff --git a/tests/testthat/test-dataframe.R b/tests/testthat/test-dataframe.R index 88816f6e1..62bef8b07 100644 --- a/tests/testthat/test-dataframe.R +++ b/tests/testthat/test-dataframe.R @@ -1761,3 +1761,37 @@ test_that("$to_dummies() works", { ) ) }) + +test_that("inequality joins work", { + east = pl$DataFrame( + id = c(100, 101, 102), + dur = c(120, 140, 160), + rev = c(12, 14, 16), + cores = c(2, 8, 4) + ) + west = pl$DataFrame( + t_id = c(404, 498, 676, 742), + time = c(90, 130, 150, 170), + cost = c(9, 13, 15, 16), + cores = c(4, 2, 1, 4) + ) + out = east$join_where( + west, + pl$col("dur") < pl$col("time"), + pl$col("rev") < pl$col("cost") + ) + + expect_identical( + out$to_data_frame(), + data.frame( + id = rep(c(100, 101), 3:2), + dur = rep(c(120, 140), 3:2), + rev = rep(c(12, 14), 3:2), + cores = rep(c(2, 8), 3:2), + t_id = c(498, 676, 742, 676, 742), + time = c(130, 150, 170, 150, 170), + cost = c(13, 15, 16, 15, 16), + cores_right = c(2, 1, 4, 1, 4) + ) + ) +}) \ No newline at end of file diff --git a/tests/testthat/test-lazy.R b/tests/testthat/test-lazy.R index ce42b9f88..3466d7c37 100644 --- a/tests/testthat/test-lazy.R +++ b/tests/testthat/test-lazy.R @@ -1198,3 +1198,37 @@ test_that("$cast() works", { list(x = NA_integer_) ) }) + +test_that("inequality joins work", { + east = pl$LazyFrame( + id = c(100, 101, 102), + dur = c(120, 140, 160), + rev = c(12, 14, 16), + cores = c(2, 8, 4) + ) + west = pl$LazyFrame( + t_id = c(404, 498, 676, 742), + time = c(90, 130, 150, 170), + cost = c(9, 13, 15, 16), + cores = c(4, 2, 1, 4) + ) + out = east$join_where( + west, + pl$col("dur") < pl$col("time"), + pl$col("rev") < pl$col("cost") + )$collect() + + expect_identical( + out$to_data_frame(), + data.frame( + id = rep(c(100, 101), 3:2), + dur = rep(c(120, 140), 3:2), + rev = rep(c(12, 14), 3:2), + cores = rep(c(2, 8), 3:2), + t_id = c(498, 676, 742, 676, 742), + time = c(130, 150, 170, 150, 170), + cost = c(13, 15, 16, 15, 16), + cores_right = c(2, 1, 4, 1, 4) + ) + ) +})