From 60496472f9390e690bf9290f9c98a9d37b526710 Mon Sep 17 00:00:00 2001 From: Etienne Bacher <52219252+etiennebacher@users.noreply.github.com> Date: Wed, 4 Sep 2024 13:49:20 +0200 Subject: [PATCH] init [skip ci] --- R/expr__expr.R | 21 ++++++++++++++------- R/extendr-wrappers.R | 2 +- src/rust/src/lazy/dsl.rs | 18 ++++++++++++------ 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/R/expr__expr.R b/R/expr__expr.R index 52afcc79e..f90636b75 100644 --- a/R/expr__expr.R +++ b/R/expr__expr.R @@ -1022,26 +1022,33 @@ Expr_to_physical = use_extendr_wrapper #' @param dtype DataType to cast to. #' @param strict If `TRUE` (default), an error will be thrown if cast failed at #' resolve time. +#' @param wrap_numerical If `TRUE`, numeric casts wrap overflowing values instead +#' of marking the cast as invalid. +#' #' @return Expr #' @examples -#' df = pl$DataFrame(a = 1:3, b = c(1, 2, 3)) +#' df = pl$DataFrame(a = 1:3, b = c(1, 2, 3), c = c(100, 200, 300)) #' df$with_columns( -#' pl$col("a")$cast(pl$dtypes$Float64), -#' pl$col("b")$cast(pl$dtypes$Int32) +#' pl$col("a")$cast(pl$Float64), +#' pl$col("b")$cast(pl$Int32) #' ) #' #' # strict FALSE, inserts null for any cast failure -#' pl$lit(c(100, 200, 300))$cast(pl$dtypes$UInt8, strict = FALSE)$to_series() +#' df$with_columns(pl$col("c")$cast(pl$UInt8, strict = FALSE)) +#' +#' # wrap_numerical doesn't error in case of overflow but rather wraps the value +#' # to fit in the datatype +#' df$with_columns(pl$col("c")$cast(pl$UInt8, wrap_numerical = TRUE)) #' #' # strict TRUE, raise any failure as an error when query is executed. #' tryCatch( #' { -#' pl$lit("a")$cast(pl$dtypes$Float64, strict = TRUE)$to_series() +#' pl$lit("a")$cast(pl$Float64, strict = TRUE)$to_series() #' }, #' error = function(e) e #' ) -Expr_cast = function(dtype, strict = TRUE) { - .pr$Expr$cast(self, dtype, strict) +Expr_cast = function(dtype, strict = TRUE, wrap_numerical = FALSE) { + .pr$Expr$cast(self, dtype, strict, wrap_numerical) } #' Compute the square root of the elements diff --git a/R/extendr-wrappers.R b/R/extendr-wrappers.R index 166d790a9..0a611b96c 100644 --- a/R/extendr-wrappers.R +++ b/R/extendr-wrappers.R @@ -526,7 +526,7 @@ RPolarsExpr$xor <- function(other) .Call(wrap__RPolarsExpr__xor, self, other) RPolarsExpr$to_physical <- function() .Call(wrap__RPolarsExpr__to_physical, self) -RPolarsExpr$cast <- function(data_type, strict) .Call(wrap__RPolarsExpr__cast, self, data_type, strict) +RPolarsExpr$cast <- function(data_type, strict, wrap_numerical) .Call(wrap__RPolarsExpr__cast, self, data_type, strict, wrap_numerical) RPolarsExpr$sort_with <- function(descending, nulls_last) .Call(wrap__RPolarsExpr__sort_with, self, descending, nulls_last) diff --git a/src/rust/src/lazy/dsl.rs b/src/rust/src/lazy/dsl.rs index 8cd203ad1..46b5d9895 100644 --- a/src/rust/src/lazy/dsl.rs +++ b/src/rust/src/lazy/dsl.rs @@ -227,14 +227,20 @@ impl RPolarsExpr { self.0.clone().to_physical().into() } - pub fn cast(&self, data_type: &RPolarsDataType, strict: bool) -> Self { + pub fn cast(&self, data_type: &RPolarsDataType, strict: bool, wrap_numerical: bool) -> Self { + use polars::chunked_array::cast::CastOptions; let dt = data_type.0.clone(); - if strict { - self.0.clone().strict_cast(dt) + + let options = if wrap_numerical { + CastOptions::Overflowing + } else if strict { + CastOptions::Strict } else { - self.0.clone().cast(dt) - } - .into() + CastOptions::NonStrict + }; + + let expr = self.0.clone().cast_with_options(dt, options); + expr.into() } pub fn sort_with(&self, descending: bool, nulls_last: bool) -> Self {