Skip to content

Commit

Permalink
fix(spark): casting date/time requires timezone (#318)
Browse files Browse the repository at this point in the history
When casting a date type to a string, Spark requires that a timezone is specified,
otherwise it will not resolve the logical plan.
The timezone is ignored for non-date/time values.

Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman authored Dec 19, 2024
1 parent 2d44a92 commit af5a615
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtens
import io.substrait.spark.logical.ToLogicalPlan

import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DateType, Decimal}
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -153,7 +154,12 @@ class ToSparkExpression(

override def visit(expr: SExpression.Cast): Expression = {
val childExp = expr.input().accept(this)
Cast(childExp, ToSubstraitType.convert(expr.getType))
val tt = ToSubstraitType.convert(expr.getType)
val tz = childExp.dataType match {
case DateType => Some(SQLConf.get.getConf(SQLConf.SESSION_LOCAL_TIMEZONE))
case _ => None
}
Cast(childExp, tt, tz)
}

override def visit(expr: exp.FieldReference): Expression = {
Expand Down Expand Up @@ -197,6 +203,7 @@ class ToSparkExpression(
val list = expr.options().asScala.map(e => e.accept(this))
In(value, list)
}

override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = {
val eArgs = expr.arguments().asScala
val args = eArgs.zipWithIndex.map {
Expand Down
2 changes: 1 addition & 1 deletion spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
"q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713
"q9", // requires implementation of named_struct()
"q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type)
"q51", "q83", "q84", // TBD
"q51", "q84", // TBD
"q72" //requires implementation of date_add()
)
// spotless:on
Expand Down

0 comments on commit af5a615

Please sign in to comment.