Skip to content

Commit

Permalink
refactor: rewrite as_polars_df.data.frame (#817)
Browse files Browse the repository at this point in the history
Co-authored-by: Etienne Bacher <[email protected]>
  • Loading branch information
eitsupi and etiennebacher authored Feb 17, 2024
1 parent 39f4bfa commit d8a6e77
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 26 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
each column. The output of `$flags` for `Series` was also improved and now
contains `FAST_EXPLODE` for `Series` of type `list` and `array` (#809).
- Add string methods `to_lowercase()` and `to_uppercase()` for `Series` (#810).
- `as_polars_df()` for `data.frame` is more memory-efficient and new arguments
`schema` and `schema_overrides` are added (#817).

## Polars R Package 0.14.0

Expand Down
42 changes: 29 additions & 13 deletions R/as_polars.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#' To polars DataFrame
#'
#' [as_polars_df()] is a generic function that converts an R object to a
#' polars DataFrame. It is basically a wrapper for [pl$DataFrame()][pl_DataFrame],
#' but has special implementations for Apache Arrow-based objects such as
#' polars [LazyFrame][LazyFrame_class] and [arrow::Table].
#' [polars DataFrame][DataFrame_class].
#'
#' For [LazyFrame][LazyFrame_class] objects, this function is a shortcut for
#' [$collect()][LazyFrame_collect] or [$fetch()][LazyFrame_fetch], depending on
Expand Down Expand Up @@ -67,17 +65,35 @@ as_polars_df.default = function(x, ...) {
#' @param make_names_unique A logical flag to replace duplicated column names
#' with unique names. If `FALSE` and there are duplicated column names, an
#' error is thrown.
#' @inheritParams as_polars_df.ArrowTabular
#' @export
as_polars_df.data.frame = function(x, ..., rownames = NULL, make_names_unique = TRUE) {
if ((anyDuplicated(names(x)) > 0) && make_names_unique) {
names(x) = make.unique(names(x), sep = "_")
as_polars_df.data.frame = function(
x,
...,
rownames = NULL,
make_names_unique = TRUE,
schema = NULL,
schema_overrides = NULL) {
uw = \(res) unwrap(res, "in as_polars_df():")

if (anyDuplicated(names(x)) > 0) {
col_names_orig = names(x)
if (make_names_unique) {
names(x) = make.unique(col_names_orig, sep = "_")
} else {
Err_plain(
paste(
"conflicting column names not allowed:",
paste(unique(col_names_orig[duplicated(col_names_orig)]), collapse = ", ")
)
) |>
uw()
}
}

if (is.null(rownames)) {
pl$DataFrame(x, make_names_unique = FALSE)
df_to_rpldf(x, schema = schema, schema_overrides = schema_overrides)
} else {
uw = \(res) unwrap(res, "in as_polars_df():")

if (length(rownames) != 1L || !is.character(rownames) || is.na(rownames)) {
Err_plain("`rownames` must be a single string, or `NULL`") |>
uw()
Expand All @@ -102,7 +118,7 @@ as_polars_df.data.frame = function(x, ..., rownames = NULL, make_names_unique =

pl$concat(
pl$Series(old_rownames, name = rownames),
pl$DataFrame(x, make_names_unique = FALSE),
df_to_rpldf(x, schema = schema, schema_overrides = schema_overrides),
how = "horizontal"
)
}
Expand Down Expand Up @@ -133,7 +149,7 @@ as_polars_df.RPolarsDynamicGroupBy = as_polars_df.RPolarsGroupBy
#' @rdname as_polars_df
#' @export
as_polars_df.RPolarsSeries = function(x, ...) {
pl$DataFrame(x)
pl$select(x)
}


Expand Down Expand Up @@ -197,7 +213,7 @@ as_polars_df.ArrowTabular = function(
rechunk = TRUE,
schema = NULL,
schema_overrides = NULL) {
arrow_to_rdf(
arrow_to_rpldf(
x,
rechunk = rechunk,
schema = schema,
Expand Down Expand Up @@ -236,7 +252,7 @@ as_polars_df.nanoarrow_array_stream = function(x, ...) {
}
}

out = do.call(pl$DataFrame, data_cols)
out = do.call(pl$select, data_cols)
} else {
out = pl$DataFrame() # TODO: support creating 0-row DataFrame
}
Expand Down
65 changes: 63 additions & 2 deletions R/construction.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#' @param schema_overrides named list of DataTypes. Cast some columns to the DataType.
#' @noRd
#' @return RPolarsDataFrame
arrow_to_rdf = function(at, schema = NULL, schema_overrides = NULL, rechunk = TRUE) {
arrow_to_rpldf = function(at, schema = NULL, schema_overrides = NULL, rechunk = TRUE) {
# new column names by schema, #todo get names if schema not NULL
n_cols = at$num_columns

Expand All @@ -21,7 +21,10 @@ arrow_to_rdf = function(at, schema = NULL, schema_overrides = NULL, rechunk = TR
)
col_names = names(new_schema)

if (length(col_names) != n_cols) stop("schema length does not match column length")
if (length(col_names) != n_cols) {
Err_plain("schema length does not match column length") |>
unwrap()
}

data_cols = list()
# dictionaries cannot be built in different batches (categorical does not allow
Expand Down Expand Up @@ -199,3 +202,61 @@ arrow_to_rseries_result = function(name, values, rechunk = TRUE) {

res
}


#' Internal function of `as_polars_df()` for `data.frame` class objects.
#'
#' This is a copy of `arrow_to_rpldf`
#' @noRd
#' @return RPolarsDataFrame
df_to_rpldf = function(x, ..., schema = NULL, schema_overrides = NULL) {
n_cols = ncol(x)

new_schema = unpack_schema(
schema = schema %||% names(x),
schema_overrides = schema_overrides
)
col_names = names(new_schema)

if (length(col_names) != n_cols) {
Err_plain("schema length does not match column length") |>
unwrap()
}

data_cols = list()

for (i in seq_len(n_cols)) {
column = as_polars_series(x[[i]])
col_name = col_names[i]

data_cols[[col_name]] = column
}

if (length(data_cols)) {
out = do.call(pl$select, data_cols)
} else {
out = pl$DataFrame()
}

cast_these_fields = mapply(
new_schema,
out$schema,
FUN = \(new_field, df_field) {
if (is.null(new_field) || new_field == df_field) NULL else new_field
},
SIMPLIFY = FALSE
) |> (\(l) l[!sapply(l, is.null)])()

if (length(cast_these_fields)) {
out = out$with_columns(
mapply(
cast_these_fields,
names(cast_these_fields),
FUN = \(dtype, name) pl$col(name)$cast(dtype),
SIMPLIFY = FALSE
) |> unname()
)
}

out
}
27 changes: 16 additions & 11 deletions man/as_polars_df.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/testthat/test-as_polars.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,29 @@ test_that("as_polars_df throws error when make_names_unique = FALSE and there ar
})


test_that("schema option and schema_overrides for as_polars_df.data.frame", {
df = data.frame(a = 1:3, b = 4:6)
pl_df_1 = as_polars_df(df, schema = list(a = pl$String, b = pl$Int32))
pl_df_2 = as_polars_df(df, schema = c("x", "y"))
pl_df_3 = as_polars_df(df, schema_overrides = list(a = pl$String))

expect_equal(
pl_df_1$to_data_frame(),
data.frame(a = as.character(1:3), b = 4L:6L)
)
expect_equal(
pl_df_2$to_data_frame(),
data.frame(x = 1:3, y = 4:6)
)
expect_equal(
pl_df_3$to_data_frame(),
data.frame(a = as.character(1:3), b = 4:6)
)

expect_error(as_polars_df(mtcars, schema = "cyl"), "schema length does not match")
})


if (requireNamespace("arrow", quietly = TRUE) && requireNamespace("nanoarrow", quietly = TRUE)) {
make_as_polars_series_cases = function() {
tibble::tribble(
Expand Down

0 comments on commit d8a6e77

Please sign in to comment.