Skip to content

Commit

Permalink
Fix null handling of nested case classes (#267)
Browse files Browse the repository at this point in the history
* Fix null handling of nested case classes

* Add test to verify parquet serialization bug disappeared

* Add missing newline at the end of file

* Remove empty line

* Fixing empty lines, once again

* Set null propagation back to true

* Revert "Set null propagation back to true"

This reverts commit db0b087.

* More nested option deserializer tests

* Make null safe equality work with structs

* Revert "Revert "Set null propagation back to true""

This reverts commit cc0a865.

* Simplify sqlContext access in tests
  • Loading branch information
kmate authored and imarios committed Apr 6, 2018
1 parent ef83774 commit c5e3c11
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 11 deletions.
9 changes: 7 additions & 2 deletions dataset/src/main/scala/frameless/RecordEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ class RecordEncoder[F, G <: HList, H <: HList]
case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
}

CreateNamedStruct(exprs)
val nullExpr = Literal.create(null, catalystRepr)
val createExpr = CreateNamedStruct(exprs)
If(IsNull(path), nullExpr, createExpr)
}

def fromCatalyst(path: Expression): Expression = {
Expand All @@ -169,6 +171,9 @@ class RecordEncoder[F, G <: HList, H <: HList]
field.encoder.fromCatalyst(fieldPath)
}

NewInstance(classTag.runtimeClass, newInstanceExprs.value.from(exprs), jvmRepr, propagateNull = true)
val nullExpr = Literal.create(null, jvmRepr)
val newArgs = newInstanceExprs.value.from(exprs)
val newExpr = NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)
If(IsNull(path), nullExpr, newExpr)
}
}
2 changes: 1 addition & 1 deletion dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ abstract class AbstractTypedColumn[T, U]
def untyped: Column = new Column(expr)

private def equalsTo[TT, W](other: ThisType[TT, U])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = typed {
if (uencoder.nullable && uencoder.catalystRepr.typeName != "struct") EqualNullSafe(self.expr, other.expr)
if (uencoder.nullable) EqualNullSafe(self.expr, other.expr)
else EqualTo(self.expr, other.expr)
}

Expand Down
4 changes: 2 additions & 2 deletions dataset/src/main/scala/frameless/TypedExpressionEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package frameless

import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, Literal}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, If, Literal}
import org.apache.spark.sql.types.StructType

object TypedExpressionEncoder {
Expand All @@ -27,7 +27,7 @@ object TypedExpressionEncoder {
val in = BoundReference(0, encoder.jvmRepr, encoder.nullable)

val (out, toRowExpressions) = encoder.toCatalyst(in) match {
case x: CreateNamedStruct =>
case If(_, _, x: CreateNamedStruct) =>
val out = BoundReference(0, encoder.catalystRepr, encoder.nullable)

(out, x.flatten)
Expand Down
27 changes: 27 additions & 0 deletions dataset/src/test/scala/frameless/RecordEncoderTests.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package frameless

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.scalatest.Matchers
import shapeless.{HList, LabelledGeneric}
import shapeless.test.illTyped
Expand All @@ -12,6 +14,8 @@ object TupleWithUnits {
def apply(_1: Int, _2: String): TupleWithUnits = TupleWithUnits((), _1, (), (), _2, ())
}

case class OptionalNesting(o: Option[TupleWithUnits])

class RecordEncoderTests extends TypedDatasetSuite with Matchers {
test("Unable to encode products made from units only") {
illTyped("""TypedEncoder[UnitsOnly]""")
Expand All @@ -35,4 +39,27 @@ class RecordEncoderTests extends TypedDatasetSuite with Matchers {
df.collect shouldEqual tds.toDF.collect
ds.collect.toSeq shouldEqual tds.collect.run
}

test("Empty nested record value becomes null on serialization") {
val ds = TypedDataset.create(Seq(OptionalNesting(Option.empty)))
val df = ds.toDF
df.na.drop.count shouldBe 0
}

test("Empty nested record value becomes none on deserialization") {
val rdd = sc.parallelize(Seq(Row(null)))
val schema = TypedEncoder[OptionalNesting].catalystRepr.asInstanceOf[StructType]
val df = session.createDataFrame(rdd, schema)
val ds = TypedDataset.createUnsafe(df)(TypedEncoder[OptionalNesting])
ds.firstOption.run.get.o.isEmpty shouldBe true
}

test("Deeply nested optional values have correct deserialization") {
val rdd = sc.parallelize(Seq(Row(true, Row(null, null))))
type NestedOptionPair = X2[Boolean, Option[X2[Option[Int], Option[String]]]]
val schema = TypedEncoder[NestedOptionPair].catalystRepr.asInstanceOf[StructType]
val df = session.createDataFrame(rdd, schema)
val ds = TypedDataset.createUnsafe(df)(TypedEncoder[NestedOptionPair])
ds.firstOption.run.get shouldBe X2(true, Some(X2(None, None)))
}
}
45 changes: 39 additions & 6 deletions dataset/src/test/scala/frameless/forward/WriteTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,35 @@ package frameless

import java.util.UUID

import org.apache.spark.sql.SparkSession
import org.scalacheck.Prop._
import org.scalacheck.{Gen, Prop}
import org.scalacheck.{Arbitrary, Gen, Prop}

class WriteTests extends TypedDatasetSuite {
test("write") {

val genNested = for {
d <- Arbitrary.arbitrary[Double]
as <- Arbitrary.arbitrary[String]
} yield Nested(d, as)

val genOptionFieldsOnly = for {
o1 <- Gen.option(Arbitrary.arbitrary[Int])
o2 <- Gen.option(genNested)
} yield OptionFieldsOnly(o1, o2)

val genWriteExample = for {
i <- Arbitrary.arbitrary[Int]
s <- Arbitrary.arbitrary[String]
on <- Gen.option(genNested)
ooo <- Gen.option(genOptionFieldsOnly)
} yield WriteExample(i, s, on, ooo)

test("write csv") {
def prop[A: TypedEncoder](data: List[A]): Prop = {
val filePath = s"$TEST_OUTPUT_DIR/${UUID.randomUUID()}"
val input = TypedDataset.create(data)
input.write.csv(filePath)

val dataset = TypedDataset.createUnsafe(
implicitly[SparkSession].sqlContext.read.schema(input.schema).csv(filePath))
val dataset = TypedDataset.createUnsafe(sqlContext.read.schema(input.schema).csv(filePath))

dataset.collect().run().groupBy(identity) ?= input.collect().run().groupBy(identity)
}
Expand All @@ -23,4 +39,21 @@ class WriteTests extends TypedDatasetSuite {
check(forAll(prop[Int] _))
}

}
test("write parquet") {
def prop[A: TypedEncoder](data: List[A]): Prop = {
val filePath = s"$TEST_OUTPUT_DIR/${UUID.randomUUID()}"
val input = TypedDataset.create(data)
input.write.parquet(filePath)

val dataset = TypedDataset.createUnsafe(sqlContext.read.schema(input.schema).parquet(filePath))

dataset.collect().run().groupBy(identity) ?= input.collect().run().groupBy(identity)
}

check(forAll(Gen.listOf(genWriteExample))(prop[WriteExample]))
}
}

case class Nested(i: Double, v: String)
case class OptionFieldsOnly(o1: Option[Int], o2: Option[Nested])
case class WriteExample(i: Int, s: String, on: Option[Nested], ooo: Option[OptionFieldsOnly])

0 comments on commit c5e3c11

Please sign in to comment.