diff --git a/src/main/scala/com/ryanstull/nullsafe/package.scala b/src/main/scala/com/ryanstull/nullsafe/package.scala index e036c2d..7853f4c 100644 --- a/src/main/scala/com/ryanstull/nullsafe/package.scala +++ b/src/main/scala/com/ryanstull/nullsafe/package.scala @@ -362,7 +362,7 @@ package object nullsafe { isAnyRef(tree) && !isPackageOrModule(tree) } - if (transformations.nonEmpty && (dontCheckForNotNull || !nullable(tree) || canFoldInto)) { + if (transformations.nonEmpty && (dontCheckForNotNull || (!nullable(tree) && canFoldInto))) { val prev = transformations.head transformations.update(0, transformation.andThen(prev)) } else { @@ -381,7 +381,7 @@ package object nullsafe { case t: Ident => (t, transformations) case t@Select(qualifier, _) if isPackageOrModule(qualifier) => (t, transformations) //Selects from packages case t@(_: Literal | _: This) => (incorporateBase(t, ignoreCanFold = true), transformations) - case t if t.symbol != null && t.symbol.isStatic => (incorporateBase(t), transformations) //Static methods call + case t if t.symbol != null && t.symbol.isStatic && !t.symbol.isImplicit => (incorporateBase(t), transformations) //Static methods call case TypeApply(Select(qualifier, termName), types) => //Casting val transformation = (qual: Tree) => TypeApply(Select(qual, termName), types) incorporateTransformation(transformation, dontCheckForNotNull = true) @@ -400,6 +400,10 @@ package object nullsafe { val transformation = (arg: Tree) => Apply(prefix, List(arg)) incorporateTransformation(transformation) loop(arg, transformations, canFoldInto = true) + case t@Apply(s@Select(_, _), List(arg)) if t.symbol != null && t.symbol.isStatic && t.symbol.isImplicit => //Implicit def + val transformation = (qual: Tree) => Apply(s, List(qual)) + incorporateTransformation(transformation) + loop(arg, transformations, canFoldInto = true) case Apply(prefix@(_: This | _: Ident | Select(_: This | _: New, _)), args) => //Function with multiple args val applyWithSafeArgs = Apply(prefix, rewriteArgsToNullSafe(args)) (incorporateBase(applyWithSafeArgs), transformations) diff --git a/src/test/scala/com/ryanstull/nullsafe/Tests.scala b/src/test/scala/com/ryanstull/nullsafe/Tests.scala index 9705e13..564e0d2 100644 --- a/src/test/scala/com/ryanstull/nullsafe/Tests.scala +++ b/src/test/scala/com/ryanstull/nullsafe/Tests.scala @@ -382,6 +382,25 @@ class Tests extends FlatSpec { val a: A = A(B(C(D(E(null))))) ?(a.b.c.getD.e.s.toInt) } + + "Handling implicit defs" should "work" in { + import java.{lang => jl} + case class Input(double: jl.Double) + case class Output(doubleOpt: Option[Double]) + + val i1 = Input(3d) + val o1 = Output(opt(i1.double)) + + val i2 = Input(null) + val o2 = Output(opt(i2.double)) + + val i3: Input = null + val o3 = Output(opt(i3.double)) + + assert(o1.doubleOpt.contains(3d)) + assert(o2.doubleOpt.isEmpty) + assert(o3.doubleOpt.isEmpty) + } } //Example of deeply nested domain object