Skip to content

Commit

Permalink
Use IfExpr to check when input to log2 is <=0 and return null (apache…
Browse files Browse the repository at this point in the history
…#506)

## Which issue does this PR close?

Closes apache#485 

## Rationale for this change

Compatibility with how Spark handles logarithms of values <=0. 

## What changes are included in this PR?

Use IfExpr to check when input to log2 is <=0 and return null.  This is done to match Spark's behavior, which in turn is implemented to match Hive's behavior.

## How are these changes tested?

The existing test for `ln`, `log2` and `log10` was modified so that it includes negative numbers as part of the inputs being tested.
  • Loading branch information
lithium323 authored Jul 15, 2024
1 parent de8c55e commit 6f9b56a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
6 changes: 3 additions & 3 deletions docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ The following Spark expressions are currently available. Any known compatibility
| Exp | |
| Floor | |
| IsNaN | |
| Log | log(0) will produce `-Infinity` unlike Spark which returns `null` |
| Log2 | log2(0) will produce `-Infinity` unlike Spark which returns `null` |
| Log10 | log10(0) will produce `-Infinity` unlike Spark which returns `null` |
| Log | |
| Log2 | |
| Log10 | |
| Pow | |
| Round | |
| Signum | Signum does not differentiate between `0.0` and `-0.0` |
Expand Down
14 changes: 11 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1703,18 +1703,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
optExprWithInfo(optExpr, expr, child)
}

// The expression for `log` functions is defined as null on numbers less than or equal
// to 0. This matches Spark and Hive behavior, where non positive values eval to null
// instead of NaN or -Infinity.
case Log(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
val optExpr = scalarExprToProto("ln", childExpr)
optExprWithInfo(optExpr, expr, child)

case Log10(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
val optExpr = scalarExprToProto("log10", childExpr)
optExprWithInfo(optExpr, expr, child)

case Log2(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
val optExpr = scalarExprToProto("log2", childExpr)
optExprWithInfo(optExpr, expr, child)

Expand Down Expand Up @@ -2393,6 +2396,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
expression
}

def nullIfNegative(expression: Expression): Expression = {
val zero = Literal.default(expression.dataType)
If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression)
}

/**
* Returns true if given datatype is supported as a key in DataFusion sort merge join.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
Seq("true", "false").foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary) {
withParquetTable(
(0 until 5).map(i => (i.toDouble + 0.3, i.toDouble + 0.8)),
(-5 until 5).map(i => (i.toDouble + 0.3, i.toDouble + 0.8)),
"tbl",
withDictionary = dictionary.toBoolean) {
checkSparkAnswerWithTol(
Expand Down

0 comments on commit 6f9b56a

Please sign in to comment.