Skip to content

Commit

Permalink
[SPARK-44974][CONNECT] Null out SparkSession/Dataset/KeyValueGroupedD…
Browse files Browse the repository at this point in the history
…atset on serialization

### What changes were proposed in this pull request?
This PR changes the serialization for connect `SparkSession`, `Dataset`, and `KeyValueGroupedDataset`. While these were marked as serializable they were not, because they refer to bits and pieces that are not serializable. Even if we were to fix this, then we still have a class clash problem with server side classes that have the same name, but have different structure. the latter can be fixed with serialization proxies, but I am going to hold that until someone actually needs/wants this.

After this PR these classes are serialized as null. This is a somewhat suboptimal solution compared to throwing exceptions on serialization, however this is more compatible compared to the old situation, and makes accidental capture of these classes less of an issue for UDFs.

### Why are the changes needed?
More compatible with the old situation. Improved UX when working with UDFs.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Added tests to `ClientDatasetSuite`, `KeyValueGroupedDatasetE2ETestSuite`, `SparkSessionSuite`, and `UserDefinedFunctionE2ETestSuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#42688 from hvanhovell/SPARK-44974.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
hvanhovell committed Aug 28, 2023
1 parent 474f64a commit f0b0428
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3352,4 +3352,10 @@ class Dataset[T] private[sql] (
result.close()
}
}

/**
* We cannot deserialize a connect [[Dataset]] because of a class clash on the server side. We
* null out the instance for now.
*/
private def writeReplace(): Any = null
}
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,12 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
outputEncoder = outputEncoder)
udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction
}

/**
* We cannot deserialize a connect [[KeyValueGroupedDataset]] because of a class clash on the
* server side. We null out the instance for now.
*/
private def writeReplace(): Any = null
}

private object KeyValueGroupedDatasetImpl {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,12 @@ class SparkSession private[sql] (
def clearTags(): Unit = {
client.clearTags()
}

/**
* We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side.
* We null out the instance for now.
*/
private def writeReplace(): Any = null
}

// The minimal builder needed to create a spark session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.client.{DummySparkConnectService, SparkConnectClient}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.util.SparkSerDeUtils

// Add sample tests.
// - sample fraction: simple.sample(0.1)
Expand Down Expand Up @@ -172,4 +173,11 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
val actualPlan = service.getAndClearLatestInputPlan()
assert(actualPlan.equals(expectedPlan))
}

test("serialize as null") {
val session = newSparkSession()
val ds = session.range(10)
val bytes = SparkSerDeUtils.serialize(ds)
assert(SparkSerDeUtils.deserialize[Dataset[Long]](bytes) == null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
import org.apache.spark.sql.test.{QueryTest, SQLHelper}
import org.apache.spark.sql.types._
import org.apache.spark.util.SparkSerDeUtils

case class ClickEvent(id: String, timestamp: Timestamp)

Expand Down Expand Up @@ -630,6 +631,12 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
30,
3)
}

test("serialize as null") {
val kvgds = session.range(10).groupByKey(_ % 2)
val bytes = SparkSerDeUtils.serialize(kvgds)
assert(SparkSerDeUtils.deserialize[KeyValueGroupedDataset[Long, Long]](bytes) == null)
}
}

case class K1(a: Long)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.util.control.NonFatal
import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor}

import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.util.SparkSerDeUtils

/**
* Tests for non-dataframe related SparkSession operations.
Expand Down Expand Up @@ -261,4 +262,10 @@ class SparkSessionSuite extends ConnectFunSuite {
.create()
.close()
}

test("serialize as null") {
val session = SparkSession.builder().create()
val bytes = SparkSerDeUtils.serialize(session)
assert(SparkSerDeUtils.deserialize[SparkSession](bytes) == null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -328,4 +328,19 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest {
IntegerType)
checkDataset(session.range(2).select(fn($"id", $"id" + 2)).as[Int], 3, 5)
}

test("nullified SparkSession/Dataset/KeyValueGroupedDataset in UDF") {
val session: SparkSession = spark
import session.implicits._
val df = session.range(0, 10, 1, 1)
val kvgds = df.groupByKey(_ / 2)
val f = udf { (i: Long) =>
assert(session == null)
assert(df == null)
assert(kvgds == null)
i + 1
}
val result = df.select(f($"id")).as[Long].head
assert(result == 1L)
}
}

0 comments on commit f0b0428

Please sign in to comment.