Skip to content

Commit

Permalink
Add support for derives clause for messages and sealed oneofs.
Browse files Browse the repository at this point in the history
For #1584
  • Loading branch information
thesamet committed Oct 7, 2023
1 parent 8fa0b8e commit 8ee068e
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,10 @@ class DescriptorImplicits private[compiler] (
case SealedOneofStyle.Optional => messageOptions.getSealedOneofExtendsList.asScala.toSeq
}

def derives: Seq[String] = messageOptions.getDerivesList.asScala

def sealedOneofDerives: Seq[String] = messageOptions.getSealedOneofDerivesList.asScala

def nestedTypes: Seq[Descriptor] = message.getNestedTypes.asScala.toSeq

def isMapEntry: Boolean = message.getOptions.getMapEntry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,8 @@ class ProtobufGenerator(

def printMessage(printer: FunctionalPrinter, message: Descriptor): FunctionalPrinter = {
val fullName = message.scalaType.fullNameWithMaybeRoot(message)
val derives =
if (message.derives.nonEmpty) message.derives.mkString(" derives ", ", ", "") else ""
printer
.call(new SealedOneofsGenerator(message, implicits).generateSealedOneofTrait)
.call(generateScalaDoc(message))
Expand All @@ -1353,7 +1355,7 @@ class ProtobufGenerator(
.indent
.indent
.call(printConstructorFieldList(message))
.add(s") extends ${message.baseClasses.mkString(" with ")} {")
.add(s") extends ${message.baseClasses.mkString(" with ")}${derives} {")
.call(generateSerializedSizeForPackedFields(message))
.call(generateSerializedSize(message))
.call(generateWriteTo(message))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ class SealedOneofsGenerator(message: Descriptor, implicits: DescriptorImplicits)
val oneof = message.getRealOneofs.get(0)
val typeMapperName = message.sealedOneofTypeMapper.name
val baseClasses = message.sealedOneofBaseClasses
val derives =
if (message.sealedOneofDerives.nonEmpty)
s"derives ${message.sealedOneofDerives.mkString(", ")} "
else ""
val bases =
if (baseClasses.nonEmpty)
s"extends ${baseClasses.mkString(" with ")} "
s"extends ${baseClasses.mkString(" with ")} $derives"
else ""

if (message.sealedOneofStyle != SealedOneofStyle.Optional) {
val sealedOneofNonEmptyName = message.sealedOneofNonEmptyScalaType.nameSymbol
val sealedOneofNonEmptyType = message.sealedOneofNonEmptyScalaType.fullName
val bases =
if (message.sealedOneofBaseClasses.nonEmpty)
s"extends ${message.sealedOneofBaseClasses.mkString(" with ")} "
else ""
val companionBases =
if (message.sealedOneofCompanionExtendsOption.nonEmpty)
s"extends ${message.sealedOneofCompanionExtendsOption.mkString(" with ")} "
Expand Down
24 changes: 24 additions & 0 deletions e2e/src/main/protobuf-scala3/derives/cases.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
syntax = "proto3";

package com.thesamet.proto.e2e.derives;

import "scalapb/scalapb.proto";

message Foo {
option (scalapb.message).derives = "scalapb.derives.Show";
option (scalapb.message).derives = "scalapb.derives.TC";

int32 a = 1;
string b = 2;
}

message M1 {}
message M2 {}

message Expr {
option (scalapb.message).sealed_oneof_derives = "scalapb.derives.Show";
oneof sealed_value {
M1 m1 = 1;
M2 m2 = 2;
}
}
24 changes: 24 additions & 0 deletions e2e/src/main/protobuf-scala3/derives/filelevel.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
syntax = "proto3";

package com.thesamet.proto.e2e.derives;

import "scalapb/scalapb.proto";

message Foo {
option (scalapb.message).derives = "scalapb.derives.Show";
option (scalapb.message).derives = "scalapb.derives.TC";

int32 a = 1;
string b = 2;
}

message M1 {}
message M2 {}

message Expr {
option (scalapb.message).sealed_oneof_derives = "scalapb.derives.Show";
oneof sealed_value {
M1 m1 = 1;
M2 m2 = 2;
}
}
21 changes: 21 additions & 0 deletions e2e/src/main/scala-3/scalapb/derives/typeclasses.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package scalapb.derives

import scalapb.{GeneratedMessage, GeneratedSealedOneof}
import scala.compiletime.erasedValue
import scala.annotation.targetName

trait Show[T]:
def show(t: T): String

object Show:
@targetName("derivedMessage") def derived[T <: GeneratedMessage]: Show[T] = new Show[T]:
def show(t: T): String = t.toProtoString

@targetName("derivedSealedOneof") def derived[T <: GeneratedSealedOneof]: Show[T] = new Show[T]:
def show(t: T): String = "Sealed!"

trait TC[T]:
def tc(t: T): Unit

object TC:
def derived[T <: GeneratedMessage]: TC[T] = null
21 changes: 21 additions & 0 deletions e2e/src/test/scalajvm-3/DerivesSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers
import org.scalatest.OptionValues
import scalapb.derives.{Show, TC}
import com.thesamet.proto.e2e.derives.cases

class DerivesSpec extends AnyFlatSpec with Matchers with OptionValues:
"Show typeclass" should "be summonable for derives.foo" in:
val s = summon[Show[cases.Foo]]
s.show(cases.Foo(3,"xyz")) must be(
"""a: 3
|b: "xyz"
|""".stripMargin)

"TC typeclass" should "be summonable for derives.foo and be null" in:
val s = summon[TC[cases.Foo]]
s must be(null)

"Show" should "return sealed for Expr" in:
val s = summon[Show[cases.Expr]]
s.show(cases.M1()) must be("Sealed!")
12 changes: 12 additions & 0 deletions e2e/src/test/scalajvm-3/Scala3CompatSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import com.thesamet.proto.e2e.scala3.issue1576.Foo
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers
import org.scalatest.OptionValues
import scalapb.lenses.{MessageLens, ObjectLens}

class Scala3CompatSpec extends AnyFlatSpec with Matchers with OptionValues {
"message lens" should "extend MessageLens and ObjectLens" in {
classOf[MessageLens[_, _]].isAssignableFrom(classOf[Foo.FooLens[_]]) must be(true)
classOf[ObjectLens[_, _]].isAssignableFrom(classOf[Foo.FooLens[_]]) must be(true)
}
}
6 changes: 6 additions & 0 deletions protobuf/scalapb/scalapb.proto
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ message MessageOptions {
// Additional classes and traits to mix in to generated sealed oneof base trait's companion object.
repeated string sealed_oneof_companion_extends = 10;

// Adds a derives clause to the message case class
repeated string derives = 11;

// Additional classes and traits to add to the derives clause of a sealed oneof.
repeated string sealed_oneof_derives = 12;

extensions 1000 to max;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ package scalapb.options
* file-level value, and can be overridden by the field-level setting.
* @param sealedOneofCompanionExtends
* Additional classes and traits to mix in to generated sealed oneof base trait's companion object.
* @param derives
* Adds a derives clause to the message case class
* @param sealedOneofDerives
* Additional classes and traits to add to the derives clause of a sealed oneof.
*/
@SerialVersionUID(0L)
final case class MessageOptions(
Expand All @@ -42,6 +46,8 @@ final case class MessageOptions(
unknownFieldsAnnotations: _root_.scala.Seq[_root_.scala.Predef.String] = _root_.scala.Seq.empty,
noDefaultValuesInConstructor: _root_.scala.Option[_root_.scala.Boolean] = _root_.scala.None,
sealedOneofCompanionExtends: _root_.scala.Seq[_root_.scala.Predef.String] = _root_.scala.Seq.empty,
derives: _root_.scala.Seq[_root_.scala.Predef.String] = _root_.scala.Seq.empty,
sealedOneofDerives: _root_.scala.Seq[_root_.scala.Predef.String] = _root_.scala.Seq.empty,
unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty
) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[MessageOptions] with _root_.scalapb.ExtendableMessage[MessageOptions] {
@transient
Expand Down Expand Up @@ -88,6 +94,14 @@ final case class MessageOptions(
val __value = __item
__size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(10, __value)
}
derives.foreach { __item =>
val __value = __item
__size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(11, __value)
}
sealedOneofDerives.foreach { __item =>
val __value = __item
__size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(12, __value)
}
__size += unknownFields.serializedSize
__size
}
Expand Down Expand Up @@ -141,6 +155,14 @@ final case class MessageOptions(
val __m = __v
_output__.writeString(10, __m)
};
derives.foreach { __v =>
val __m = __v
_output__.writeString(11, __m)
};
sealedOneofDerives.foreach { __v =>
val __m = __v
_output__.writeString(12, __m)
};
unknownFields.writeTo(_output__)
}
def clearExtends = copy(`extends` = _root_.scala.Seq.empty)
Expand Down Expand Up @@ -180,6 +202,14 @@ final case class MessageOptions(
def addSealedOneofCompanionExtends(__vs: _root_.scala.Predef.String *): MessageOptions = addAllSealedOneofCompanionExtends(__vs)
def addAllSealedOneofCompanionExtends(__vs: Iterable[_root_.scala.Predef.String]): MessageOptions = copy(sealedOneofCompanionExtends = sealedOneofCompanionExtends ++ __vs)
def withSealedOneofCompanionExtends(__v: _root_.scala.Seq[_root_.scala.Predef.String]): MessageOptions = copy(sealedOneofCompanionExtends = __v)
def clearDerives = copy(derives = _root_.scala.Seq.empty)
def addDerives(__vs: _root_.scala.Predef.String *): MessageOptions = addAllDerives(__vs)
def addAllDerives(__vs: Iterable[_root_.scala.Predef.String]): MessageOptions = copy(derives = derives ++ __vs)
def withDerives(__v: _root_.scala.Seq[_root_.scala.Predef.String]): MessageOptions = copy(derives = __v)
def clearSealedOneofDerives = copy(sealedOneofDerives = _root_.scala.Seq.empty)
def addSealedOneofDerives(__vs: _root_.scala.Predef.String *): MessageOptions = addAllSealedOneofDerives(__vs)
def addAllSealedOneofDerives(__vs: Iterable[_root_.scala.Predef.String]): MessageOptions = copy(sealedOneofDerives = sealedOneofDerives ++ __vs)
def withSealedOneofDerives(__v: _root_.scala.Seq[_root_.scala.Predef.String]): MessageOptions = copy(sealedOneofDerives = __v)
def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v)
def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty)
def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = {
Expand All @@ -194,6 +224,8 @@ final case class MessageOptions(
case 8 => unknownFieldsAnnotations
case 9 => noDefaultValuesInConstructor.orNull
case 10 => sealedOneofCompanionExtends
case 11 => derives
case 12 => sealedOneofDerives
}
}
def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = {
Expand All @@ -209,6 +241,8 @@ final case class MessageOptions(
case 8 => _root_.scalapb.descriptors.PRepeated(unknownFieldsAnnotations.iterator.map(_root_.scalapb.descriptors.PString(_)).toVector)
case 9 => noDefaultValuesInConstructor.map(_root_.scalapb.descriptors.PBoolean(_)).getOrElse(_root_.scalapb.descriptors.PEmpty)
case 10 => _root_.scalapb.descriptors.PRepeated(sealedOneofCompanionExtends.iterator.map(_root_.scalapb.descriptors.PString(_)).toVector)
case 11 => _root_.scalapb.descriptors.PRepeated(derives.iterator.map(_root_.scalapb.descriptors.PString(_)).toVector)
case 12 => _root_.scalapb.descriptors.PRepeated(sealedOneofDerives.iterator.map(_root_.scalapb.descriptors.PString(_)).toVector)
}
}
def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this)
Expand All @@ -229,6 +263,8 @@ object MessageOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
val __unknownFieldsAnnotations: _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String] = new _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String]
var __noDefaultValuesInConstructor: _root_.scala.Option[_root_.scala.Boolean] = _root_.scala.None
val __sealedOneofCompanionExtends: _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String] = new _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String]
val __derives: _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String] = new _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String]
val __sealedOneofDerives: _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String] = new _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String]
var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null
var _done__ = false
while (!_done__) {
Expand All @@ -255,6 +291,10 @@ object MessageOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
__noDefaultValuesInConstructor = Option(_input__.readBool())
case 82 =>
__sealedOneofCompanionExtends += _input__.readStringRequireUtf8()
case 90 =>
__derives += _input__.readStringRequireUtf8()
case 98 =>
__sealedOneofDerives += _input__.readStringRequireUtf8()
case tag =>
if (_unknownFields__ == null) {
_unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder()
Expand All @@ -273,6 +313,8 @@ object MessageOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
unknownFieldsAnnotations = __unknownFieldsAnnotations.result(),
noDefaultValuesInConstructor = __noDefaultValuesInConstructor,
sealedOneofCompanionExtends = __sealedOneofCompanionExtends.result(),
derives = __derives.result(),
sealedOneofDerives = __sealedOneofDerives.result(),
unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result()
)
}
Expand All @@ -289,7 +331,9 @@ object MessageOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
noBox = __fieldsMap.get(scalaDescriptor.findFieldByNumber(7).get).flatMap(_.as[_root_.scala.Option[_root_.scala.Boolean]]),
unknownFieldsAnnotations = __fieldsMap.get(scalaDescriptor.findFieldByNumber(8).get).map(_.as[_root_.scala.Seq[_root_.scala.Predef.String]]).getOrElse(_root_.scala.Seq.empty),
noDefaultValuesInConstructor = __fieldsMap.get(scalaDescriptor.findFieldByNumber(9).get).flatMap(_.as[_root_.scala.Option[_root_.scala.Boolean]]),
sealedOneofCompanionExtends = __fieldsMap.get(scalaDescriptor.findFieldByNumber(10).get).map(_.as[_root_.scala.Seq[_root_.scala.Predef.String]]).getOrElse(_root_.scala.Seq.empty)
sealedOneofCompanionExtends = __fieldsMap.get(scalaDescriptor.findFieldByNumber(10).get).map(_.as[_root_.scala.Seq[_root_.scala.Predef.String]]).getOrElse(_root_.scala.Seq.empty),
derives = __fieldsMap.get(scalaDescriptor.findFieldByNumber(11).get).map(_.as[_root_.scala.Seq[_root_.scala.Predef.String]]).getOrElse(_root_.scala.Seq.empty),
sealedOneofDerives = __fieldsMap.get(scalaDescriptor.findFieldByNumber(12).get).map(_.as[_root_.scala.Seq[_root_.scala.Predef.String]]).getOrElse(_root_.scala.Seq.empty)
)
case _ => throw new RuntimeException("Expected PMessage")
}
Expand All @@ -308,7 +352,9 @@ object MessageOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
noBox = _root_.scala.None,
unknownFieldsAnnotations = _root_.scala.Seq.empty,
noDefaultValuesInConstructor = _root_.scala.None,
sealedOneofCompanionExtends = _root_.scala.Seq.empty
sealedOneofCompanionExtends = _root_.scala.Seq.empty,
derives = _root_.scala.Seq.empty,
sealedOneofDerives = _root_.scala.Seq.empty
)
implicit class MessageOptionsLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, scalapb.options.MessageOptions]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, scalapb.options.MessageOptions](_l) {
def `extends`: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[_root_.scala.Predef.String]] = field(_.`extends`)((c_, f_) => c_.copy(`extends` = f_))
Expand All @@ -324,6 +370,8 @@ object MessageOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
def noDefaultValuesInConstructor: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Boolean] = field(_.getNoDefaultValuesInConstructor)((c_, f_) => c_.copy(noDefaultValuesInConstructor = Option(f_)))
def optionalNoDefaultValuesInConstructor: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[_root_.scala.Boolean]] = field(_.noDefaultValuesInConstructor)((c_, f_) => c_.copy(noDefaultValuesInConstructor = f_))
def sealedOneofCompanionExtends: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[_root_.scala.Predef.String]] = field(_.sealedOneofCompanionExtends)((c_, f_) => c_.copy(sealedOneofCompanionExtends = f_))
def derives: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[_root_.scala.Predef.String]] = field(_.derives)((c_, f_) => c_.copy(derives = f_))
def sealedOneofDerives: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[_root_.scala.Predef.String]] = field(_.sealedOneofDerives)((c_, f_) => c_.copy(sealedOneofDerives = f_))
}
final val EXTENDS_FIELD_NUMBER = 1
final val COMPANION_EXTENDS_FIELD_NUMBER = 2
Expand All @@ -335,6 +383,8 @@ object MessageOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
final val UNKNOWN_FIELDS_ANNOTATIONS_FIELD_NUMBER = 8
final val NO_DEFAULT_VALUES_IN_CONSTRUCTOR_FIELD_NUMBER = 9
final val SEALED_ONEOF_COMPANION_EXTENDS_FIELD_NUMBER = 10
final val DERIVES_FIELD_NUMBER = 11
final val SEALED_ONEOF_DERIVES_FIELD_NUMBER = 12
def of(
`extends`: _root_.scala.Seq[_root_.scala.Predef.String],
companionExtends: _root_.scala.Seq[_root_.scala.Predef.String],
Expand All @@ -345,7 +395,9 @@ object MessageOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
noBox: _root_.scala.Option[_root_.scala.Boolean],
unknownFieldsAnnotations: _root_.scala.Seq[_root_.scala.Predef.String],
noDefaultValuesInConstructor: _root_.scala.Option[_root_.scala.Boolean],
sealedOneofCompanionExtends: _root_.scala.Seq[_root_.scala.Predef.String]
sealedOneofCompanionExtends: _root_.scala.Seq[_root_.scala.Predef.String],
derives: _root_.scala.Seq[_root_.scala.Predef.String],
sealedOneofDerives: _root_.scala.Seq[_root_.scala.Predef.String]
): _root_.scalapb.options.MessageOptions = _root_.scalapb.options.MessageOptions(
`extends`,
companionExtends,
Expand All @@ -356,7 +408,9 @@ object MessageOptions extends scalapb.GeneratedMessageCompanion[scalapb.options.
noBox,
unknownFieldsAnnotations,
noDefaultValuesInConstructor,
sealedOneofCompanionExtends
sealedOneofCompanionExtends,
derives,
sealedOneofDerives
)
// @@protoc_insertion_point(GeneratedMessageCompanion[scalapb.MessageOptions])
}
Loading

0 comments on commit 8ee068e

Please sign in to comment.