Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple params #3

Merged
merged 6 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
runner.dialect = "scala3"
version = 3.5.8
version = 3.7.17
runner.dialect = scala3
runner.dialectOverride.allowSignificantIndentation = false

maxColumn = 100
align.preset = some

Expand Down
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ val commonSettings = Seq(
) ++
compilerPlugins,
scalacOptions ++= Seq(
"-Wunused:all"
"-Wunused:all",
"-no-indent",
),
)

Expand Down
258 changes: 153 additions & 105 deletions core/src/main/scala/respectfully/API.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,47 +40,48 @@ import scala.quoted.Type
import scala.quoted.quotes

trait API[Alg] {
def toRoutes: Alg => HttpApp[IO]
def toClient: (Client[IO], Uri) => Alg
def toRoutes(impl: Alg): HttpApp[IO]
def toClient(client: Client[IO], uri: Uri): Alg
}

object API {

def apply[Alg](using api: API[Alg]): API[Alg] = api
def apply[Alg](
using api: API[Alg]
): API[Alg] = api

inline def derived[Alg]: API[Alg] = ${ derivedImpl[Alg] }

private def derivedImpl[Alg: Type](using Quotes): Expr[API[Alg]] = {
private def derivedImpl[Alg: Type](
using Quotes
): Expr[API[Alg]] = {
import quotes.reflect.{TypeRepr, report, DefDef, Position, asTerm}

val algTpe = TypeRepr.of[Alg]
val endpoints = algTpe.typeSymbol.declaredMethods.map { meth =>
require(
meth.paramSymss.size == 1,
"Only methods with one parameter list are supported, got: " + meth.paramSymss + " for " + meth.name,
)

val inputCodec =
meth.paramSymss.head match {
case Nil => '{ Codec.from(Decoder[Unit], Encoder[Unit]) }

case one :: Nil => /* ok */
one.termRef.typeSymbol.typeRef.asType match {
case '[t] =>
'{
Codec.from(
summonInline[Decoder[t]],
summonInline[Encoder[t]],
)
}
}
val typeParameters = meth.paramSymss.flatten.filter(_.isTypeParam)
if (typeParameters.nonEmpty)
report.errorAndAbort(
s"Methods with type parameters are not supported. `${meth.name}` has type parameters: ${typeParameters.map(_.name).mkString(", ")}"
)

case _ =>
report.errorAndAbort(
"Only methods with one parameter are supported",
meth.pos.getOrElse(Position.ofMacroExpansion),
)
val inputCodec: Expr[Codec[List[List[Any]]]] = combineCodecs {
meth.paramSymss.map {
_.map { one =>
val codec =
one.termRef.typeSymbol.typeRef.asType match {
case '[t] =>
'{
Codec.from(
summonInline[Decoder[t]],
summonInline[Encoder[t]],
)
}
}
one.termRef.termSymbol.name -> codec
}
}
}

val outputCodec =
meth.tree.asInstanceOf[DefDef].returnTpt.tpe.asType match {
Expand Down Expand Up @@ -114,52 +115,56 @@ object API {
}
}

def functionsFor(algExpr: Expr[Alg]): Expr[List[(String, Any => IO[Any])]] = Expr.ofList {
def functionsFor(
algExpr: Expr[Alg]
): Expr[List[(String, List[List[Any]] => IO[Any])]] = Expr.ofList {
algTpe
.typeSymbol
.declaredMethods
.map { meth =>
meth.paramSymss.head match {
case Nil =>
// special-case: nullary method
Expr(meth.name) -> '{ (_: Any) =>
${ algExpr.asTerm.select(meth).appliedToNone.asExprOf[IO[Any]] }
}
val selectMethod = algExpr.asTerm.select(meth)

case sym :: Nil =>
sym.termRef.typeSymbol.typeRef.asType match {
case '[t] =>
Expr(meth.name) -> '{ (input: Any) =>
${
//format: off
algExpr
.asTerm
.select(meth)
.appliedTo('{ input.asInstanceOf[t] }.asTerm)
.asExprOf[IO[Any]]
//format: on
Expr(meth.name) -> meth.paramSymss.match {
case Nil :: Nil =>
// special-case: nullary method (one, zero-parameter list)
'{ Function.const(${ selectMethod.appliedToNone.asExprOf[IO[Any]] }) }

case _ =>
val types = meth.paramSymss.map(_.map(_.termRef.typeSymbol.typeRef.asType))

'{ (input: List[List[Any]]) =>
${
selectMethod
.appliedToArgss {
types
.zipWithIndex
.map { (tpeList, idx0) =>
tpeList.zipWithIndex.map { (tpe, idx1) =>
tpe match {
case '[t] =>
'{ input(${ Expr(idx0) })(${ Expr(idx1) }).asInstanceOf[t] }.asTerm
}
}
}
.toList
}
}
.asExprOf[IO[Any]]
}
}
case _ =>
report.errorAndAbort(
"Only methods with one parameter are supported",
meth.pos.getOrElse(Position.ofMacroExpansion),
)
}

}
.map(Expr.ofTuple(_))
}

val asFunction: Expr[Alg => AsFunction] =
'{ (alg: Alg) =>
val functionsByName: Map[String, Any => IO[Any]] = ${ functionsFor('alg) }.toMap
val functionsByName: Map[String, List[List[Any]] => IO[Any]] = ${ functionsFor('alg) }.toMap
new AsFunction {
def apply[In, Out](
endpointName: String,
in: In,
): IO[Out] = functionsByName(endpointName)(in).asInstanceOf[IO[Out]]
): IO[Out] = functionsByName(endpointName)(in.asInstanceOf[List[List[Any]]])
.asInstanceOf[IO[Out]]

}
}
Expand All @@ -169,7 +174,44 @@ object API {
'{ API.instance[Alg](${ Expr.ofList(endpoints) }, ${ asFunction }, ${ fromFunction }) }
}

private def proxy[Trait: Type](using Quotes)(asf: Expr[AsFunction]) = {
private inline def combineCodecs(
codecss: List[List[(String, Expr[Codec[?]])]]
)(
using Quotes
): Expr[Codec[List[List[Any]]]] =
'{
combineCodecsRuntime(
${
Expr.ofList {
codecss.map { codecs =>
Expr.ofList(
codecs.map { case (k, v) => Expr.ofTuple((Expr(k), v)) }
)
}
}
}
)
}

private def combineCodecsRuntime(
codecss: List[List[(String, Codec[?])]]
): Codec[List[List[Any]]] = Codec.from(
codecss.traverse(_.traverse { case (k, decoder) => decoder.at(k).widen }),
inputss =>
Json.obj(
inputss.zip(codecss).flatMap { (inputs, codecs) =>
inputs.zip(codecs).map { case (param, (k, encoder)) =>
k -> encoder.asInstanceOf[Encoder[Any]](param)
}
}: _*
),
)

private def proxy[Trait: Type](
using Quotes
)(
asf: Expr[AsFunction]
) = {
import quotes.reflect.*
val parents = List(TypeTree.of[Object], TypeTree.of[Trait])

Expand Down Expand Up @@ -204,20 +246,28 @@ object API {
.asInstanceOf[Symbol]

val body: List[DefDef] = cls.declaredMethods.map { sym =>
def undefinedTerm(args: List[List[Tree]]) = {
args.head match {
case Nil => '{ ${ asf }.apply(${ Expr(sym.name) }, ()) }
case one :: Nil => '{ ${ asf }.apply(${ Expr(sym.name) }, ${ one.asExprOf[Any] }) }
def impl(argss: List[List[Tree]]) = {
argss match {
case Nil :: Nil => '{ ${ asf }.apply(${ Expr(sym.name) }, Nil) }
case _ =>
report.errorAndAbort(
"Only methods with one parameter are supported",
sym.pos.getOrElse(Position.ofMacroExpansion),
)
'{
${ asf }.apply(
endpointName = ${ Expr(sym.name) },
in =
${
Expr.ofList(argss.map { argList =>
Expr.ofList(
argList.map(_.asExprOf[Any])
)
})
},
)
}
}

}.asTerm

DefDef(sym, args => Some(undefinedTerm(args)))
DefDef(sym, argss => Some(impl(argss)))
}

// The definition is experimental and I didn't want to bother.
Expand Down Expand Up @@ -253,51 +303,49 @@ object API {
new API[Alg] {
private val endpointsByName = endpoints.groupBy(_.name).fmap(_.head)

override val toClient: (Client[IO], Uri) => Alg =
(c, uri) =>
fromFunction {
new AsFunction {
override def apply[In, Out](endpointName: String, in: In): IO[Out] = {
val e = endpointsByName(endpointName).asInstanceOf[Endpoint[In, Out]]
override def toClient(c: Client[IO], uri: Uri): Alg = fromFunction {
new AsFunction {
override def apply[In, Out](endpointName: String, in: In): IO[Out] = {
val e = endpointsByName(endpointName).asInstanceOf[Endpoint[In, Out]]

given Codec[e.Out] = e.output
given Codec[e.Out] = e.output

def write(
methodName: String,
input: Json,
): Request[IO] = Request[IO](uri = uri, method = Method.POST)
.withHeaders(Header.Raw(CIString("X-Method"), methodName))
.withEntity(input)
def write(
methodName: String,
input: Json,
): Request[IO] = Request[IO](uri = uri, method = Method.POST)
.withHeaders(Header.Raw(CIString("X-Method"), methodName))
.withEntity(input)

c.expect[e.Out](write(e.name, e.input.apply(in)))
}
}
c.expect[e.Out](write(e.name, e.input.apply(in)))
}
}
}

override val toRoutes: Alg => HttpApp[IO] =
impl =>
val implFunction = asFunction(impl)

HttpApp { req =>
val methodName: String =
req
.headers
.get(CIString("X-Method"))
.getOrElse(sys.error("missing X-Method header"))
.head
.value
override def toRoutes(impl: Alg): HttpApp[IO] = {
val implFunction = asFunction(impl)

HttpApp { req =>
val methodName: String =
req
.as[Json]
.flatMap { input =>
val e = endpointsByName(methodName)

e.input
.decodeJson(input)
.liftTo[IO]
.flatMap(implFunction.apply[e.In, e.Out](e.name, _).map(e.output.apply(_)))
}
.map(Response[IO]().withEntity(_))
}
.headers
.get(CIString("X-Method"))
.getOrElse(sys.error("missing X-Method header"))
.head
.value
req
.as[Json]
.flatMap { input =>
val e = endpointsByName(methodName)

e.input
.decodeJson(input)
.liftTo[IO]
.flatMap(implFunction.apply[e.In, e.Out](e.name, _).map(e.output.apply(_)))
}
.map(Response[IO]().withEntity(_))
}
}

}

Expand Down
Loading
Loading