From f0b04286022e0774d78b9adcf4aeabc181a3ec89 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 28 Aug 2023 15:05:18 +0200 Subject: [PATCH] [SPARK-44974][CONNECT] Null out SparkSession/Dataset/KeyValueGroupedDatset on serialization ### What changes were proposed in this pull request? This PR changes the serialization for connect `SparkSession`, `Dataset`, and `KeyValueGroupedDataset`. While these were marked as serializable they were not, because they refer to bits and pieces that are not serializable. Even if we were to fix this, then we still have a class clash problem with server side classes that have the same name, but have different structure. the latter can be fixed with serialization proxies, but I am going to hold that until someone actually needs/wants this. After this PR these classes are serialized as null. This is a somewhat suboptimal solution compared to throwing exceptions on serialization, however this is more compatible compared to the old situation, and makes accidental capture of these classes less of an issue for UDFs. ### Why are the changes needed? More compatible with the old situation. Improved UX when working with UDFs. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added tests to `ClientDatasetSuite`, `KeyValueGroupedDatasetE2ETestSuite`, `SparkSessionSuite`, and `UserDefinedFunctionE2ETestSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42688 from hvanhovell/SPARK-44974. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../main/scala/org/apache/spark/sql/Dataset.scala | 6 ++++++ .../apache/spark/sql/KeyValueGroupedDataset.scala | 6 ++++++ .../scala/org/apache/spark/sql/SparkSession.scala | 6 ++++++ .../org/apache/spark/sql/ClientDatasetSuite.scala | 8 ++++++++ .../sql/KeyValueGroupedDatasetE2ETestSuite.scala | 7 +++++++ .../org/apache/spark/sql/SparkSessionSuite.scala | 7 +++++++ .../sql/UserDefinedFunctionE2ETestSuite.scala | 15 +++++++++++++++ 7 files changed, 55 insertions(+) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 3c89e649020f4..1d83f196b53b1 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3352,4 +3352,10 @@ class Dataset[T] private[sql] ( result.close() } } + + /** + * We cannot deserialize a connect [[Dataset]] because of a class clash on the server side. We + * null out the instance for now. + */ + private def writeReplace(): Any = null } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 202891c66d748..88c8b6a4f8bad 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -979,6 +979,12 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( outputEncoder = outputEncoder) udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction } + + /** + * We cannot deserialize a connect [[KeyValueGroupedDataset]] because of a class clash on the + * server side. We null out the instance for now. + */ + private def writeReplace(): Any = null } private object KeyValueGroupedDatasetImpl { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index e902e04e24611..7882ea6401354 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -714,6 +714,12 @@ class SparkSession private[sql] ( def clearTags(): Unit = { client.clearTags() } + + /** + * We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side. + * We null out the instance for now. + */ + private def writeReplace(): Any = null } // The minimal builder needed to create a spark session. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala index a521c6745a90c..aab31d97e8c9d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.connect.proto import org.apache.spark.sql.connect.client.{DummySparkConnectService, SparkConnectClient} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.util.SparkSerDeUtils // Add sample tests. // - sample fraction: simple.sample(0.1) @@ -172,4 +173,11 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { val actualPlan = service.getAndClearLatestInputPlan() assert(actualPlan.equals(expectedPlan)) } + + test("serialize as null") { + val session = newSparkSession() + val ds = session.range(10) + val bytes = SparkSerDeUtils.serialize(ds) + assert(SparkSerDeUtils.deserialize[Dataset[Long]](bytes) == null) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index 3e979be73a754..98a947826e3de 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} import org.apache.spark.sql.test.{QueryTest, SQLHelper} import org.apache.spark.sql.types._ +import org.apache.spark.util.SparkSerDeUtils case class ClickEvent(id: String, timestamp: Timestamp) @@ -630,6 +631,12 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { 30, 3) } + + test("serialize as null") { + val kvgds = session.range(10).groupByKey(_ % 2) + val bytes = SparkSerDeUtils.serialize(kvgds) + assert(SparkSerDeUtils.deserialize[KeyValueGroupedDataset[Long, Long]](bytes) == null) + } } case class K1(a: Long) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 90fe8f57d0713..4c858262c6ef5 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} import org.apache.spark.sql.test.ConnectFunSuite +import org.apache.spark.util.SparkSerDeUtils /** * Tests for non-dataframe related SparkSession operations. @@ -261,4 +262,10 @@ class SparkSessionSuite extends ConnectFunSuite { .create() .close() } + + test("serialize as null") { + val session = SparkSession.builder().create() + val bytes = SparkSerDeUtils.serialize(session) + assert(SparkSerDeUtils.deserialize[SparkSession](bytes) == null) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index 0af8c78a1da85..fbc2c1c266262 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -328,4 +328,19 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { IntegerType) checkDataset(session.range(2).select(fn($"id", $"id" + 2)).as[Int], 3, 5) } + + test("nullified SparkSession/Dataset/KeyValueGroupedDataset in UDF") { + val session: SparkSession = spark + import session.implicits._ + val df = session.range(0, 10, 1, 1) + val kvgds = df.groupByKey(_ / 2) + val f = udf { (i: Long) => + assert(session == null) + assert(df == null) + assert(kvgds == null) + i + 1 + } + val result = df.select(f($"id")).as[Long].head + assert(result == 1L) + } }