From 33a56d998d5b99fff74106fc2483263cfe6c3846 Mon Sep 17 00:00:00 2001 From: Etienne Bacher <52219252+etiennebacher@users.noreply.github.com> Date: Tue, 26 Sep 2023 22:28:38 +0000 Subject: [PATCH] Implement `$sample()` for `DataFrame` (#399) Co-authored-by: sorhawell --- NEWS.md | 1 + R/dataframe__frame.R | 41 ++++++++++++++++++++++++++++----- R/extendr-wrappers.R | 8 +++++-- man/DataFrame_sample.Rd | 41 +++++++++++++++++++++++++++++++++ man/DataFrame_unnest.Rd | 8 +++---- src/rust/src/lazy/dataframe.rs | 1 - src/rust/src/rdataframe/mod.rs | 40 +++++++++++++++++++++++++++++++- tests/testthat/test-dataframe.R | 27 ++++++++++++++++++++++ 8 files changed, 153 insertions(+), 14 deletions(-) create mode 100644 man/DataFrame_sample.Rd diff --git a/NEWS.md b/NEWS.md index 8b01bf10c..4eab15d33 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,7 @@ ## What's changed - New method `$unnest()` for `LazyFrame` (#397). +- New method `$sample()` for `DataFrame`. # polars 0.8.1 diff --git a/R/dataframe__frame.R b/R/dataframe__frame.R index d9d29b496..a16e124f8 100644 --- a/R/dataframe__frame.R +++ b/R/dataframe__frame.R @@ -1039,10 +1039,10 @@ DataFrame_to_struct = function(name = "") { #' b = c("one", "two", "three", "four", "five"), #' c = 6:10 #' )$ -#' select( -#' pl$col("b")$to_struct(), -#' pl$col("a", "c")$to_struct()$alias("a_and_c") -#' ) +#' select( +#' pl$col("b")$to_struct(), +#' pl$col("a", "c")$to_struct()$alias("a_and_c") +#' ) #' df #' #' # by default, all struct columns are unnested @@ -1050,10 +1050,9 @@ DataFrame_to_struct = function(name = "") { #' #' # we can specify specific columns to unnest #' df$unnest("a_and_c") - DataFrame_unnest = function(names = NULL) { if (is.null(names)) { - names <- names(which(dtypes_are_struct(.pr$DataFrame$schema(self)))) + names = names(which(dtypes_are_struct(.pr$DataFrame$schema(self)))) } unwrap(.pr$DataFrame$unnest(self, names), "in $unnest():") } @@ -1604,3 +1603,33 @@ DataFrame_glimpse = function(..., return_as_string = FALSE) { DataFrame_explode = function(...) { self$lazy()$explode(...)$collect() } + +#' Take a sample of rows from a DataFrame +#' +#' @param n Number of rows to return. Cannot be used with `fraction`. +#' @param fraction Fraction of rows to return (between 0 and 1). Cannot be used +#' with `n`. +#' @param with_replacement Allow values to be sampled more than once. +#' @param shuffle If `TRUE`, the order of the sampled rows will be shuffled. If +#' `FALSE` (default), the order of the returned rows will be neither stable nor +#' fully random. +#' @param seed Seed for the random number generator. If set to `NULL` (default), +#' a random seed is generated for each sample operation. +#' +#' @keywords DataFrame +#' @return DataFrame +#' @examples +#' df = pl$DataFrame(iris) +#' df$sample(n = 20) +#' df$sample(frac = 0.1) +DataFrame_sample = function( + n = NULL, fraction = NULL, with_replacement = FALSE, shuffle = FALSE, seed = NULL) { + seed = seed %||% sample(0:10000, 1) + pcase( + !xor(is.null(n), is.null(fraction)), Err_plain("Pass either arg `n` or `fraction`, not both."), + is.null(fraction), .pr$DataFrame$sample_n(self, n, with_replacement, shuffle, seed), + is.null(n), .pr$DataFrame$sample_frac(self, fraction, with_replacement, shuffle, seed), + or_else = Err_plain("internal error") + ) |> + unwrap("in $sample():") +} diff --git a/R/extendr-wrappers.R b/R/extendr-wrappers.R index 77b95a6a8..fcdc9f3a1 100644 --- a/R/extendr-wrappers.R +++ b/R/extendr-wrappers.R @@ -175,6 +175,10 @@ DataFrame$melt <- function(id_vars, value_vars, value_name, variable_name) .Call DataFrame$pivot_expr <- function(values, index, columns, maintain_order, sort_columns, aggregate_expr, separator) .Call(wrap__DataFrame__pivot_expr, self, values, index, columns, maintain_order, sort_columns, aggregate_expr, separator) +DataFrame$sample_n <- function(n, with_replacement, shuffle, seed) .Call(wrap__DataFrame__sample_n, self, n, with_replacement, shuffle, seed) + +DataFrame$sample_frac <- function(frac, with_replacement, shuffle, seed) .Call(wrap__DataFrame__sample_frac, self, frac, with_replacement, shuffle, seed) + #' @export `$.DataFrame` <- function (self, name) { func <- DataFrame[[name]]; environment(func) <- environment(); func } @@ -997,6 +1001,8 @@ LazyFrame$slice <- function(offset, length) .Call(wrap__LazyFrame__slice, self, LazyFrame$with_columns <- function(exprs) .Call(wrap__LazyFrame__with_columns, self, exprs) +LazyFrame$unnest <- function(names) .Call(wrap__LazyFrame__unnest, self, names) + LazyFrame$select <- function(exprs) .Call(wrap__LazyFrame__select, self, exprs) LazyFrame$select_str_as_lit <- function(exprs) .Call(wrap__LazyFrame__select_str_as_lit, self, exprs) @@ -1039,8 +1045,6 @@ LazyFrame$explode <- function(dotdotdot) .Call(wrap__LazyFrame__explode, self, d LazyFrame$clone_see_me_macro <- function() .Call(wrap__LazyFrame__clone_see_me_macro, self) -LazyFrame$unnest <- function(names) .Call(wrap__LazyFrame__unnest, self, names) - #' @export `$.LazyFrame` <- function (self, name) { func <- LazyFrame[[name]]; environment(func) <- environment(); func } diff --git a/man/DataFrame_sample.Rd b/man/DataFrame_sample.Rd new file mode 100644 index 000000000..45639a836 --- /dev/null +++ b/man/DataFrame_sample.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dataframe__frame.R +\name{DataFrame_sample} +\alias{DataFrame_sample} +\title{Take a sample of rows from a DataFrame} +\usage{ +DataFrame_sample( + n = NULL, + fraction = NULL, + with_replacement = FALSE, + shuffle = FALSE, + seed = NULL +) +} +\arguments{ +\item{n}{Number of rows to return. Cannot be used with \code{fraction}.} + +\item{fraction}{Fraction of rows to return (between 0 and 1). Cannot be used +with \code{n}.} + +\item{with_replacement}{Allow values to be sampled more than once.} + +\item{shuffle}{If \code{TRUE}, the order of the sampled rows will be shuffled. If +\code{FALSE} (default), the order of the returned rows will be neither stable nor +fully random.} + +\item{seed}{Seed for the random number generator. If set to \code{NULL} (default), +a random seed is generated for each sample operation.} +} +\value{ +DataFrame +} +\description{ +Take a sample of rows from a DataFrame +} +\examples{ +df = pl$DataFrame(iris) +df$sample(n = 20) +df$sample(frac = 0.1) +} +\keyword{DataFrame} diff --git a/man/DataFrame_unnest.Rd b/man/DataFrame_unnest.Rd index b63b11fa2..4967151a5 100644 --- a/man/DataFrame_unnest.Rd +++ b/man/DataFrame_unnest.Rd @@ -23,10 +23,10 @@ df = pl$DataFrame( b = c("one", "two", "three", "four", "five"), c = 6:10 )$ - select( - pl$col("b")$to_struct(), - pl$col("a", "c")$to_struct()$alias("a_and_c") - ) + select( + pl$col("b")$to_struct(), + pl$col("a", "c")$to_struct()$alias("a_and_c") +) df # by default, all struct columns are unnested diff --git a/src/rust/src/lazy/dataframe.rs b/src/rust/src/lazy/dataframe.rs index bae8f60c7..8705e6f34 100644 --- a/src/rust/src/lazy/dataframe.rs +++ b/src/rust/src/lazy/dataframe.rs @@ -19,7 +19,6 @@ use polars::frame::hash_join::JoinType; use polars::prelude as pl; use polars::prelude::AsOfOptions; - #[allow(unused_imports)] use std::result::Result; diff --git a/src/rust/src/rdataframe/mod.rs b/src/rust/src/rdataframe/mod.rs index b3541efed..245878b27 100644 --- a/src/rust/src/rdataframe/mod.rs +++ b/src/rust/src/rdataframe/mod.rs @@ -299,7 +299,7 @@ impl DataFrame { s.into_series().into() } - pub fn unnest(&self, names: Vec) -> RResult { + pub fn unnest(&self, names: Vec) -> RResult { self.lazy().unnest(names)?.collect() } @@ -379,6 +379,44 @@ impl DataFrame { .map_err(|err| err.to_string()) .map(|ok| ok.into()) } + + pub fn sample_n( + &self, + n: Robj, + with_replacement: Robj, + shuffle: Robj, + seed: Robj, + ) -> RResult { + self.0 + .clone() + .sample_n( + robj_to!(usize, n)?, + robj_to!(bool, with_replacement)?, + robj_to!(bool, shuffle)?, + robj_to!(Option, u64, seed)?, + ) + .map_err(polars_to_rpolars_err) + .map(DataFrame) + } + + pub fn sample_frac( + &self, + frac: Robj, + with_replacement: Robj, + shuffle: Robj, + seed: Robj, + ) -> RResult { + self.0 + .clone() + .sample_frac( + robj_to!(f64, frac)?, + robj_to!(bool, with_replacement)?, + robj_to!(bool, shuffle)?, + robj_to!(Option, u64, seed)?, + ) + .map_err(polars_to_rpolars_err) + .map(DataFrame) + } } impl DataFrame { pub fn to_list_result(&self) -> Result { diff --git a/tests/testthat/test-dataframe.R b/tests/testthat/test-dataframe.R index 8c6a01295..6cf7293ee 100644 --- a/tests/testthat/test-dataframe.R +++ b/tests/testthat/test-dataframe.R @@ -1050,3 +1050,30 @@ test_that("strictly_immutable = FALSE", { pl$reset_options() }) + +test_that("sample", { + df = pl$DataFrame(iris) + + # plain use + expect_identical(df$sample(n = 20)$height, 20) + expect_identical(df$sample(frac = 0.1)$height, 15) + + # must pass either n or fraction and not both + expect_error(df$sample(), "Pass either arg") + expect_error(df$sample(n = 2, fraction = 0.1), "not both") + + # single check of some conversion errors + ctx = df$sample(frac = 0.1, seed = "not even a written number") |> get_err_ctx() + expect_identical(ctx$PlainErrorMessage, "ParseIntError { kind: InvalidDigit }") + + # single check on rust-polars errors + ctx = df$sample(n = 151) |> get_err_ctx() + expect_true(isTRUE(grepl("larger sample than the total population", ctx$PolarsError))) + expect_no_error(df$sample(n = 151, with_replacement = TRUE)) + + # seed works + expect_identical( + df$sample(fraction = 0.1, seed = 123)$to_data_frame(), + df$sample(fraction = 0.1, seed = "123")$to_data_frame() + ) +})