Skip to content

Commit

Permalink
Define an interface for CollectionAdapters
Browse files Browse the repository at this point in the history
Make Builder return an Either to represent potential errors.
  • Loading branch information
thesamet committed Jan 8, 2021
1 parent d8ff506 commit bbc32e6
Show file tree
Hide file tree
Showing 77 changed files with 355 additions and 260 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,16 @@ private[compiler] class BuilderGenerator(
Field(
s"__${field.scalaName}",
field.scalaName.asSymbol,
s"collection.mutable.Builder[${field.singleScalaTypeName}, ${field.scalaTypeName}]",
if (field.collection.adapter.isDefined)
s"${field.collection.adapter.get.fullName}.Builder"
else
s"_root_.scala.collection.mutable.Builder[${field.singleScalaTypeName}, ${field.scalaTypeName}]",
field.collection.newBuilder,
s"${field.collection.newBuilder} ++= $it",
s"__${field.scalaName}.result()"
if (field.collection.adapter.isDefined)
s"__${field.scalaName}.result().fold(throw _, identity(_))"
else
s"__${field.scalaName}.result()"
)
}
} ++ message.getOneofs.asScala.map { oneof =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package scalapb.compiler
import com.google.protobuf.Descriptors.FieldDescriptor
import DescriptorImplicits._

class CollectionMethods(fd: FieldDescriptor, implicits: DescriptorImplicits) {
class CollectionMethods(fd: FieldDescriptor, val implicits: DescriptorImplicits) {
import implicits._

def newBuilder: String = {
Expand All @@ -12,48 +12,53 @@ class CollectionMethods(fd: FieldDescriptor, implicits: DescriptorImplicits) {
if (!fd.isMapField) {
adapter match {
case None => s"$t.newBuilder[${fd.singleScalaTypeName}]"
case Some(tc) => s"$tc.newBuilder[${fd.singleScalaTypeName}]"
case Some(tc) => s"${tc.fullName}.newBuilder"
}
} else {
adapter match {
case None => s"$t.newBuilder[${fd.mapType.keyType}, ${fd.mapType.valueType}]"
case Some(tc) => s"$tc.newBuilder[${fd.mapType.keyType}, ${fd.mapType.valueType}]"
case Some(tc) => s"${tc.fullName}.newBuilder"
}
}
}

def empty: String = adapter match {
case None => s"${fd.collectionType}.empty"
case Some(tc) => s"$tc.empty"
case Some(tc) => s"${tc.fullName}.empty"
}

def foreach = adapter match {
case None => fd.scalaName.asSymbol + ".foreach"
case Some(tc) => s"$tc.foreach(${fd.scalaName.asSymbol})"
case Some(tc) => s"${tc.fullName}.foreach(${fd.scalaName.asSymbol})"
}

def concat(left: String, right: String) = adapter match {
case None => s"$left ++ $right"
case Some(tc) => s"$tc.concat($left, $right)"
case Some(tc) => s"${tc.fullName}.concat($left, $right)"
}

def nonEmptyType = fd.fieldOptions.getCollection.getNonEmpty

def nonEmptyCheck(expr: String) = if (nonEmptyType) "true" else s"$expr.nonEmpty"

def adapter: Option[String] = {
def adapter: Option[ScalaName] =
if (adapterClass.isDefined)
Some(fd.getContainingType.scalaType / s"_adapter_${fd.scalaName}")
else None

def adapterClass: Option[String] = {
if (fd.fieldOptions.getCollection.hasAdapter())
Some(fd.fieldOptions.getCollection.getAdapter())
else None
}

def size: Expression = adapter match {
case None => MethodApplication("size")
case Some(tc) => FunctionApplication(s"$tc.size")
case Some(tc) => FunctionApplication(s"${tc.fullName}.size")
}

def iterator: Expression = adapter match {
case None => MethodApplication("iterator")
case Some(tc) => FunctionApplication(s"$tc.toIterator")
case Some(tc) => FunctionApplication(s"${tc.fullName}.toIterator")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class DescriptorImplicits private[compiler] (
if (isSingular) EnclosingType.None
else if (supportsPresence || fd.isInOneof) EnclosingType.ScalaOption
else {
EnclosingType.Collection(collectionType, collection.adapter)
EnclosingType.Collection(collectionType, collection.adapter.map(_.fullName))
}

def fieldMapEnclosingType: EnclosingType =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object ExpressionBuilder {

def convertCollection(expr: String, targetType: EnclosingType): String = {
val convert = List(targetType match {
case Collection(_, Some(tc)) => FunctionApplication(s"${tc}.fromIterator")
case Collection(_, Some(tc)) => FunctionApplication(s"${tc}.fromIteratorUnsafe")
case Collection(DescriptorImplicits.ScalaVector, _) => MethodApplication("toVector")
case Collection(DescriptorImplicits.ScalaSeq, _) => MethodApplication("toSeq")
case Collection(DescriptorImplicits.ScalaMap, _) => MethodApplication("toMap")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,8 +871,9 @@ class ProtobufGenerator(
)

val expr = field.collection.adapter match {
case Some(tc) if (!field.isMapField()) => s"$tc.fromIterator($itemTypeTranform)"
case _ => itemTypeTranform
case Some(tc) if (!field.isMapField()) =>
s"${tc.fullName}.fromIteratorUnsafe($itemTypeTranform)"
case _ => itemTypeTranform
}

s"${field.scalaName.asSymbol} = $expr"
Expand Down Expand Up @@ -1005,6 +1006,28 @@ class ProtobufGenerator(
}
}

def generateCollectionAdapters(
fields: Seq[FieldDescriptor]
)(printer: FunctionalPrinter): FunctionalPrinter = {
val fieldsWithAdapter: Seq[(FieldDescriptor, String)] = for {
field <- fields
adapter <- field.collection.adapterClass
} yield (field, adapter)

printer
.print(fieldsWithAdapter) {
case (printer, (field, adapter)) =>
val modifier =
if (field.getFile().scalaPackage.fullName.isEmpty) "private"
else s"private[${field.getFile().scalaPackage.fullName.split('.').last}]"
printer
.add("@transient")
.add(
s"$modifier val ${field.collection.adapter.get.nameSymbol}: _root_.scalapb.CollectionAdapter[${field.singleScalaTypeName}, ${field.scalaTypeName}] = $adapter()"
)
}
}

def generateTypeMappersForMapEntry(
message: Descriptor
)(printer: FunctionalPrinter): FunctionalPrinter = {
Expand Down Expand Up @@ -1263,6 +1286,7 @@ class ProtobufGenerator(
.when(message.generateLenses)(generateMessageLens(message))
.call(generateFieldNumbers(message))
.call(generateTypeMappers(message.fields ++ message.getExtensions.asScala))
.call(generateCollectionAdapters(message.fields ++ message.getExtensions.asScala))
.when(message.isMapEntry)(generateTypeMappersForMapEntry(message))
.call(generateNoDefaultArgsFactory(message))
.add(s"// @@protoc_insertion_point(${message.messageCompanionInsertionPoint.insertionPoint})")
Expand Down
2 changes: 1 addition & 1 deletion docs/src/main/scala/scalapb/docs/person/Person.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ object Person extends scalapb.GeneratedMessageCompanion[scalapb.docs.person.Pers
final class Builder private (
private var __name: _root_.scala.Predef.String,
private var __age: _root_.scala.Int,
private var __addresses: collection.mutable.Builder[scalapb.docs.person.Person.Address, _root_.scala.Seq[scalapb.docs.person.Person.Address]],
private var __addresses: _root_.scala.collection.mutable.Builder[scalapb.docs.person.Person.Address, _root_.scala.Seq[scalapb.docs.person.Person.Address]],
private var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder
) extends _root_.scalapb.MessageBuilder[scalapb.docs.person.Person] {
def merge(`_input__`: _root_.com.google.protobuf.CodedInputStream): this.type = {
Expand Down
2 changes: 1 addition & 1 deletion docs/src/main/scala/scalapb/perf/protos/EnumVector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object EnumVector extends scalapb.GeneratedMessageCompanion[scalapb.perf.protos.
colors = _root_.scala.Seq.empty
)
final class Builder private (
private var __colors: collection.mutable.Builder[scalapb.perf.protos.Color, _root_.scala.Seq[scalapb.perf.protos.Color]],
private var __colors: _root_.scala.collection.mutable.Builder[scalapb.perf.protos.Color, _root_.scala.Seq[scalapb.perf.protos.Color]],
private var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder
) extends _root_.scalapb.MessageBuilder[scalapb.perf.protos.EnumVector] {
def merge(`_input__`: _root_.com.google.protobuf.CodedInputStream): this.type = {
Expand Down
2 changes: 1 addition & 1 deletion docs/src/main/scala/scalapb/perf/protos/IntVector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ object IntVector extends scalapb.GeneratedMessageCompanion[scalapb.perf.protos.I
ints = _root_.scala.Seq.empty
)
final class Builder private (
private var __ints: collection.mutable.Builder[_root_.scala.Int, _root_.scala.Seq[_root_.scala.Int]],
private var __ints: _root_.scala.collection.mutable.Builder[_root_.scala.Int, _root_.scala.Seq[_root_.scala.Int]],
private var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder
) extends _root_.scalapb.MessageBuilder[scalapb.perf.protos.IntVector] {
def merge(`_input__`: _root_.com.google.protobuf.CodedInputStream): this.type = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ object MessageContainer extends scalapb.GeneratedMessageCompanion[scalapb.perf.p
)
final class Builder private (
private var __opt: _root_.scala.Option[scalapb.perf.protos.SimpleMessage],
private var __rep: collection.mutable.Builder[scalapb.perf.protos.SimpleMessage, _root_.scala.Seq[scalapb.perf.protos.SimpleMessage]],
private var __rep: _root_.scala.collection.mutable.Builder[scalapb.perf.protos.SimpleMessage, _root_.scala.Seq[scalapb.perf.protos.SimpleMessage]],
private var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder
) extends _root_.scalapb.MessageBuilder[scalapb.perf.protos.MessageContainer] {
def merge(`_input__`: _root_.com.google.protobuf.CodedInputStream): this.type = {
Expand Down
Loading

0 comments on commit bbc32e6

Please sign in to comment.