Skip to content

Commit

Permalink
fix: Rate limit protection against rapid resets (#4324) (#4325)
Browse files Browse the repository at this point in the history
  • Loading branch information
johanandren authored Oct 16, 2023
1 parent ece6aa9 commit 905559f
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 9 deletions.
5 changes: 5 additions & 0 deletions akka-http-core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ akka.http {
# Fail the connection if a sent ping is not acknowledged within this timeout.
# When zero the ping-interval is used, if set the value must be evenly divisible by less than or equal to the ping-interval.
ping-timeout = 0s

# Limit the number of RSTs a client is allowed to do on one connection, per interval
# Protects against rapid reset attacks. If a connection goes over the limit, it is closed with HTTP/2 protocol error ENHANCE_YOUR_CALM
max-resets = 400
max-resets-interval = 10s
}

websocket {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import akka.event.LoggingAdapter
import akka.http.impl.engine.{ HttpConnectionIdleTimeoutBidi, HttpIdleTimeoutException }
import akka.http.impl.engine.http2.FrameEvent._
import akka.http.impl.engine.http2.client.ResponseParsing
import akka.http.impl.engine.http2.framing.{ FrameRenderer, Http2FrameParsing }
import akka.http.impl.engine.http2.framing.{ FrameRenderer, Http2FrameParsing, RSTFrameLimit }
import akka.http.impl.engine.http2.hpack.{ HeaderCompression, HeaderDecompression }
import akka.http.impl.engine.parsing.HttpHeaderParser
import akka.http.impl.engine.rendering.DateHeaderRendering
Expand Down Expand Up @@ -108,7 +108,7 @@ private[http] object Http2Blueprint {
serverDemux(settings.http2Settings, initialDemuxerSettings, upgraded) atop
FrameLogger.logFramesIfEnabled(settings.http2Settings.logFrames) atop // enable for debugging
hpackCoding(masterHttpHeaderParser, settings.parserSettings) atop
framing(log) atop
framing(settings.http2Settings, log) atop
errorHandling(log) atop
idleTimeoutIfConfigured(settings.idleTimeout)
}
Expand Down Expand Up @@ -168,10 +168,12 @@ private[http] object Http2Blueprint {
Flow[ByteString]
)

def framing(log: LoggingAdapter): BidiFlow[FrameEvent, ByteString, ByteString, FrameEvent, NotUsed] =
def framing(http2ServerSettings: Http2ServerSettings, log: LoggingAdapter): BidiFlow[FrameEvent, ByteString, ByteString, FrameEvent, NotUsed] =
BidiFlow.fromFlows(
Flow[FrameEvent].map(FrameRenderer.render),
Flow[ByteString].via(new Http2FrameParsing(shouldReadPreface = true, log)))
Flow[ByteString].via(new Http2FrameParsing(shouldReadPreface = true, log))
.via(new RSTFrameLimit(http2ServerSettings))
)

def framingClient(log: LoggingAdapter): BidiFlow[FrameEvent, ByteString, ByteString, FrameEvent, NotUsed] =
BidiFlow.fromFlows(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (C) 2023 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.http.impl.engine.http2.framing

import akka.annotation.InternalApi
import akka.http.impl.engine.http2.{ FrameEvent, Http2Compliance }
import akka.http.impl.engine.http2.FrameEvent.RstStreamFrame
import akka.http.impl.engine.http2.Http2Protocol.ErrorCode
import akka.http.scaladsl.settings.Http2ServerSettings
import akka.stream.{ Attributes, FlowShape, Inlet, Outlet }
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }

/**
* INTERNAL API
*/
@InternalApi
private[akka] final class RSTFrameLimit(http2ServerSettings: Http2ServerSettings) extends GraphStage[FlowShape[FrameEvent, FrameEvent]] {

private val maxResets = http2ServerSettings.maxResets
private val maxResetsIntervalNanos = http2ServerSettings.maxResetsInterval.toNanos

val in = Inlet[FrameEvent]("in")
val out = Outlet[FrameEvent]("out")
val shape = FlowShape(in, out)

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
private var rstSeen = false
private var rstCount = 0
private var rstSpanStartNanos = 0L

setHandlers(in, out, this)

override def onPush(): Unit = {
grab(in) match {
case frame: RstStreamFrame =>
rstCount += 1
val now = System.nanoTime()
if (!rstSeen) {
rstSeen = true
rstSpanStartNanos = now
push(out, frame)
} else if ((now - rstSpanStartNanos) <= maxResetsIntervalNanos) {
if (rstCount > maxResets) {
failStage(new Http2Compliance.Http2ProtocolException(
ErrorCode.ENHANCE_YOUR_CALM,
s"Too many RST frames per second for this connection. (Configured limit ${maxResets}/${http2ServerSettings.maxResetsInterval.toCoarsest})"))
} else {
push(out, frame)
}
} else {
// outside time window, reset counter
rstCount = 1
rstSpanStartNanos = now
push(out, frame)
}

case frame =>
push(out, frame)
}
}

override def onPull(): Unit = pull(in)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ trait Http2ServerSettings { self: scaladsl.settings.Http2ServerSettings with akk

def getPingTimeout: Duration = Duration.ofMillis(pingTimeout.toMillis)
def withPingTimeout(timeout: Duration): Http2ServerSettings = withPingTimeout(timeout.toMillis.millis)

def maxResets: Int

def withMaxResets(n: Int): Http2ServerSettings = copy(maxResets = n)

def getMaxResetsInterval: Duration = Duration.ofMillis(maxResetsInterval.toMillis)

def withMaxResetsInterval(interval: Duration): Http2ServerSettings = copy(maxResetsInterval = interval.toMillis.millis)

}
object Http2ServerSettings extends SettingsCompanion[Http2ServerSettings] {
def create(config: Config): Http2ServerSettings = scaladsl.settings.Http2ServerSettings(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ trait Http2ServerSettings extends javadsl.settings.Http2ServerSettings with Http
def pingTimeout: FiniteDuration
def withPingTimeout(timeout: FiniteDuration): Http2ServerSettings = copy(pingTimeout = timeout)

def maxResets: Int

override def withMaxResets(n: Int): Http2ServerSettings = copy(maxResets = n)

def maxResetsInterval: FiniteDuration

def withMaxResetsInterval(interval: FiniteDuration): Http2ServerSettings = copy(maxResetsInterval = interval)

@InternalApi
private[http] def internalSettings: Option[Http2InternalServerSettings]
@InternalApi
Expand All @@ -110,7 +118,10 @@ object Http2ServerSettings extends SettingsCompanion[Http2ServerSettings] {
logFrames: Boolean,
pingInterval: FiniteDuration,
pingTimeout: FiniteDuration,
internalSettings: Option[Http2InternalServerSettings])
maxResets: Int,
maxResetsInterval: FiniteDuration,
internalSettings: Option[Http2InternalServerSettings]
)
extends Http2ServerSettings {
require(maxConcurrentStreams >= 0, "max-concurrent-streams must be >= 0")
require(requestEntityChunkSize > 0, "request-entity-chunk-size must be > 0")
Expand All @@ -134,7 +145,9 @@ object Http2ServerSettings extends SettingsCompanion[Http2ServerSettings] {
logFrames = c.getBoolean("log-frames"),
pingInterval = c.getFiniteDuration("ping-interval"),
pingTimeout = c.getFiniteDuration("ping-timeout"),
None // no possibility to configure internal settings with config
maxResets = c.getInt("max-resets"),
maxResetsInterval = c.getFiniteDuration("max-resets-interval"),
internalSettings = None, // no possibility to configure internal settings with config
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import akka.http.impl.engine.http2.Http2Protocol.ErrorCode
import akka.http.impl.engine.http2.Http2Protocol.Flags
import akka.http.impl.engine.http2.Http2Protocol.FrameType
import akka.http.impl.engine.http2.Http2Protocol.SettingIdentifier
import akka.http.impl.engine.http2.framing.FrameRenderer
import akka.http.impl.engine.server.{ HttpAttributes, ServerTerminator }
import akka.http.impl.engine.ws.ByteStringSinkProbe
import akka.http.impl.util.AkkaSpecWithMaterializer
Expand All @@ -22,28 +23,29 @@ import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.CacheDirectives
import akka.http.scaladsl.model.headers.RawHeader
import akka.http.scaladsl.settings.ServerSettings
import akka.stream.Attributes
import akka.stream.{ Attributes, DelayOverflowStrategy, OverflowStrategy }
import akka.stream.Attributes.LogLevels
import akka.stream.OverflowStrategy
import akka.stream.scaladsl.{ BidiFlow, Flow, Keep, Sink, Source, SourceQueueWithComplete }
import akka.stream.testkit.TestPublisher.{ ManualProbe, Probe }
import akka.stream.testkit.scaladsl.StreamTestKit
import akka.stream.testkit.TestPublisher
import akka.stream.testkit.TestSubscriber
import akka.testkit._
import akka.util.ByteString
import akka.util.{ ByteString, ByteStringBuilder }

import scala.annotation.nowarn
import javax.net.ssl.SSLContext
import org.scalatest.concurrent.Eventually
import org.scalatest.concurrent.PatienceConfiguration.Timeout

import java.nio.ByteOrder
import scala.collection.immutable
import scala.concurrent.duration._
import scala.concurrent.Await
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.util.Success

/**
* This tests the http2 server protocol logic.
Expand Down Expand Up @@ -1686,6 +1688,31 @@ class Http2ServerSpec extends AkkaSpecWithMaterializer("""
terminated.futureValue
}
}

"not allow high a frequency of resets for one connection" in StreamTestKit.assertAllStagesStopped(new TestSetup {

override def settings: ServerSettings = super.settings.withHttp2Settings(super.settings.http2Settings.withMaxResets(100).withMaxResetsInterval(2.seconds))

// covers CVE-2023-44487 with a rapid sequence of RSTs
override def handlerFlow: Flow[HttpRequest, HttpResponse, NotUsed] = Flow[HttpRequest].buffer(1000, OverflowStrategy.backpressure).mapAsync(300) { req =>
// never actually reached since rst is in headers
req.entity.discardBytes()
Future.successful(HttpResponse(entity = "Ok").withAttributes(req.attributes))
}

network.toNet.request(100000L)
val request = HttpRequest(protocol = HttpProtocols.`HTTP/2.0`, uri = "/foo")
val error = intercept[AssertionError] {
for (streamId <- 1 to 300 by 2) {
network.sendBytes(
FrameRenderer.render(HeadersFrame(streamId, true, true, network.encodeRequestHeaders(request), None))
++ FrameRenderer.render(RstStreamFrame(streamId, ErrorCode.CANCEL))
)
}
}
error.getMessage should include("Too many RST frames per second for this connection.")
network.toNet.cancel()
})
}

implicit class InWithStoppedStages(name: String) {
Expand Down

0 comments on commit 905559f

Please sign in to comment.