From d8a6e77abd1f5c386a92fad4051685513be8f86c Mon Sep 17 00:00:00 2001 From: eitsupi <50911393+eitsupi@users.noreply.github.com> Date: Sun, 18 Feb 2024 00:13:22 +0900 Subject: [PATCH] refactor: rewrite `as_polars_df.data.frame` (#817) Co-authored-by: Etienne Bacher <52219252+etiennebacher@users.noreply.github.com> --- NEWS.md | 2 + R/as_polars.R | 42 ++++++++++++++------- R/construction.R | 65 ++++++++++++++++++++++++++++++++- man/as_polars_df.Rd | 27 ++++++++------ tests/testthat/test-as_polars.R | 23 ++++++++++++ 5 files changed, 133 insertions(+), 26 deletions(-) diff --git a/NEWS.md b/NEWS.md index 5224489dc..5fdd0dd4a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,8 @@ each column. The output of `$flags` for `Series` was also improved and now contains `FAST_EXPLODE` for `Series` of type `list` and `array` (#809). - Add string methods `to_lowercase()` and `to_uppercase()` for `Series` (#810). +- `as_polars_df()` for `data.frame` is more memory-efficient and new arguments + `schema` and `schema_overrides` are added (#817). ## Polars R Package 0.14.0 diff --git a/R/as_polars.R b/R/as_polars.R index eb3386eb5..4d8c9fff6 100644 --- a/R/as_polars.R +++ b/R/as_polars.R @@ -1,9 +1,7 @@ #' To polars DataFrame #' #' [as_polars_df()] is a generic function that converts an R object to a -#' polars DataFrame. It is basically a wrapper for [pl$DataFrame()][pl_DataFrame], -#' but has special implementations for Apache Arrow-based objects such as -#' polars [LazyFrame][LazyFrame_class] and [arrow::Table]. +#' [polars DataFrame][DataFrame_class]. #' #' For [LazyFrame][LazyFrame_class] objects, this function is a shortcut for #' [$collect()][LazyFrame_collect] or [$fetch()][LazyFrame_fetch], depending on @@ -67,17 +65,35 @@ as_polars_df.default = function(x, ...) { #' @param make_names_unique A logical flag to replace duplicated column names #' with unique names. If `FALSE` and there are duplicated column names, an #' error is thrown. +#' @inheritParams as_polars_df.ArrowTabular #' @export -as_polars_df.data.frame = function(x, ..., rownames = NULL, make_names_unique = TRUE) { - if ((anyDuplicated(names(x)) > 0) && make_names_unique) { - names(x) = make.unique(names(x), sep = "_") +as_polars_df.data.frame = function( + x, + ..., + rownames = NULL, + make_names_unique = TRUE, + schema = NULL, + schema_overrides = NULL) { + uw = \(res) unwrap(res, "in as_polars_df():") + + if (anyDuplicated(names(x)) > 0) { + col_names_orig = names(x) + if (make_names_unique) { + names(x) = make.unique(col_names_orig, sep = "_") + } else { + Err_plain( + paste( + "conflicting column names not allowed:", + paste(unique(col_names_orig[duplicated(col_names_orig)]), collapse = ", ") + ) + ) |> + uw() + } } if (is.null(rownames)) { - pl$DataFrame(x, make_names_unique = FALSE) + df_to_rpldf(x, schema = schema, schema_overrides = schema_overrides) } else { - uw = \(res) unwrap(res, "in as_polars_df():") - if (length(rownames) != 1L || !is.character(rownames) || is.na(rownames)) { Err_plain("`rownames` must be a single string, or `NULL`") |> uw() @@ -102,7 +118,7 @@ as_polars_df.data.frame = function(x, ..., rownames = NULL, make_names_unique = pl$concat( pl$Series(old_rownames, name = rownames), - pl$DataFrame(x, make_names_unique = FALSE), + df_to_rpldf(x, schema = schema, schema_overrides = schema_overrides), how = "horizontal" ) } @@ -133,7 +149,7 @@ as_polars_df.RPolarsDynamicGroupBy = as_polars_df.RPolarsGroupBy #' @rdname as_polars_df #' @export as_polars_df.RPolarsSeries = function(x, ...) { - pl$DataFrame(x) + pl$select(x) } @@ -197,7 +213,7 @@ as_polars_df.ArrowTabular = function( rechunk = TRUE, schema = NULL, schema_overrides = NULL) { - arrow_to_rdf( + arrow_to_rpldf( x, rechunk = rechunk, schema = schema, @@ -236,7 +252,7 @@ as_polars_df.nanoarrow_array_stream = function(x, ...) { } } - out = do.call(pl$DataFrame, data_cols) + out = do.call(pl$select, data_cols) } else { out = pl$DataFrame() # TODO: support creating 0-row DataFrame } diff --git a/R/construction.R b/R/construction.R index 8c304c066..2271845e6 100644 --- a/R/construction.R +++ b/R/construction.R @@ -11,7 +11,7 @@ #' @param schema_overrides named list of DataTypes. Cast some columns to the DataType. #' @noRd #' @return RPolarsDataFrame -arrow_to_rdf = function(at, schema = NULL, schema_overrides = NULL, rechunk = TRUE) { +arrow_to_rpldf = function(at, schema = NULL, schema_overrides = NULL, rechunk = TRUE) { # new column names by schema, #todo get names if schema not NULL n_cols = at$num_columns @@ -21,7 +21,10 @@ arrow_to_rdf = function(at, schema = NULL, schema_overrides = NULL, rechunk = TR ) col_names = names(new_schema) - if (length(col_names) != n_cols) stop("schema length does not match column length") + if (length(col_names) != n_cols) { + Err_plain("schema length does not match column length") |> + unwrap() + } data_cols = list() # dictionaries cannot be built in different batches (categorical does not allow @@ -199,3 +202,61 @@ arrow_to_rseries_result = function(name, values, rechunk = TRUE) { res } + + +#' Internal function of `as_polars_df()` for `data.frame` class objects. +#' +#' This is a copy of `arrow_to_rpldf` +#' @noRd +#' @return RPolarsDataFrame +df_to_rpldf = function(x, ..., schema = NULL, schema_overrides = NULL) { + n_cols = ncol(x) + + new_schema = unpack_schema( + schema = schema %||% names(x), + schema_overrides = schema_overrides + ) + col_names = names(new_schema) + + if (length(col_names) != n_cols) { + Err_plain("schema length does not match column length") |> + unwrap() + } + + data_cols = list() + + for (i in seq_len(n_cols)) { + column = as_polars_series(x[[i]]) + col_name = col_names[i] + + data_cols[[col_name]] = column + } + + if (length(data_cols)) { + out = do.call(pl$select, data_cols) + } else { + out = pl$DataFrame() + } + + cast_these_fields = mapply( + new_schema, + out$schema, + FUN = \(new_field, df_field) { + if (is.null(new_field) || new_field == df_field) NULL else new_field + }, + SIMPLIFY = FALSE + ) |> (\(l) l[!sapply(l, is.null)])() + + if (length(cast_these_fields)) { + out = out$with_columns( + mapply( + cast_these_fields, + names(cast_these_fields), + FUN = \(dtype, name) pl$col(name)$cast(dtype), + SIMPLIFY = FALSE + ) |> unname() + ) + } + + out +} diff --git a/man/as_polars_df.Rd b/man/as_polars_df.Rd index d616401b8..b23e72f30 100644 --- a/man/as_polars_df.Rd +++ b/man/as_polars_df.Rd @@ -19,7 +19,14 @@ as_polars_df(x, ...) \method{as_polars_df}{default}(x, ...) -\method{as_polars_df}{data.frame}(x, ..., rownames = NULL, make_names_unique = TRUE) +\method{as_polars_df}{data.frame}( + x, + ..., + rownames = NULL, + make_names_unique = TRUE, + schema = NULL, + schema_overrides = NULL +) \method{as_polars_df}{RPolarsDataFrame}(x, ...) @@ -70,6 +77,13 @@ If \code{x} already has a column with that name, an error is thrown. with unique names. If \code{FALSE} and there are duplicated column names, an error is thrown.} +\item{schema}{named list of DataTypes, or character vector of column names. +Should be the same length as the number of columns of \code{x}. +If schema names or types do not match \code{x}, the columns will be renamed/recast. +If \code{NULL} (default), convert columns as is.} + +\item{schema_overrides}{named list of DataTypes. Cast some columns to the DataType.} + \item{n_rows}{Number of rows to fetch. Defaults to \code{Inf}, meaning all rows.} \item{type_coercion}{Boolean. Coerce types such that operations succeed and @@ -109,22 +123,13 @@ into the resulting DataFrame. Useful in interactive mode to not lock R session.} \item{rechunk}{A logical flag (default \code{TRUE}). Make sure that all data of each column is in contiguous memory.} - -\item{schema}{named list of DataTypes, or character vector of column names. -Should be the same length as the number of columns of \code{x}. -If schema names or types do not match \code{x}, the columns will be renamed/recast. -If \code{NULL} (default), convert columns as is.} - -\item{schema_overrides}{named list of DataTypes. Cast some columns to the DataType.} } \value{ a \link[=DataFrame_class]{DataFrame} } \description{ \code{\link[=as_polars_df]{as_polars_df()}} is a generic function that converts an R object to a -polars DataFrame. It is basically a wrapper for \link[=pl_DataFrame]{pl$DataFrame()}, -but has special implementations for Apache Arrow-based objects such as -polars \link[=LazyFrame_class]{LazyFrame} and \link[arrow:Table-class]{arrow::Table}. +\link[=DataFrame_class]{polars DataFrame}. } \details{ For \link[=LazyFrame_class]{LazyFrame} objects, this function is a shortcut for diff --git a/tests/testthat/test-as_polars.R b/tests/testthat/test-as_polars.R index 5287dd5d2..d28d24586 100644 --- a/tests/testthat/test-as_polars.R +++ b/tests/testthat/test-as_polars.R @@ -99,6 +99,29 @@ test_that("as_polars_df throws error when make_names_unique = FALSE and there ar }) +test_that("schema option and schema_overrides for as_polars_df.data.frame", { + df = data.frame(a = 1:3, b = 4:6) + pl_df_1 = as_polars_df(df, schema = list(a = pl$String, b = pl$Int32)) + pl_df_2 = as_polars_df(df, schema = c("x", "y")) + pl_df_3 = as_polars_df(df, schema_overrides = list(a = pl$String)) + + expect_equal( + pl_df_1$to_data_frame(), + data.frame(a = as.character(1:3), b = 4L:6L) + ) + expect_equal( + pl_df_2$to_data_frame(), + data.frame(x = 1:3, y = 4:6) + ) + expect_equal( + pl_df_3$to_data_frame(), + data.frame(a = as.character(1:3), b = 4:6) + ) + + expect_error(as_polars_df(mtcars, schema = "cyl"), "schema length does not match") +}) + + if (requireNamespace("arrow", quietly = TRUE) && requireNamespace("nanoarrow", quietly = TRUE)) { make_as_polars_series_cases = function() { tibble::tribble(