Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix joins to include partial null results #767

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class ChrononKryoRegistrator extends KryoRegistrator {
"org.apache.spark.sql.types.TimestampType$",
"org.apache.spark.util.sketch.BitArray",
"org.apache.spark.util.sketch.BloomFilterImpl",
"org.apache.spark.util.collection.BitSet",
"org.apache.spark.util.collection.CompactBuffer",
"scala.reflect.ClassTag$$anon$1",
"scala.math.Ordering$$anon$4",
Expand Down
5 changes: 3 additions & 2 deletions spark/src/main/scala/ai/chronon/spark/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,11 @@ object Extensions {
bloomFilter
}

def removeNulls(cols: Seq[String]): DataFrame = {
def removeNulls(cols: Seq[String], includePartial: Boolean): DataFrame = {
logger.info(s"filtering nulls from columns: [${cols.mkString(", ")}]")
val comparison = if (includePartial) "OR" else "AND"
// do not use != or <> operator with null, it doesn't return false ever!
df.filter(cols.map(_ + " IS NOT NULL").mkString(" AND "))
df.filter(cols.map(_ + " IS NOT NULL").mkString(s" $comparison "))
}

def nullSafeJoin(right: DataFrame, keys: Seq[String], joinType: String): DataFrame = {
Expand Down
7 changes: 3 additions & 4 deletions spark/src/main/scala/ai/chronon/spark/GroupBy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
def temporalEntities(queriesUnfilteredDf: DataFrame, resolution: Resolution = FiveMinuteResolution): DataFrame = {

// Add extra column to the queries and generate the key hash.
val queriesDf = queriesUnfilteredDf.removeNulls(keyColumns)
val queriesDf = queriesUnfilteredDf.removeNulls(keyColumns, includePartial = true)
val timeBasedPartitionColumn = "ds_of_ts"
val queriesWithTimeBasedPartition = queriesDf.withTimeBasedColumn(timeBasedPartitionColumn)

Expand Down Expand Up @@ -282,7 +282,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation],

val queriesDf = skewFilter
.map { queriesUnfilteredDf.filter }
.getOrElse(queriesUnfilteredDf.removeNulls(keyColumns))
.getOrElse(queriesUnfilteredDf.removeNulls(keyColumns, includePartial = true))

val TimeRange(minQueryTs, maxQueryTs) = queryTimeRange.getOrElse(queriesDf.timeRange)
val hopsRdd = hopsAggregate(minQueryTs, resolution)
Expand Down Expand Up @@ -506,8 +506,7 @@ object GroupBy {
val processedInputDf = bloomMapOpt.map { skewFilteredDf.filterBloom }.getOrElse { skewFilteredDf }

// at-least one of the keys should be present in the row.
val nullFilterClause = groupByConf.keyColumns.toScala.map(key => s"($key IS NOT NULL)").mkString(" OR ")
val nullFiltered = processedInputDf.filter(nullFilterClause)
val nullFiltered = processedInputDf.removeNulls(groupByConf.keyColumns.toScala, includePartial = true)
if (showDf) {
logger.info(s"printing input date for groupBy: ${groupByConf.metaData.name}")
nullFiltered.prettyPrint()
Expand Down
29 changes: 20 additions & 9 deletions spark/src/main/scala/ai/chronon/spark/JoinUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import ai.chronon.spark.Extensions._
import com.google.gson.Gson
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{coalesce, col, udf}
import org.apache.spark.sql.functions.{coalesce, col, lit, udf, when}
import org.apache.spark.util.sketch.BloomFilter

import scala.collection.compat._
Expand Down Expand Up @@ -137,23 +137,34 @@ object JoinUtils {
s"Column '$column' has mismatched data types - left type: $leftDataType vs. right type $rightDataType")
}

val joinedDf = leftDf.join(rightDf, keys.toSeq, joinType)
val leftJoinDf = leftDf.alias("leftDf")
val rightJoinDf = rightDf.alias("rightDf")

val joinCondition = keys
.map { key =>
(col(s"leftDf.$key") === col(s"rightDf.$key")) or
(col(s"leftDf.$key").isNull and col(s"rightDf.$key").isNull)
}
.reduce(_ && _)

val joinedDf = leftJoinDf.join(rightJoinDf, joinCondition, joinType)

// find columns that exist both on left and right that are not keys and coalesce them
val selects = keys.map(col) ++
leftDf.columns.flatMap { colName =>
val selects = keys.map(k => col(s"leftDf.$k").as(k)) ++
leftJoinDf.columns.flatMap { colName =>
if (keys.contains(colName)) {
None
} else if (sharedColumns.contains(colName)) {
Some(coalesce(leftDf(colName), rightDf(colName)).as(colName))
Some(coalesce(col(s"leftDf.$colName"), col(s"rightDf.$colName")).as(colName))
} else {
Some(leftDf(colName))
Some(col(s"leftDf.$colName").as(colName))
}
} ++
rightDf.columns.flatMap { colName =>
rightJoinDf.columns.flatMap { colName =>
if (sharedColumns.contains(colName)) {
None // already selected previously
} else {
Some(rightDf(colName))
Some(col(s"rightDf.$colName").as(colName))
}
}
val finalDf = joinedDf.select(selects.toSeq: _*)
Expand Down Expand Up @@ -365,7 +376,7 @@ object JoinUtils {
}.toSet

// Form the final WHERE clause for injection
s"$groupByKeyExpression in (${valueSet.mkString(sep = ",")})"
s"$groupByKeyExpression in (${valueSet.mkString(sep = ",")}) or $groupByKeyExpression is null"
}
.foreach { whereClause =>
val currentWheres = Option(source.rootQuery.getWheres).getOrElse(new util.ArrayList[String]())
Expand Down
110 changes: 109 additions & 1 deletion spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ import ai.chronon.spark.stats.SummaryJob
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructType, StringType => SparkStringType}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession, types}
import org.junit.Assert._
import org.junit.Test
import org.scalatest.Assertions.intercept

import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

import java.time.{LocalDateTime, ZoneOffset}
import java.time.format.DateTimeFormatter
import scala.collection.JavaConverters._
import scala.util.ScalaJavaConversions.ListOps

Expand Down Expand Up @@ -1298,4 +1300,110 @@ class JoinTest {
assert(
thrown2.getMessage.contains("Table or view not found") && thrown3.getMessage.contains("Table or view not found"))
}

@Test
def testPartialJoins(): Unit = {
import spark.implicits._

val userTable = s"$namespace.users"
spark.sql(s"DROP TABLE IF EXISTS $userTable")

val start = "2024-01-01"
val end = "2024-12-01"

val schema = StructType(Array(
StructField("user", types.StringType, nullable = true),
StructField("email", types.StringType, nullable = true),
StructField("order", types.StringType, nullable = true),
StructField("amount", types.IntegerType, nullable = true),
StructField("created", types.LongType, nullable = true),
StructField("ds", types.StringType, nullable = true),
))

val rows = Seq(
("user", "email", "order", 0, LocalDateTime.of(2024, 5, 10, 0, 0)),
("user", "email", "order", 0, LocalDateTime.of(2024, 5, 10, 6, 0)),
("user", "email2", "order", 1, LocalDateTime.of(2024, 5, 10, 0, 0)),
("user", "email2", "order", 1, LocalDateTime.of(2024, 5, 10, 6, 0)),
("user", null, "order", null, LocalDateTime.of(2024, 5, 10, 0, 0)),
("user", null, "order", null, LocalDateTime.of(2024, 5, 10, 6, 0))
).map({
case (user, email, order, amount, dt) => Row(
user,
email,
order,
amount,
dt.toInstant(ZoneOffset.UTC).toEpochMilli,
dt.format(DateTimeFormatter.ofPattern("yyyy-MM-dd"))
)
})

val events = spark.createDataFrame(spark.sparkContext.parallelize(rows), schema)

TableUtils(events.sparkSession).insertPartitions(events, tableName = userTable)

val source = Builders.Source.events(
query = Builders.Query(
selects = Builders.Selects("user", "email", "order", "amount"),
timeColumn = "created",
startPartition = start
),
table = userTable
)

def makeJoinConf(keys: Seq[String]) = {
val groupBy = Builders.GroupBy(
sources = Seq(source),
keyColumns = keys,
aggregations = Seq(
Builders.Aggregation(
operation = Operation.COUNT,
inputColumn = "order",
windows = Seq(new Window(30, TimeUnit.DAYS))
)
),
accuracy=Accuracy.TEMPORAL,
metaData = Builders.MetaData(
name = "unit_test.user_email_orders"
)
)

Builders.Join(
left = source,
joinParts = Seq(
Builders.JoinPart(
groupBy = groupBy
)
),
metaData = Builders.MetaData(
name = "test.user_features",
namespace = namespace,
team = "chronon"
)
)
}

def assertJoin(columnNames: Seq[String]) = {
val joinConf = makeJoinConf(columnNames)
val join = new Join(joinConf, end, tableUtils)
val results = join.computeJoin()
val columns = columnNames.map(col)

val window = org.apache.spark.sql.expressions.Window.partitionBy(columns: _*).orderBy(col("ts").desc)

val mostRecentResults = results.withColumn("rank", row_number().over(window))
.filter(col("rank") === 1)

val expectedCounts = List(1, 1, 1)
val actualCounts = mostRecentResults.map(r => r.getAs[Long]("unit_test_user_email_orders_order_count_30d")).collectAsList().asScala

assertEquals(expectedCounts, actualCounts)
}

// Check string nulls
assertJoin(Seq("user", "email"))

// Check non-string nulls
assertJoin(Seq("user", "amount"))
}
}