Observation
Spark SQL can optimize join only if join condition is based on the equality operator. This means we can consider equijoins and non-equijoins separately.
Equijoin
Equijoin can be implemented in a type safe manner by mapping both Datasets to (key, value) tuples, performing join based on keys, and reshaping the result:
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Dataset
def safeEquiJoin[T, U, K](ds1: Dataset[T], ds2: Dataset[U])
(f: T => K, g: U => K)
(implicit e1: Encoder[(K, T)], e2: Encoder[(K, U)], e3: Encoder[(T, U)]) = {
val ds1_ = ds1.map(x => (f(x), x))
val ds2_ = ds2.map(x => (g(x), x))
ds1_.joinWith(ds2_, ds1_("_1") === ds2_("_1")).map(x => (x._1._2, x._2._2))
}
Non-equijoin
Can be expressed using relational algebra operators as R ⋈θ S = σθ(R × S) and converted directly to code.
Spark 2.0
Enable crossJoin and use joinWith with trivially equal predicate:
spark.conf.set("spark.sql.crossJoin.enabled", true)
def safeNonEquiJoin[T, U](ds1: Dataset[T], ds2: Dataset[U])
(p: (T, U) => Boolean) = {
ds1.joinWith(ds2, lit(true)).filter(p.tupled)
}
Spark 2.1
Use crossJoin method:
def safeNonEquiJoin[T, U](ds1: Dataset[T], ds2: Dataset[U])
(p: (T, U) => Boolean)
(implicit e1: Encoder[Tuple1[T]], e2: Encoder[Tuple1[U]], e3: Encoder[(T, U)]) = {
ds1.map(Tuple1(_)).crossJoin(ds2.map(Tuple1(_))).as[(T, U)].filter(p.tupled)
}
Examples
case class LabeledPoint(label: String, x: Double, y: Double)
case class Category(id: Long, name: String)
val points1 = Seq(LabeledPoint("foo", 1.0, 2.0)).toDS
val points2 = Seq(
LabeledPoint("bar", 3.0, 5.6), LabeledPoint("foo", -1.0, 3.0)
).toDS
val categories = Seq(Category(1, "foo"), Category(2, "bar")).toDS
safeEquiJoin(points1, categories)(_.label, _.name)
safeNonEquiJoin(points1, points2)(_.x > _.x)
Notes
-
It should be noted that these methods are qualtiatively differnt from a direct
joinWithapplication and require expensiveDeserializeToObject/SerializeFromObjecttransformations (compared to that directjoinWithcan use logical operations on the data).This is similar to the behavior described in Spark 2.0 Dataset vs DataFrame.
-
If you’re not limited to the Spark SQL API
framelessprovides interesting type safe extensions forDatasets(as of today its supports only Spark 2.0):import frameless.TypedDataset val typedPoints1 = TypedDataset.create(points1) val typedPoints2 = TypedDataset.create(points2) typedPoints1.join(typedPoints2, typedPoints1('x), typedPoints2('x)) -
DatasetAPI is not stable in 1.6 so I don’t think it makes sense to use it there. -
Of course this design and descriptive names are not necessary. You can easily use type class to add this methods implicitly to
Datasetan there is no conflict with built in signatures so both can be calledjoinWith.