Skip to content

Commit

Permalink
http (fix, breaking): RPCContext.current.getThreadLocal interface cha…
Browse files Browse the repository at this point in the history
…nge to avoid unsafe type cast (#3548)

- Breaking change: `RPCContext.current.getThreadLocal[A](key: String):
A` -> `RPCContext.current.getThreadLocal(key: String): Any` to avoid
type cast error
- Also, fixed a bug of getting the previous thread-local values
  • Loading branch information
xerial authored May 31, 2024
1 parent 39ba662 commit f3aaaa5
Show file tree
Hide file tree
Showing 13 changed files with 94 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ case class FinagleRPCContext(request: Request) extends RPCContext {
FinagleBackend.setThreadLocal(key, value)
}

override def getThreadLocal[A](key: String): Option[A] = {
override def getThreadLocal(key: String): Option[Any] = {
FinagleBackend.getThreadLocal(key)
}

override def getThreadLocalUnsafe[A](key: String): Option[A] = {
getThreadLocal(key).map(_.asInstanceOf[A])
}

override def httpRequest: HttpMessage.Request = {
request.toHttpRequest
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ThreadLocalStorageTest extends AirSpec {

@Endpoint(path = "/rpc-context")
def rpcContext: String = {
RPCContext.current.getThreadLocal[String]("client_id").getOrElse("unknown")
RPCContext.current.getThreadLocal("client_id").map(_.toString).getOrElse("unknown")
}

@Endpoint(path = "/rpc-header")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package wvlet.airframe.http.grpc

import io.grpc.*
import wvlet.airframe.http.internal.TLSSupport
import wvlet.airframe.http.{Http, HttpMessage, RPCContext, RPCEncoding}
import wvlet.log.LogSupport

Expand Down Expand Up @@ -57,16 +58,9 @@ case class GrpcContext(
metadata: Metadata,
descriptor: MethodDescriptor[_, _]
) extends RPCContext
with TLSSupport
with LogSupport {

// Grpc doesn't provide a mutable thread-local stage, so create our own TLS here.
private lazy val tls =
ThreadLocal.withInitial[collection.mutable.Map[String, Any]](() => mutable.Map.empty[String, Any])

private def storage: collection.mutable.Map[String, Any] = {
tls.get()
}

// Return the accept header
def accept: String = metadata.accept
def encoding: RPCEncoding = accept match {
Expand All @@ -79,11 +73,11 @@ case class GrpcContext(
}

override def setThreadLocal[A](key: String, value: A): Unit = {
storage.put(key, value)
setTLS(key, value)
}

override def getThreadLocal[A](key: String): Option[A] = {
storage.get(key).asInstanceOf[Option[A]]
override def getThreadLocal(key: String): Option[Any] = {
getTLS(key)
}

override def httpRequest: HttpMessage.Request = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ trait DemoApi extends LogSupport {

def getRPCContext: Option[String] = {
val ctx = RPCContext.current
ctx.getThreadLocal[String]("client_id")
ctx.getThreadLocal("client_id").map(_.toString)
}

def getRequest: Request = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ package wvlet.airframe.http.netty

import wvlet.airframe.http.HttpMessage.{Request, Response}
import wvlet.airframe.http.*
import wvlet.airframe.http.internal.TLSSupport
import wvlet.airframe.rx.Rx
import wvlet.log.LogSupport

import scala.collection.mutable
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.util.{Failure, Success}

object NettyBackend extends HttpBackend[Request, Response, Rx] with LogSupport { self =>
object NettyBackend extends HttpBackend[Request, Response, Rx] with TLSSupport with LogSupport { self =>
private val rxBackend = new RxNettyBackend

override protected implicit val httpRequestAdapter: HttpRequestAdapter[Request] =
Expand Down Expand Up @@ -89,21 +90,16 @@ object NettyBackend extends HttpBackend[Request, Response, Rx] with LogSupport {
f.toRx.map(body)
}

private lazy val tls =
ThreadLocal.withInitial[collection.mutable.Map[String, Any]](() => mutable.Map.empty[String, Any])

private def storage: collection.mutable.Map[String, Any] = tls.get()

override def withThreadLocalStore(request: => Rx[Response]): Rx[Response] = {
//
request
}

override def setThreadLocal[A](key: String, value: A): Unit = {
storage.put(key, value)
setTLS(key, value)
}

override def getThreadLocal[A](key: String): Option[A] = {
storage.get(key).asInstanceOf[Option[A]]
getTLS(key).map(_.asInstanceOf[A])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ package wvlet.airframe.http.netty

import wvlet.airframe.http.HttpMessage.Request
import wvlet.airframe.http.RPCContext
import wvlet.airframe.http.internal.TLSSupport

class NettyRPCContext(val httpRequest: Request) extends RPCContext {
override def setThreadLocal[A](key: String, value: A): Unit = {
NettyBackend.setThreadLocal(key, value)
}
override def getThreadLocal[A](key: String): Option[A] = {
NettyBackend.getThreadLocal(key)
}
import scala.collection.mutable

class NettyRPCContext(val httpRequest: Request) extends RPCContext with TLSSupport {
override def setThreadLocal[A](key: String, value: A): Unit = setTLS(key, value)
override def getThreadLocal(key: String): Option[Any] = getTLS(key)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,25 @@ class NettyBackendTest extends AirSpec {
val key = ULID.newULIDString

test("must be None by default") {
NettyBackend.getThreadLocal[Int](key) shouldBe None
NettyBackend.getThreadLocal(key) shouldBe None
}

test("store different content for each thread") {
NettyBackend.setThreadLocal[Int](key, 123)
NettyBackend.setThreadLocal(key, 123)

var valueInThread: Option[Int] = None

val t = new Thread {
override def run(): Unit = {
NettyBackend.getThreadLocal[Int](key) shouldBe None
NettyBackend.setThreadLocal[Int](key, 456)
valueInThread = NettyBackend.getThreadLocal[Int](key)
NettyBackend.getThreadLocal(key) shouldBe None
NettyBackend.setThreadLocal(key, 456)
valueInThread = NettyBackend.getThreadLocal(key)
}
}
t.start()
t.join()

NettyBackend.getThreadLocal[Int](key) shouldBe Some(123)
NettyBackend.getThreadLocal(key) shouldBe Some(123)
valueInThread shouldBe Some(456)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@ object NettyLoggingTest extends AirSpec {

@RPC
class MyRPC extends LogSupport {
private var requestCount = 0

def hello(): Unit = {
RPCContext.current.setThreadLocal("user", "xxxx_yyyy")
debug("hello rpc")
if (requestCount == 0) {
RPCContext.current.setThreadLocal("user", "xxxx_yyyy")
}
requestCount += 1
trace("hello rpc")
}
}

Expand All @@ -46,7 +51,7 @@ object NettyLoggingTest extends AirSpec {
.withName("log-test-server")
.withExtraLogEntries { () =>
val m = ListMap.newBuilder[String, Any]
RPCContext.current.getThreadLocal[String]("user").foreach { v =>
RPCContext.current.getThreadLocal("user").foreach { v =>
m += "user" -> v
}
m += ("custom_log_entry" -> "test")
Expand All @@ -67,12 +72,20 @@ object NettyLoggingTest extends AirSpec {

test("add server custom log") { (syncClient: SyncClient) =>
syncClient.send(Http.POST("/wvlet.airframe.http.netty.NettyLoggingTest.MyRPC/hello"))
val logEntry = serverLogger.getLogs.head
val logs = serverLogger.getLogs
val logEntry = logs(0)
debug(logEntry)
logEntry shouldContain ("server_name" -> "log-test-server")
logEntry shouldContain ("custom_log_entry" -> "test")
logEntry shouldContain ("user" -> "xxxx_yyyy")

test("do not set TLS in the second request") {
syncClient.send(Http.POST("/wvlet.airframe.http.netty.NettyLoggingTest.MyRPC/hello"))
val l = serverLogger.getLogs(1)
debug(l)
l shouldNotContain ("user" -> "xxxx_yyyy")
}

test("add client custom log") {
val clientLogEntry = clientLogger.getLogs.head
debug(clientLogEntry)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/
package wvlet.airframe.http.internal

import wvlet.airframe.http.{RPCContext, EmptyRPCContext}
import wvlet.airframe.http.{EmptyRPCContext, RPCContext}

object LocalRPCContext {
private val localContext = new ThreadLocal[RPCContext]()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.http.internal

import scala.collection.mutable

/**
* Thread-local storage support
*/
private[http] trait TLSSupport {
private lazy val tls = ThreadLocal.withInitial[mutable.Map[String, Any]](() => mutable.Map.empty[String, Any])
private def tlsStorage(): mutable.Map[String, Any] = tls.get()

def setTLS(key: String, value: Any): Unit = tlsStorage().put(key, value)
def getTLS(key: String): Option[Any] = tlsStorage().get(key)
}
20 changes: 16 additions & 4 deletions airframe-http/src/main/scala/wvlet/airframe/http/RPCContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ trait RPCContext {
def httpRequest: HttpMessage.Request

def rpcCallContext: Option[RPCCallContext] = {
getThreadLocal[RPCCallContext](HttpBackend.TLS_KEY_RPC)
getThreadLocal(HttpBackend.TLS_KEY_RPC) match {
case Some(c: RPCCallContext) => Some(c)
case _ => None
}
}

/**
Expand All @@ -52,10 +55,19 @@ trait RPCContext {
* Get a thread-local variable that is available only within the request scope. The type must be specified
* explicitly.
* @param key
* @tparam A
* @return
*/
def getThreadLocal[A](key: String): Option[A]
@deprecated("Use getThreadLocal(key: String): Any instead", "24.5.0")
def getThreadLocalUnsafe[A](key: String): Option[A] = {
getThreadLocal(key).map(_.asInstanceOf[A])
}

/**
* Get a thread-local variable that is available only within the request scope.
* @param key
* @return
*/
def getThreadLocal(key: String): Option[Any]
}

/**
Expand All @@ -65,7 +77,7 @@ object EmptyRPCContext extends RPCContext {
override def setThreadLocal[A](key: String, value: A): Unit = {
// no-op
}
override def getThreadLocal[A](key: String): Option[A] = {
override def getThreadLocal(key: String): Option[Any] = {
// no-op
None
}
Expand Down
6 changes: 3 additions & 3 deletions docs/airframe-http.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ val server = Netty.server
// Add a custom log entry
m += "application_version" -> "1.0"
// Add a thread-local parameter to the log
RPCContext.current.getThreadLocal[String]("user_id").map { uid =>
RPCContext.current.getThreadLocal("user_id").map { uid =>
m += "user_id" -> uid
}
m.result
}
// [optional] Disable server-side logging (log/http_server.json)
.noLogging
// Add a custom MessageCodec mapping
// [optional] Add a custom MessageCodec mapping
.withCustomCodec{ case s: Surface.of[MyClass] => ... }

server.start { server =>
Expand Down Expand Up @@ -372,7 +372,7 @@ object AuthLogFilter extends RxHttpFilter with LogSupport {
def apply(request: Request, next: RxHttpEndpoint): Rx[Response] = {
next(request).map { response =>
// Read the thread-local parameter set in the context(request)
RPCContext.current.getThreadLocal[String]("user_id").map { uid =>
RPCContext.current.getThreadLocal("user_id").map { uid =>
info(s"user_id: ${uid}")
}
response
Expand Down
4 changes: 2 additions & 2 deletions docs/airframe-rpc.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ String "100" will be translated into an Int value `100` automatically.

### RPCContext

Since Airframe 22.8.0, airframe-rpc introduced `RPCContext` for reading and writing the thread-local storage, and referencing the original HTTP request:
Since Airframe 22.8.0, airframe-rpc introduced `RPCContext.current` for reading and writing the thread-local storage, and referencing the original HTTP request:

```scala
import wvlet.airframe.http._
Expand All @@ -456,7 +456,7 @@ import wvlet.airframe.http._
trait MyAPI {
def hello: String = {
// Read the thread-local storage
val userName = RPCContext.current.getThreadLocal[String]("context_user")
val userName = RPCContext.current.getThreadLocal("context_user")
s"Hello ${userName}"
}

Expand Down

0 comments on commit f3aaaa5

Please sign in to comment.