diff --git a/compiler/src/dotty/tools/dotc/transform/Bridges.scala b/compiler/src/dotty/tools/dotc/transform/Bridges.scala index 9f0c32f89a45..4c1b2bb98c98 100644 --- a/compiler/src/dotty/tools/dotc/transform/Bridges.scala +++ b/compiler/src/dotty/tools/dotc/transform/Bridges.scala @@ -31,12 +31,9 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) { /** Only use the superclass of `root` as a parent class. This means * overriding pairs that have a common implementation in a trait parent - * are also counted. The only exception is generation of bridges for traits, - * where we want to be able to deduplicate bridges already defined in parents. + * are also counted. */ - override lazy val parents = - if(root.is(Trait)) super.parents - else Array(root.superClass) + override def parents = Array(root.superClass) override def exclude(sym: Symbol) = !sym.isOneOf(MethodOrModule) || sym.isAllOf(Module | JavaDefined) || super.exclude(sym) @@ -45,6 +42,25 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) { OverridingPairs.isOverridingPair(sym1, sym2, parent.thisType) } + /** Usually, we don't want to create bridges for methods defined in traits, but for some cases it is necessary. + * For example with SAM methods combined with covariant result types (see issue #15402). + * In some cases, creating a bridge inside a trait to an erased generic method leads to incorrect + * interface method lookup and infinite loops at run-time. (e.g., in cats' `WriterTApplicative.map`). + * To avoid that issue, we limit bridges to methods with the same set of parameters and a different, covariant result type. + * We also ignore non-public methods (see `DottyBackendTests.invocationReceivers` for a test case). + */ + private class TraitBridgesCursor(using Context) extends BridgesCursor{ + // Get full list of parents to deduplicate already defined bridges in the parents + override lazy val parents: Array[Symbol] = + root.info.parents.map(_.classSymbol).toArray + + override protected def matches(sym1: Symbol, sym2: Symbol): Boolean = + sym1.signature.consistentParams(sym2.signature) && super.matches(sym1, sym2) + + override def exclude(sym: Symbol) = + !sym.isPublic || super.exclude(sym) + } + val site = root.thisType private var toBeRemoved = immutable.Set[Symbol]() @@ -174,14 +190,12 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) { * time deferred methods in `stats` that are replaced by a bridge with the same signature. */ def add(stats: List[untpd.Tree]): List[untpd.Tree] = - // When adding bridges to traits ignore non-public methods - // see `DottyBackendTests.invocationReceivers` - val opc = inContext(preErasureCtx) { new BridgesCursor } + val opc = inContext(preErasureCtx) { + if (root.is(Trait)) new TraitBridgesCursor + else new BridgesCursor + } while opc.hasNext do - if - !opc.overriding.is(Deferred) && - (!root.is(Trait) || opc.overridden.isPublic) - then + if !opc.overriding.is(Deferred) then addBridgeIfNeeded(opc.overriding, opc.overridden) opc.next() if bridges.isEmpty then stats diff --git a/tests/run/i15402b/ImplJava_2.java b/tests/run/i15402b/ImplJava_2.java new file mode 100644 index 000000000000..d4ac03240db5 --- /dev/null +++ b/tests/run/i15402b/ImplJava_2.java @@ -0,0 +1,25 @@ +interface FooJavaFromScala extends NamedScala { + default FooJavaFromScala self() { + return this; + } + NamedScala foo(NamedScala x); +} + +interface FooJavaFromJava extends NamedJava { + default FooJavaFromJava self() { + return this; + } + NamedJava foo(NamedJava x); +} + +class BarJavaFromJava implements FooJavaFromJava { + public NamedJava foo(NamedJava x) { + return x; + } +} + +class BarJavaFromScala implements FooJavaFromScala { + public NamedScala foo(NamedScala x) { + return x; + } +} diff --git a/tests/run/i15402b/ImplScala_2.scala b/tests/run/i15402b/ImplScala_2.scala new file mode 100644 index 000000000000..fc6cf49f53a2 --- /dev/null +++ b/tests/run/i15402b/ImplScala_2.scala @@ -0,0 +1,7 @@ +trait FooScalaFromScala extends NamedScala: + def self: FooScalaFromScala = this + def foo(x: NamedScala): NamedScala + +trait FooScalaFromJava extends NamedJava: + def self: FooScalaFromJava = this + def foo(x: NamedJava): NamedJava \ No newline at end of file diff --git a/tests/run/i15402b/NamedJava_1.java b/tests/run/i15402b/NamedJava_1.java new file mode 100644 index 000000000000..297812a05e91 --- /dev/null +++ b/tests/run/i15402b/NamedJava_1.java @@ -0,0 +1,3 @@ +interface NamedJava { + NamedJava self(); +} diff --git a/tests/run/i15402b/NamedScala_1.scala b/tests/run/i15402b/NamedScala_1.scala new file mode 100644 index 000000000000..f5e4944035bc --- /dev/null +++ b/tests/run/i15402b/NamedScala_1.scala @@ -0,0 +1,2 @@ +trait NamedScala: + def self: NamedScala diff --git a/tests/run/i15402b/Usage_3.scala b/tests/run/i15402b/Usage_3.scala new file mode 100644 index 000000000000..d254dcc8a3ad --- /dev/null +++ b/tests/run/i15402b/Usage_3.scala @@ -0,0 +1,18 @@ +class Names(xs: List[NamedScala | NamedJava]): + def mkString = xs.map{ + case n: NamedScala => n.self + case n: NamedJava => n.self + }.mkString(",") + +object Names: + def single[T <: NamedScala](t: T): Names = Names(List(t)) + def single[T <: NamedJava](t: T): Names = Names(List(t)) + + +@main def Test() = + Names.single[FooJavaFromJava](identity).mkString + Names.single[FooJavaFromScala](identity).mkString + Names(List(new BarJavaFromJava())).mkString + Names(List(new BarJavaFromScala())).mkString + Names.single[FooScalaFromJava](identity).mkString // failing in #15402 + Names.single[FooScalaFromScala](identity).mkString // failing in #15402 \ No newline at end of file diff --git a/tests/run/mixin-signatures.check b/tests/run/mixin-signatures.check index 30e41c49f623..34adaf49d461 100644 --- a/tests/run/mixin-signatures.check +++ b/tests/run/mixin-signatures.check @@ -61,9 +61,10 @@ interface Foo1 { } interface Foo2 { + public abstract java.lang.Object Base.f(java.lang.Object) + generic: public abstract R Base.f(T) public default java.lang.Object Foo2.f(java.lang.String) generic: public default R Foo2.f(java.lang.String) - public default java.lang.Object Foo2.f(java.lang.Object) public abstract java.lang.Object Base.g(java.lang.Object) generic: public abstract R Base.g(T) public abstract java.lang.Object Foo2.g(java.lang.String)