Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement $sample() for DataFrame #399

Merged
merged 4 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
## What's changed

- New method `$unnest()` for `LazyFrame` (#397).
- New method `$sample()` for `DataFrame`.

# polars 0.8.1

Expand Down
41 changes: 35 additions & 6 deletions R/dataframe__frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -1039,21 +1039,20 @@ DataFrame_to_struct = function(name = "") {
#' b = c("one", "two", "three", "four", "five"),
#' c = 6:10
#' )$
#' select(
#' pl$col("b")$to_struct(),
#' pl$col("a", "c")$to_struct()$alias("a_and_c")
#' )
#' select(
#' pl$col("b")$to_struct(),
#' pl$col("a", "c")$to_struct()$alias("a_and_c")
#' )
#' df
#'
#' # by default, all struct columns are unnested
#' df$unnest()
#'
#' # we can specify specific columns to unnest
#' df$unnest("a_and_c")

DataFrame_unnest = function(names = NULL) {
if (is.null(names)) {
names <- names(which(dtypes_are_struct(.pr$DataFrame$schema(self))))
names = names(which(dtypes_are_struct(.pr$DataFrame$schema(self))))
}
unwrap(.pr$DataFrame$unnest(self, names), "in $unnest():")
}
Expand Down Expand Up @@ -1604,3 +1603,33 @@ DataFrame_glimpse = function(..., return_as_string = FALSE) {
DataFrame_explode = function(...) {
self$lazy()$explode(...)$collect()
}

#' Take a sample of rows from a DataFrame
#'
#' @param n Number of rows to return. Cannot be used with `fraction`.
#' @param fraction Fraction of rows to return (between 0 and 1). Cannot be used
#' with `n`.
#' @param with_replacement Allow values to be sampled more than once.
#' @param shuffle If `TRUE`, the order of the sampled rows will be shuffled. If
#' `FALSE` (default), the order of the returned rows will be neither stable nor
#' fully random.
#' @param seed Seed for the random number generator. If set to `NULL` (default),
#' a random seed is generated for each sample operation.
#'
#' @keywords DataFrame
#' @return DataFrame
#' @examples
#' df = pl$DataFrame(iris)
#' df$sample(n = 20)
#' df$sample(frac = 0.1)
DataFrame_sample = function(
n = NULL, fraction = NULL, with_replacement = FALSE, shuffle = FALSE, seed = NULL) {
seed = seed %||% sample(0:10000, 1)
pcase(
!xor(is.null(n), is.null(fraction)), Err_plain("Pass either arg `n` or `fraction`, not both."),
is.null(fraction), .pr$DataFrame$sample_n(self, n, with_replacement, shuffle, seed),
is.null(n), .pr$DataFrame$sample_frac(self, fraction, with_replacement, shuffle, seed),
or_else = Err_plain("internal error")
) |>
unwrap("in $sample():")
}
8 changes: 6 additions & 2 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ DataFrame$melt <- function(id_vars, value_vars, value_name, variable_name) .Call

DataFrame$pivot_expr <- function(values, index, columns, maintain_order, sort_columns, aggregate_expr, separator) .Call(wrap__DataFrame__pivot_expr, self, values, index, columns, maintain_order, sort_columns, aggregate_expr, separator)

DataFrame$sample_n <- function(n, with_replacement, shuffle, seed) .Call(wrap__DataFrame__sample_n, self, n, with_replacement, shuffle, seed)

DataFrame$sample_frac <- function(frac, with_replacement, shuffle, seed) .Call(wrap__DataFrame__sample_frac, self, frac, with_replacement, shuffle, seed)

#' @export
`$.DataFrame` <- function (self, name) { func <- DataFrame[[name]]; environment(func) <- environment(); func }

Expand Down Expand Up @@ -997,6 +1001,8 @@ LazyFrame$slice <- function(offset, length) .Call(wrap__LazyFrame__slice, self,

LazyFrame$with_columns <- function(exprs) .Call(wrap__LazyFrame__with_columns, self, exprs)

LazyFrame$unnest <- function(names) .Call(wrap__LazyFrame__unnest, self, names)

LazyFrame$select <- function(exprs) .Call(wrap__LazyFrame__select, self, exprs)

LazyFrame$select_str_as_lit <- function(exprs) .Call(wrap__LazyFrame__select_str_as_lit, self, exprs)
Expand Down Expand Up @@ -1039,8 +1045,6 @@ LazyFrame$explode <- function(dotdotdot) .Call(wrap__LazyFrame__explode, self, d

LazyFrame$clone_see_me_macro <- function() .Call(wrap__LazyFrame__clone_see_me_macro, self)

LazyFrame$unnest <- function(names) .Call(wrap__LazyFrame__unnest, self, names)

#' @export
`$.LazyFrame` <- function (self, name) { func <- LazyFrame[[name]]; environment(func) <- environment(); func }

Expand Down
41 changes: 41 additions & 0 deletions man/DataFrame_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/DataFrame_unnest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion src/rust/src/lazy/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use polars::frame::hash_join::JoinType;
use polars::prelude as pl;
use polars::prelude::AsOfOptions;


#[allow(unused_imports)]
use std::result::Result;

Expand Down
40 changes: 39 additions & 1 deletion src/rust/src/rdataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ impl DataFrame {
s.into_series().into()
}

pub fn unnest(&self, names: Vec<String>) -> RResult<Self> {
pub fn unnest(&self, names: Vec<String>) -> RResult<Self> {
self.lazy().unnest(names)?.collect()
}

Expand Down Expand Up @@ -379,6 +379,44 @@ impl DataFrame {
.map_err(|err| err.to_string())
.map(|ok| ok.into())
}

pub fn sample_n(
&self,
n: Robj,
with_replacement: Robj,
shuffle: Robj,
seed: Robj,
) -> RResult<Self> {
self.0
.clone()
.sample_n(
robj_to!(usize, n)?,
robj_to!(bool, with_replacement)?,
robj_to!(bool, shuffle)?,
robj_to!(Option, u64, seed)?,
)
.map_err(polars_to_rpolars_err)
.map(DataFrame)
}

pub fn sample_frac(
&self,
frac: Robj,
with_replacement: Robj,
shuffle: Robj,
seed: Robj,
) -> RResult<Self> {
self.0
.clone()
.sample_frac(
robj_to!(f64, frac)?,
robj_to!(bool, with_replacement)?,
robj_to!(bool, shuffle)?,
robj_to!(Option, u64, seed)?,
)
.map_err(polars_to_rpolars_err)
.map(DataFrame)
}
}
impl DataFrame {
pub fn to_list_result(&self) -> Result<Robj, pl::PolarsError> {
Expand Down
27 changes: 27 additions & 0 deletions tests/testthat/test-dataframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -1050,3 +1050,30 @@ test_that("strictly_immutable = FALSE", {

pl$reset_options()
})

test_that("sample", {
df = pl$DataFrame(iris)

# plain use
expect_identical(df$sample(n = 20)$height, 20)
expect_identical(df$sample(frac = 0.1)$height, 15)

# must pass either n or fraction and not both
expect_error(df$sample(), "Pass either arg")
expect_error(df$sample(n = 2, fraction = 0.1), "not both")

# single check of some conversion errors
ctx = df$sample(frac = 0.1, seed = "not even a written number") |> get_err_ctx()
expect_identical(ctx$PlainErrorMessage, "ParseIntError { kind: InvalidDigit }")

# single check on rust-polars errors
ctx = df$sample(n = 151) |> get_err_ctx()
expect_true(isTRUE(grepl("larger sample than the total population", ctx$PolarsError)))
expect_no_error(df$sample(n = 151, with_replacement = TRUE))

# seed works
expect_identical(
df$sample(fraction = 0.1, seed = 123)$to_data_frame(),
df$sample(fraction = 0.1, seed = "123")$to_data_frame()
)
})