Skip to content

Commit

Permalink
Implement $sample() for DataFrame (#399)
Browse files Browse the repository at this point in the history
Co-authored-by: sorhawell <[email protected]>
  • Loading branch information
etiennebacher and sorhawell authored Sep 26, 2023
1 parent 8090f6e commit 33a56d9
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 14 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
## What's changed

- New method `$unnest()` for `LazyFrame` (#397).
- New method `$sample()` for `DataFrame`.

# polars 0.8.1

Expand Down
41 changes: 35 additions & 6 deletions R/dataframe__frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -1039,21 +1039,20 @@ DataFrame_to_struct = function(name = "") {
#' b = c("one", "two", "three", "four", "five"),
#' c = 6:10
#' )$
#' select(
#' pl$col("b")$to_struct(),
#' pl$col("a", "c")$to_struct()$alias("a_and_c")
#' )
#' select(
#' pl$col("b")$to_struct(),
#' pl$col("a", "c")$to_struct()$alias("a_and_c")
#' )
#' df
#'
#' # by default, all struct columns are unnested
#' df$unnest()
#'
#' # we can specify specific columns to unnest
#' df$unnest("a_and_c")

DataFrame_unnest = function(names = NULL) {
if (is.null(names)) {
names <- names(which(dtypes_are_struct(.pr$DataFrame$schema(self))))
names = names(which(dtypes_are_struct(.pr$DataFrame$schema(self))))
}
unwrap(.pr$DataFrame$unnest(self, names), "in $unnest():")
}
Expand Down Expand Up @@ -1604,3 +1603,33 @@ DataFrame_glimpse = function(..., return_as_string = FALSE) {
DataFrame_explode = function(...) {
self$lazy()$explode(...)$collect()
}

#' Take a sample of rows from a DataFrame
#'
#' @param n Number of rows to return. Cannot be used with `fraction`.
#' @param fraction Fraction of rows to return (between 0 and 1). Cannot be used
#' with `n`.
#' @param with_replacement Allow values to be sampled more than once.
#' @param shuffle If `TRUE`, the order of the sampled rows will be shuffled. If
#' `FALSE` (default), the order of the returned rows will be neither stable nor
#' fully random.
#' @param seed Seed for the random number generator. If set to `NULL` (default),
#' a random seed is generated for each sample operation.
#'
#' @keywords DataFrame
#' @return DataFrame
#' @examples
#' df = pl$DataFrame(iris)
#' df$sample(n = 20)
#' df$sample(frac = 0.1)
DataFrame_sample = function(
n = NULL, fraction = NULL, with_replacement = FALSE, shuffle = FALSE, seed = NULL) {
seed = seed %||% sample(0:10000, 1)
pcase(
!xor(is.null(n), is.null(fraction)), Err_plain("Pass either arg `n` or `fraction`, not both."),
is.null(fraction), .pr$DataFrame$sample_n(self, n, with_replacement, shuffle, seed),
is.null(n), .pr$DataFrame$sample_frac(self, fraction, with_replacement, shuffle, seed),
or_else = Err_plain("internal error")
) |>
unwrap("in $sample():")
}
8 changes: 6 additions & 2 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ DataFrame$melt <- function(id_vars, value_vars, value_name, variable_name) .Call

DataFrame$pivot_expr <- function(values, index, columns, maintain_order, sort_columns, aggregate_expr, separator) .Call(wrap__DataFrame__pivot_expr, self, values, index, columns, maintain_order, sort_columns, aggregate_expr, separator)

DataFrame$sample_n <- function(n, with_replacement, shuffle, seed) .Call(wrap__DataFrame__sample_n, self, n, with_replacement, shuffle, seed)

DataFrame$sample_frac <- function(frac, with_replacement, shuffle, seed) .Call(wrap__DataFrame__sample_frac, self, frac, with_replacement, shuffle, seed)

#' @export
`$.DataFrame` <- function (self, name) { func <- DataFrame[[name]]; environment(func) <- environment(); func }

Expand Down Expand Up @@ -997,6 +1001,8 @@ LazyFrame$slice <- function(offset, length) .Call(wrap__LazyFrame__slice, self,

LazyFrame$with_columns <- function(exprs) .Call(wrap__LazyFrame__with_columns, self, exprs)

LazyFrame$unnest <- function(names) .Call(wrap__LazyFrame__unnest, self, names)

LazyFrame$select <- function(exprs) .Call(wrap__LazyFrame__select, self, exprs)

LazyFrame$select_str_as_lit <- function(exprs) .Call(wrap__LazyFrame__select_str_as_lit, self, exprs)
Expand Down Expand Up @@ -1039,8 +1045,6 @@ LazyFrame$explode <- function(dotdotdot) .Call(wrap__LazyFrame__explode, self, d

LazyFrame$clone_see_me_macro <- function() .Call(wrap__LazyFrame__clone_see_me_macro, self)

LazyFrame$unnest <- function(names) .Call(wrap__LazyFrame__unnest, self, names)

#' @export
`$.LazyFrame` <- function (self, name) { func <- LazyFrame[[name]]; environment(func) <- environment(); func }

Expand Down
41 changes: 41 additions & 0 deletions man/DataFrame_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/DataFrame_unnest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion src/rust/src/lazy/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use polars::frame::hash_join::JoinType;
use polars::prelude as pl;
use polars::prelude::AsOfOptions;


#[allow(unused_imports)]
use std::result::Result;

Expand Down
40 changes: 39 additions & 1 deletion src/rust/src/rdataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ impl DataFrame {
s.into_series().into()
}

pub fn unnest(&self, names: Vec<String>) -> RResult<Self> {
pub fn unnest(&self, names: Vec<String>) -> RResult<Self> {
self.lazy().unnest(names)?.collect()
}

Expand Down Expand Up @@ -379,6 +379,44 @@ impl DataFrame {
.map_err(|err| err.to_string())
.map(|ok| ok.into())
}

pub fn sample_n(
&self,
n: Robj,
with_replacement: Robj,
shuffle: Robj,
seed: Robj,
) -> RResult<Self> {
self.0
.clone()
.sample_n(
robj_to!(usize, n)?,
robj_to!(bool, with_replacement)?,
robj_to!(bool, shuffle)?,
robj_to!(Option, u64, seed)?,
)
.map_err(polars_to_rpolars_err)
.map(DataFrame)
}

pub fn sample_frac(
&self,
frac: Robj,
with_replacement: Robj,
shuffle: Robj,
seed: Robj,
) -> RResult<Self> {
self.0
.clone()
.sample_frac(
robj_to!(f64, frac)?,
robj_to!(bool, with_replacement)?,
robj_to!(bool, shuffle)?,
robj_to!(Option, u64, seed)?,
)
.map_err(polars_to_rpolars_err)
.map(DataFrame)
}
}
impl DataFrame {
pub fn to_list_result(&self) -> Result<Robj, pl::PolarsError> {
Expand Down
27 changes: 27 additions & 0 deletions tests/testthat/test-dataframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -1050,3 +1050,30 @@ test_that("strictly_immutable = FALSE", {

pl$reset_options()
})

test_that("sample", {
df = pl$DataFrame(iris)

# plain use
expect_identical(df$sample(n = 20)$height, 20)
expect_identical(df$sample(frac = 0.1)$height, 15)

# must pass either n or fraction and not both
expect_error(df$sample(), "Pass either arg")
expect_error(df$sample(n = 2, fraction = 0.1), "not both")

# single check of some conversion errors
ctx = df$sample(frac = 0.1, seed = "not even a written number") |> get_err_ctx()
expect_identical(ctx$PlainErrorMessage, "ParseIntError { kind: InvalidDigit }")

# single check on rust-polars errors
ctx = df$sample(n = 151) |> get_err_ctx()
expect_true(isTRUE(grepl("larger sample than the total population", ctx$PolarsError)))
expect_no_error(df$sample(n = 151, with_replacement = TRUE))

# seed works
expect_identical(
df$sample(fraction = 0.1, seed = 123)$to_data_frame(),
df$sample(fraction = 0.1, seed = "123")$to_data_frame()
)
})

0 comments on commit 33a56d9

Please sign in to comment.