Created
November 4, 2019 12:38
-
-
Save HarborZeng/c0f8a928dfe75aaf85cf9aab3708593a to your computer and use it in GitHub Desktop.
Explorer what spark join is
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.log4j.{Level, Logger} | |
import org.apache.spark.sql.{Row, SparkSession} | |
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, StructField, StructType} | |
import org.scalatest.{FlatSpec, Matchers} | |
class JoinSpec extends FlatSpec with Matchers { | |
@transient lazy val logger: Logger = Logger.getLogger(getClass) | |
Logger.getRootLogger.setLevel(Level.INFO) | |
Logger.getLogger("org").setLevel(Level.WARN) | |
Logger.getLogger("akka").setLevel(Level.WARN) | |
lazy val spark: SparkSession = { | |
SparkSession | |
.builder() | |
.master("local") | |
.appName("spark test of join") | |
.getOrCreate() | |
} | |
val sc = spark.sparkContext | |
private val customers = spark.createDataFrame( | |
sc.parallelize(Seq( | |
Row(1, "harbor"), | |
Row(2, "mr.wu"), | |
Row(3, "babaozhou") | |
)), | |
schema = StructType( | |
Array[StructField]( | |
StructField("cid", IntegerType), | |
StructField("name", StringType) | |
)) | |
) | |
private val orders = spark.createDataFrame( | |
sc.parallelize(Seq( | |
Row(1, 1, 50.0d), | |
Row(2, 2, 10d), | |
Row(3, 2, 10d), | |
Row(4, 2, 10d), | |
Row(5, 1000, 19d) | |
)), schema = StructType( | |
Array[StructField]( | |
StructField("oid", IntegerType), | |
StructField("cid", IntegerType), | |
StructField("amount", DoubleType) | |
)) | |
) | |
"SparkJoin inner" should "collect only matching rows from both sides" in { | |
val innerJoinResultDF = orders.join(customers, Seq("cid"), joinType = "inner") | |
innerJoinResultDF.show() | |
val innerJoinResult = innerJoinResultDF.collect() | |
innerJoinResult should (have size 4 and contain allOf( | |
// |cid|oid|amount| name| | |
Row(1, 1, 50.0d, "harbor"), | |
Row(2, 2, 10d, "mr.wu"), | |
Row(2, 3, 10d, "mr.wu"), | |
Row(2, 4, 10d, "mr.wu") | |
)) | |
} | |
"SparkJoin cross" should "create a Cartesian Product" in { | |
val crossJoinResultDF = orders.crossJoin(customers) | |
crossJoinResultDF.show() | |
val crossJoinResult = crossJoinResultDF.collect() | |
crossJoinResult should (have size 15 and contain allOf( | |
// |oid| cid| amount| cid| name| | |
Row(1, 1, 50.0, 1, "harbor"), | |
Row(1, 1, 50.0, 2, "mr.wu"), | |
Row(1, 1, 50.0, 3, "babaozhou"), | |
Row(2, 2, 10.0, 1, "harbor"), | |
Row(2, 2, 10.0, 2, "mr.wu"), | |
Row(2, 2, 10.0, 3, "babaozhou"), | |
Row(3, 2, 10.0, 1, "harbor"), | |
Row(3, 2, 10.0, 2, "mr.wu"), | |
Row(3, 2, 10.0, 3, "babaozhou"), | |
Row(4, 2, 10.0, 1, "harbor"), | |
Row(4, 2, 10.0, 2, "mr.wu"), | |
Row(4, 2, 10.0, 3, "babaozhou"), | |
Row(5, 1000, 19.0, 1, "harbor"), | |
Row(5, 1000, 19.0, 2, "mr.wu"), | |
Row(5, 1000, 19.0, 3, "babaozhou") | |
)) | |
} | |
"SparkJoin outer" should "full outer join 相比内连接多了null数据" in { | |
// "outer", "full", "fullouter", "full_outer" | |
val outerJoinResultDF = orders.join(customers, Seq("cid"), joinType = "outer") | |
outerJoinResultDF.show() | |
val outerJoinResult = outerJoinResultDF.collect() | |
outerJoinResult should (have size 6 and contain allOf( | |
// | cid| oid| amount| name| | |
Row( 1, 1, 50.0d, "harbor"), | |
Row( 2, 2, 10d, "mr.wu"), | |
Row( 2, 3, 10d, "mr.wu"), | |
Row( 2, 4, 10d, "mr.wu"), | |
Row( 3, null, null, "babaozhou"), | |
Row( 1000, 5, 19d, null) | |
)) | |
} | |
"SparkJoin left order join customer" should "left outer join 相比内连接多了左边key不为null的null数据" in { | |
// "leftouter", "left", "left_outer" | |
val leftJoinResultDF = orders.join(customers, Seq("cid"), joinType = "left") | |
leftJoinResultDF.show() | |
val leftJoinResult = leftJoinResultDF.collect() | |
leftJoinResult should (have size 5 and contain allOf( | |
// | cid| oid| amount| name| | |
Row( 1, 1, 50.0d, "harbor"), | |
Row( 2, 2, 10d, "mr.wu"), | |
Row( 2, 3, 10d, "mr.wu"), | |
Row( 2, 4, 10d, "mr.wu"), | |
Row( 1000, 5, 19d, null) | |
)) | |
} | |
"SparkJoin left customer join order" should "left outer join 相比内连接多了左边key不为null的null数据" in { | |
// "leftouter", "left", "left_outer" | |
val leftJoinResultDF = customers.join(orders, Seq("cid"), joinType = "left") | |
leftJoinResultDF.show() | |
val leftJoinResult = leftJoinResultDF.collect() | |
leftJoinResult should (have size 5 and contain allOf( | |
// |cid| name| oid| amount| | |
Row(1, "harbor", 1, 50.0d), | |
Row(2, "mr.wu", 2, 10d), | |
Row(2, "mr.wu", 3, 10d), | |
Row(2, "mr.wu", 4, 10d), | |
Row(3, "babaozhou", null, null) | |
)) | |
} | |
"SparkJoin right customer join order" should "和left order join customer除了列数据顺序不一样之外其他都一样" in { | |
// "rightouter", "right", "right_outer" | |
val rightJoinResultDF = customers.join(orders, Seq("cid"), joinType = "right") | |
rightJoinResultDF.show() | |
val rightJoinResult = rightJoinResultDF.collect() | |
rightJoinResult should (have size 5 and contain allOf( | |
// | cid| name| oid| amount| | |
Row( 1, "harbor", 1, 50.0d), | |
Row( 2, "mr.wu", 2, 10d), | |
Row( 2, "mr.wu", 3, 10d), | |
Row( 2, "mr.wu", 4, 10d), | |
Row( 1000, null, 5, 19d) | |
)) | |
} | |
"SparkJoin right order join customer" should "和left customer join order除了列数据顺序不一样之外其他都一样" in { | |
// "rightouter", "right", "right_outer" | |
val rightJoinResultDF = orders.join(customers, Seq("cid"), joinType = "right") | |
rightJoinResultDF.show() | |
val rightJoinResult = rightJoinResultDF.collect() | |
rightJoinResult should (have size 5 and contain allOf( | |
// |cid| oid| amount| name| | |
Row(1, 1, 50.0d, "harbor"), | |
Row(2, 2, 10d, "mr.wu"), | |
Row(2, 3, 10d, "mr.wu"), | |
Row(2, 4, 10d, "mr.wu"), | |
Row(3, null, null, "babaozhou") | |
)) | |
} | |
"SparkJoin left_semi order join customer" should "和inner除了少了右边独有的列之外其他都一样" in { | |
// "leftsemi", "left_semi" | |
val leftSemiJoinResultDF = orders.join(customers, Seq("cid"), joinType = "left_semi") | |
leftSemiJoinResultDF.show() | |
val leftSemiJoinResult = leftSemiJoinResultDF.collect() | |
leftSemiJoinResult should (have size 4 and contain allOf( | |
// |cid|oid|amount| | |
Row(1, 1, 50.0d), | |
Row(2, 2, 10d), | |
Row(2, 3, 10d), | |
Row(2, 4, 10d) | |
)) | |
} | |
"SparkJoin left_semi custom join order" should "和inner除了少了右边独有的列还有去重了之外其他都一样" in { | |
// "leftsemi", "left_semi" | |
val leftSemiJoinResultDF = customers.join(orders, Seq("cid"), joinType = "left_semi") | |
leftSemiJoinResultDF.show() | |
val leftSemiJoinResult = leftSemiJoinResultDF.collect() | |
leftSemiJoinResult should (have size 2 and contain allOf( | |
// |cid |name | | |
Row(1, "harbor"), | |
Row(2, "mr.wu") | |
)) | |
} | |
"SparkJoin left_anti custom join order" should "和 left_semi custom join order 组合在一起就是完整的 customs 表" in { | |
// "leftanti", "left_anti" | |
val leftAntiJoinResultDF = customers.join(orders, Seq("cid"), joinType = "left_anti") | |
leftAntiJoinResultDF.show() | |
val leftAntiJoinResult = leftAntiJoinResultDF.collect() | |
leftAntiJoinResult should (have size 1 and contain ( | |
// |cid|name| | |
Row(3, "babaozhou") | |
)) | |
} | |
"SparkJoin left_anti order join customer" should "和 left_semi order join customer 组合在一起就是完整的 orders 表" in { | |
// "leftanti", "left_anti" | |
val leftAntiJoinResultDF = orders.join(customers, Seq("cid"), joinType = "left_anti") | |
leftAntiJoinResultDF.show() | |
val leftAntiJoinResult = leftAntiJoinResultDF.collect() | |
leftAntiJoinResult should (have size 1 and contain ( | |
// |cid|oid|amount| | |
Row(1000,5, 19d) | |
)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment