Skip to content

Commit

Permalink
Fix Typeable for recursive case classes (#195)
Browse files Browse the repository at this point in the history
Added a new `Typeable.recursive` constructor.
Use it in the Typeable macro implementation.
  • Loading branch information
joroKr21 authored Jan 10, 2024
1 parent 963b0db commit f6fdde9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 65 deletions.
35 changes: 20 additions & 15 deletions modules/typeable/src/main/scala/shapeless3/typeable/typeable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package shapeless3.typeable

import scala.compiletime.*
import scala.quoted.*

/**
Expand Down Expand Up @@ -135,7 +134,7 @@ object Typeable extends Typeable0:
"scala.runtime.BoxedUnit"
)

def isAnyValClass[T](clazz: Class[T]) =
def isAnyValClass[T](clazz: Class[T]): Boolean =
clazz == classOf[jl.Byte] ||
clazz == classOf[jl.Short] ||
clazz == classOf[jl.Integer] ||
Expand Down Expand Up @@ -165,11 +164,10 @@ object Typeable extends Typeable0:
* Typeable instance for `Iterable`. Note that the contents be will tested for conformance to the element type.
*/
given iterableTypeable[CC[t] <: Iterable[t], T](using CCTag: ClassTag[CC[Any]], tt: Typeable[T]): Typeable[CC[T]] with
def castable(t: Any): Boolean =
t match
case (cc: CC[?] @unchecked) if CCTag.runtimeClass.isAssignableFrom(t.getClass) =>
cc.forall(_.castable[T])
case _ => false
def castable(t: Any): Boolean = t match
case cc: CC[?] @unchecked if CCTag.runtimeClass.isAssignableFrom(t.getClass) =>
cc.forall(_.castable[T])
case _ => false
def describe = s"${CCTag.runtimeClass.getSimpleName}[${tt.describe}]"

/**
Expand All @@ -180,11 +178,10 @@ object Typeable extends Typeable0:
tk: Typeable[K],
tv: Typeable[V]
): Typeable[M[K, V]] with
def castable(t: Any): Boolean =
t match
case (m: Map[Any, Any] @unchecked) if MTag.runtimeClass.isAssignableFrom(t.getClass) =>
m.forall { case (k, v) => k.castable[K] && v.castable[V] }
case _ => false
def castable(t: Any): Boolean = t match
case m: Map[Any, Any] @unchecked if MTag.runtimeClass.isAssignableFrom(t.getClass) =>
m.forall { case (k, v) => k.castable[K] && v.castable[V] }
case _ => false
def describe = s"${MTag.runtimeClass.getSimpleName}[${tk.describe}, ${tv.describe}]"

/** Typeable instance for simple monomorphic types */
Expand Down Expand Up @@ -251,17 +248,25 @@ object Typeable extends Typeable0:
def castable(t: Any): Boolean = elems.exists(_.castable(t))
def describe = name

/** Allows constructing a Typeable instance for recursive types by tying the knot with a lazy val. */
def recursive[T](f: Typeable[T] => Typeable[T]): Typeable[T] = new Typeable[T]:
lazy val delegate = f(this)
export delegate.*

trait Typeable0:
inline def mkDefaultTypeable[T]: Typeable[T] = ${ TypeableMacros.impl[T] }
inline def mkDefaultTypeable[T]: Typeable[T] =
${ TypeableMacros.impl[T] }

inline given [T]: Typeable[T] = mkDefaultTypeable[T]
inline given [T]: Typeable[T] =
Typeable.recursive: self =>
given Typeable[T] = self
mkDefaultTypeable[T]

object TypeableMacros:
import Typeable.*

def impl[T: Type](using Quotes): Expr[Typeable[T]] =
import quotes.reflect.*
import util.*

val TypeableType = TypeRepr.of[Typeable[?]] match
case tp: AppliedType => tp.tycon
Expand Down
104 changes: 54 additions & 50 deletions modules/typeable/src/test/scala/shapeless3/typeable/typeable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
package shapeless3.typeable

class TypeableTests:
import java.lang as jl

import org.junit.Test
import TypeableTests.*
import org.junit.Assert.*

import syntax.typeable.*
import org.junit.Test
import shapeless3.test.*
import syntax.typeable.*

import java.lang as jl

@Test
def testPrimitives: Unit =
def testPrimitives(): Unit =
val b: Any = 23.toByte
val cb = b.cast[Byte]
assertTrue(cb.isDefined)
Expand Down Expand Up @@ -64,7 +64,7 @@ class TypeableTests:
assertTrue(cu.isDefined)

@Test
def testBoxedPrimitives: Unit =
def testBoxedPrimitives(): Unit =
val b: Any = 23.toByte
val cb = b.cast[jl.Byte]
assertTrue(cb.isDefined)
Expand Down Expand Up @@ -98,7 +98,7 @@ class TypeableTests:
assertTrue(cbl.isDefined)

@Test
def testUnerased: Unit =
def testUnerased(): Unit =
val li: Any = List(1, 2, 3, 4)
val cli = li.cast[List[Int]]
assertTrue(cli.isDefined)
Expand Down Expand Up @@ -140,17 +140,12 @@ class TypeableTests:
trait Poly[T]

@Test
def testErased: Unit =
illTyped("""
Typeable[Int => String]
""")

illTyped("""
Typeable[Poly[Int]]
""")
def testErased(): Unit =
illTyped("Typeable[Int => String]")
illTyped("Typeable[Poly[Int]]")

@Test
def testAnys: Unit =
def testAnys(): Unit =
val v: Any = 23
val cv = v.cast[AnyVal]
assertTrue(cv.isDefined)
Expand All @@ -166,28 +161,28 @@ class TypeableTests:
assertTrue(cr2.isEmpty)

@Test
def testNull: Unit =
def testNull(): Unit =
val n: Any = null
val cn = n.cast[AnyVal]
assertTrue(!cn.isDefined)
assertTrue(cn.isEmpty)

val cn1 = n.cast[AnyRef]
assertTrue(!cn1.isDefined)
assertTrue(cn1.isEmpty)

val cn2 = n.cast[Int]
assertTrue(!cn2.isDefined)
assertTrue(cn2.isEmpty)

val cn3 = n.cast[String]
assertTrue(!cn3.isDefined)
assertTrue(cn3.isEmpty)

val cn4 = n.cast[List[Int]]
assertTrue(!cn4.isDefined)
assertTrue(cn4.isEmpty)

val cn7 = n.cast[(Int, String)]
assertTrue(!cn7.isDefined)
assertTrue(cn7.isEmpty)

@Test
def testExistentials: Unit =
def testExistentials(): Unit =
val l: Any = List(1, 2, 3, 4)
val cl = l.cast[List[?]]
assertTrue(cl.isDefined)
Expand All @@ -196,7 +191,7 @@ class TypeableTests:
assertTrue(cl2.isEmpty)

@Test
def testTraits: Unit =
def testTraits(): Unit =
trait A
trait B
trait C
Expand All @@ -213,7 +208,7 @@ class TypeableTests:
assertTrue(cd3.isEmpty)

@Test
def testIntersections: Unit =
def testIntersections(): Unit =
trait A
trait B
trait C
Expand All @@ -233,7 +228,7 @@ class TypeableTests:
assertTrue(cd4.isEmpty)

@Test
def testUnions: Unit =
def testUnions(): Unit =
class A
class B
class C
Expand All @@ -252,7 +247,7 @@ class TypeableTests:
assertTrue(cd3.isEmpty)

@Test
def testNarrowTo: Unit =
def testNarrowTo(): Unit =
trait A
trait B
class C extends A with B
Expand All @@ -267,16 +262,11 @@ class TypeableTests:
val cc2 = b.narrowTo[C]
assertTrue(cc2.isDefined)

illTyped("""
val ca = b.narrowTo[A]
""")

illTyped("""
val cb = a.narrowTo[B]
""")
illTyped("val ca = b.narrowTo[A]")
illTyped("val cb = a.narrowTo[B]")

@Test
def testTuples: Unit =
def testTuples(): Unit =
val p: Any = (23, "foo")
val cp = p.cast[(Int, String)]
assertTrue(cp.isDefined)
Expand All @@ -298,7 +288,7 @@ class TypeableTests:
assertTrue(cm3.isEmpty)

@Test
def testOption: Unit =
def testOption(): Unit =
val o: Any = Option(23)
val co = o.cast[Option[Int]]
assertTrue(co.isDefined)
Expand All @@ -313,7 +303,7 @@ class TypeableTests:
assertTrue(co4.isDefined)

@Test
def testEither: Unit =
def testEither(): Unit =
val ei: Any = Left[Int, String](23)
val cei = ei.cast[Either[Int, String]]
assertTrue(cei.isDefined)
Expand Down Expand Up @@ -345,7 +335,7 @@ class TypeableTests:
case class Baz[A, B](a: A, b: B, i: Int)

@Test
def testProducts: Unit =
def testProducts(): Unit =
val foo: Any = Foo(23, "foo", true)
val iBar: Any = Bar(23)
val sBar: Any = Bar("bar")
Expand Down Expand Up @@ -390,14 +380,14 @@ class TypeableTests:
case class Gen4[A](i: Int)(a: A) extends Abs[A](a)

@Test
def testIllegalProducts: Unit =
illTyped("""Typeable[Gen1[Int]]""")
illTyped("""Typeable[Gen2[Int]]""")
illTyped("""Typeable[Gen3[Int]]""")
illTyped("""Typeable[Gen4[Int]]""")
def testIllegalProducts(): Unit =
illTyped("Typeable[Gen1[Int]]")
illTyped("Typeable[Gen2[Int]]")
illTyped("Typeable[Gen3[Int]]")
illTyped("Typeable[Gen4[Int]]")

@Test
def testTypeCase: Unit =
def testTypeCase(): Unit =

def typeCase[T: Typeable](t: Any): Option[T] =
val T = TypeCase[T]
Expand All @@ -423,7 +413,7 @@ class TypeableTests:
assertEquals(None, typeCase[String](List(("foo", 23)): Any))

@Test
def testSingletons: Unit =
def testSingletons(): Unit =
object ObjA
object ObjB

Expand Down Expand Up @@ -464,7 +454,7 @@ class TypeableTests:
class C extends A with B

@Test
def testToString: Unit =
def testToString(): Unit =
def typeableString[T](t: T)(implicit tp: Typeable[T]) = tp.toString

val i: Int = 7
Expand Down Expand Up @@ -516,7 +506,7 @@ class TypeableTests:
assertEquals("TypeCase[List[Int]]", tc.toString)

@Test
def testNested: Unit =
def testNested(): Unit =

trait A1[T]:
class C(val t: T)
Expand Down Expand Up @@ -583,11 +573,25 @@ class TypeableTests:
assertEquals(None, ttD2I.cast(d2SF))

@Test
def testValInNestedCaseClass: Unit =
def testValInNestedCaseClass(): Unit =
// See https://github.com/milessabin/shapeless/issues/812
object X:
case class A()
case class B(a: A):
private[this] val aa = a
object Test:
Typeable[X.B]

@Test
def testTree(): Unit =
val t = Typeable[Tree]
val v: Any = Tree(Vector(Tree(Vector.empty), Tree(Vector.empty)))
assert(t.castable(v))
assert(!t.castable("tree"))
assert(t.cast(v).contains(v))
assert(t.cast("tree").isEmpty)
assertEquals(t.describe, "Tree")
assertEquals(t.toString, "Typeable[Tree]")

object TypeableTests:
final case class Tree(children: Vector[Tree])

0 comments on commit f6fdde9

Please sign in to comment.