Skip to content
This repository has been archived by the owner on Oct 30, 2023. It is now read-only.

Commit

Permalink
[PROXY-443] Externalize ssl handshake timeout (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmykh-vgs authored Jul 19, 2023
1 parent 854c4ea commit 916dd57
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 11 deletions.
5 changes: 5 additions & 0 deletions src/main/java/org/littleshoot/proxy/HttpProxyServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ public interface HttpProxyServer {
*/
int getConnectTimeout();

/**
* Returns the ssl handshake timeout in milliseconds.
*/
int getSslHandshakeTimeout();

/**
* Sets the maximum time to wait, in milliseconds, to connect to a server.
*/
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/org/littleshoot/proxy/HttpProxyServerBootstrap.java
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,21 @@ HttpProxyServerBootstrap withIdleConnectionTimeout(
HttpProxyServerBootstrap withConnectTimeout(
int connectTimeout);

/**
* <p>
* Specify the timeout for ssl handshakes in milliseconds.
* </p>
*
* <p>
* Default = 10000
* </p>
*
* @param sslHandshakeTimeout
* @return
*/
HttpProxyServerBootstrap withSslHandshakeTimeout(
int sslHandshakeTimeout);

/**
* Specify a custom {@link HostResolver} for resolving server addresses.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
*/
public class ClientToProxyConnection extends ProxyConnection<HttpRequest> {
public static final AttributeKey<Integer> CONNECT_TIMEOUT_MILLIS = AttributeKey.valueOf("connectTimeoutMillis");
public static final AttributeKey<Integer> SSL_HANDSHAKE_TIMEOUT_MILLIS = AttributeKey.valueOf("sslHandshakeTimeoutMillis");

private static final HttpResponseStatus CONNECTION_ESTABLISHED = new HttpResponseStatus(
200, "Connection established");
Expand Down Expand Up @@ -155,7 +156,7 @@ public class ClientToProxyConnection extends ProxyConnection<HttpRequest> {
Channel ch,
GlobalTrafficShapingHandler globalTrafficShapingHandler,
boolean separateProcessingEventLoop) {
super(AWAITING_INITIAL, proxyServer, false);
super(AWAITING_INITIAL, proxyServer, false, proxyServer.getSslHandshakeTimeout());
this.channel = ch;
this.separateProcessingEventLoop = separateProcessingEventLoop;

Expand Down Expand Up @@ -321,14 +322,20 @@ private ConnectionState doReadHTTPInitial(HttpRequest httpRequest) {
)
.orElse(proxyServer.getConnectTimeout());

int sslHandshakeTimeoutMillis = Optional.ofNullable(
channel.attr(SSL_HANDSHAKE_TIMEOUT_MILLIS).get()
)
.orElse(proxyServer.getSslHandshakeTimeout());

currentServerConnection = ProxyToServerConnection.create(
proxyServer,
this,
serverHostAndPort,
currentFilters,
httpRequest,
globalTrafficShapingHandler,
connectTimeoutMillis);
connectTimeoutMillis,
sslHandshakeTimeoutMillis);
if (currentServerConnection == null) {
LOG.debug("Unable to create server connection, probably no chained proxies available");
boolean keepAlive = writeBadGateway(httpRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ public class DefaultHttpProxyServer implements HttpProxyServer {
private final FailureHttpResponseComposer unrecoverableFailureHttpResponseComposer;
private final boolean transparent;
private volatile int connectTimeout;
private volatile int sslHandshakeTimeout;
private volatile int idleConnectionTimeout;
private final HostResolver serverResolver;
private volatile GlobalTrafficShapingHandler globalTrafficShapingHandler;
Expand Down Expand Up @@ -272,6 +273,7 @@ private DefaultHttpProxyServer(ServerGroup serverGroup,
int idleConnectionTimeout,
Collection<ActivityTracker> activityTrackers,
int connectTimeout,
int sslHandshakeTimeout,
HostResolver serverResolver,
long readThrottleBytesPerSecond,
long writeThrottleBytesPerSecond,
Expand Down Expand Up @@ -305,6 +307,7 @@ private DefaultHttpProxyServer(ServerGroup serverGroup,
this.activityTrackers.addAll(activityTrackers);
}
this.connectTimeout = connectTimeout;
this.sslHandshakeTimeout = sslHandshakeTimeout;
this.serverResolver = serverResolver;

if (writeThrottleBytesPerSecond > 0 || readThrottleBytesPerSecond > 0) {
Expand Down Expand Up @@ -371,6 +374,11 @@ public int getConnectTimeout() {
return connectTimeout;
}

@Override
public int getSslHandshakeTimeout() {
return sslHandshakeTimeout;
}

@Override
public void setConnectTimeout(int connectTimeoutMs) {
this.connectTimeout = connectTimeoutMs;
Expand Down Expand Up @@ -709,6 +717,7 @@ private static class DefaultHttpProxyServerBootstrap implements HttpProxyServerB
private int idleConnectionTimeout = 70;
private Collection<ActivityTracker> activityTrackers = new ConcurrentLinkedQueue<ActivityTracker>();
private int connectTimeout = 40000;
private int sslHandshakeTimeout = 10000;
private HostResolver serverResolver = new DefaultHostResolver();
private long readThrottleBytesPerSecond;
private long writeThrottleBytesPerSecond;
Expand Down Expand Up @@ -980,7 +989,13 @@ public HttpProxyServerBootstrap withConnectTimeout(
return this;
}

@Override
@Override
public HttpProxyServerBootstrap withSslHandshakeTimeout(int sslHandshakeTimeout) {
this.sslHandshakeTimeout = sslHandshakeTimeout;
return this;
}

@Override
public HttpProxyServerBootstrap withServerResolver(
HostResolver serverResolver) {
this.serverResolver = serverResolver;
Expand Down Expand Up @@ -1087,7 +1102,7 @@ transportProtocol, determineListenAddress(),
proxyAuthenticator, chainProxyManager, mitmManagerFactory,
clientToProxyExHandler, proxyToServerExHandler, requestTracer, globalStateHandler,
filtersSource, unrecoverableFailureHttpResponseComposer, transparent,
idleConnectionTimeout, activityTrackers, connectTimeout,
idleConnectionTimeout, activityTrackers, connectTimeout, sslHandshakeTimeout,
serverResolver, readThrottleBytesPerSecond, writeThrottleBytesPerSecond,
localAddress, proxyAlias, maxInitialLineLength, maxHeaderSize, maxChunkSize,
allowRequestToOriginServer, rateLimiter, httpPipeliningBlocked,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ abstract class ProxyConnection<I extends HttpObject> extends
*/
protected volatile SSLEngine sslEngine;

private final int sslHandshakeTimeoutMillis;

/**
* Construct a new ProxyConnection.
*
Expand All @@ -94,10 +96,12 @@ abstract class ProxyConnection<I extends HttpObject> extends
*/
protected ProxyConnection(ConnectionState initialState,
DefaultHttpProxyServer proxyServer,
boolean runsAsSslClient) {
boolean runsAsSslClient,
int sslHandshakeTimeoutMillis) {
become(initialState);
this.proxyServer = proxyServer;
this.runsAsSslClient = runsAsSslClient;
this.sslHandshakeTimeoutMillis = sslHandshakeTimeoutMillis;
}

/***************************************************************************
Expand Down Expand Up @@ -379,6 +383,7 @@ protected Future<Channel> encrypt(ChannelPipeline pipeline,
channel.config().setAutoRead(true);
}
SslHandler handler = new SslHandler(sslEngine);
handler.setHandshakeTimeoutMillis(sslHandshakeTimeoutMillis);
if(pipeline.get("ssl") == null) {
pipeline.addFirst("ssl", handler);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ static ProxyToServerConnection create(DefaultHttpProxyServer proxyServer,
HttpFilters initialFilters,
HttpRequest initialHttpRequest,
GlobalTrafficShapingHandler globalTrafficShapingHandler,
int connectTimeoutMillis)
int connectTimeoutMillis,
int sslHandshakeTimeoutMillis)
throws UnknownHostException {
Queue<ChainedProxy> chainedProxies = new ConcurrentLinkedQueue<ChainedProxy>();
ChainedProxyManager chainedProxyManager = proxyServer
Expand All @@ -182,7 +183,8 @@ static ProxyToServerConnection create(DefaultHttpProxyServer proxyServer,
chainedProxies,
initialFilters,
globalTrafficShapingHandler,
connectTimeoutMillis);
connectTimeoutMillis,
sslHandshakeTimeoutMillis);
}

private ProxyToServerConnection(
Expand All @@ -193,9 +195,10 @@ private ProxyToServerConnection(
Queue<ChainedProxy> availableChainedProxies,
HttpFilters initialFilters,
GlobalTrafficShapingHandler globalTrafficShapingHandler,
int connectTimeoutMillis)
int connectTimeoutMillis,
int sslHandshakeTimeoutMillis)
throws UnknownHostException {
super(DISCONNECTED, proxyServer, true);
super(DISCONNECTED, proxyServer, true, sslHandshakeTimeoutMillis);
this.clientConnection = clientConnection;
this.serverHostAndPort = serverHostAndPort;
this.chainedProxy = chainedProxy;
Expand Down
135 changes: 133 additions & 2 deletions src/test/java/org/littleshoot/proxy/TimeoutTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.HttpRequest;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.http.HttpHost;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
Expand All @@ -11,6 +17,7 @@
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.littleshoot.proxy.extras.SelfSignedMitmManagerFactory;
import org.littleshoot.proxy.impl.ClientToProxyConnection;
import org.littleshoot.proxy.impl.DefaultHttpProxyServer;
import org.littleshoot.proxy.test.SocketClientUtil;
Expand All @@ -30,7 +37,6 @@
import static org.mockserver.model.HttpResponse.response;

public class TimeoutTest {

private static final String UNUSED_URI_FOR_BAD_GATEWAY = "http://1.2.3.6:53540";

private ClientAndServer mockServer;
Expand Down Expand Up @@ -114,7 +120,7 @@ public void testConnectionTimeout() throws IOException {
}

@Test
public void testConnectTimeoutChannelAttribute() throws IOException {
public void testConnectTimeoutChannelAttributeOverride() throws IOException {
proxyServer = DefaultHttpProxyServer.bootstrap()
.withPort(0)
.withConnectTimeout(5000)
Expand Down Expand Up @@ -170,4 +176,129 @@ public void testClientIdleBeforeRequestReceived() throws IOException, Interrupte

socket.close();
}

@Test
public void testSslHandshakeTimeout() throws IOException, InterruptedException {
String idleServerHost = "127.0.0.1";
int idleServerPort = 10002;
String idleServerUri = String.format("https://%s:%d", idleServerHost, idleServerPort);

ExecutorService executor = Executors.newFixedThreadPool(1);
IdleServer upstream = new IdleServer(InetAddress.getByName(idleServerHost), idleServerPort);
executor.execute(upstream::start);

proxyServer = DefaultHttpProxyServer.bootstrap()
.withPort(0)
.withSslHandshakeTimeout(1000)
.withManInTheMiddle(new SelfSignedMitmManagerFactory())
.start();

DefaultHttpClient httpClient = new DefaultHttpClient();
final HttpHost proxy = new HttpHost("127.0.0.1", proxyServer.getListenAddress().getPort(), "http");
httpClient.getParams().setParameter(ConnRoutePNames.DEFAULT_PROXY, proxy);

HttpGet get = new HttpGet(idleServerUri);

long start = System.nanoTime();
HttpResponse response = httpClient.execute(get);
long stop = System.nanoTime();

EntityUtils.consumeQuietly(response.getEntity());

upstream.stop();

assertEquals("Expected to receive an HTTP 502 (Bad Gateway) response after proxy could not establish an SSL session within 1 second", 502, response.getStatusLine().getStatusCode());
assertThat("Expected SSL handshake timeout to happen after approximately 1 second",
TimeUnit.MILLISECONDS.convert(stop - start, TimeUnit.NANOSECONDS), lessThan(2000L));
}

@Test
public void testSslHandshakeTimeoutChannelAttributeOverride() throws IOException, InterruptedException {
String idleServerHost = "127.0.0.1";
int idleServerPort = 10003;
String idleServerUri = String.format("https://%s:%d", idleServerHost, idleServerPort);

ExecutorService executor = Executors.newFixedThreadPool(1);
IdleServer upstream = new IdleServer(InetAddress.getByName(idleServerHost), idleServerPort);
executor.execute(upstream::start);

proxyServer = DefaultHttpProxyServer.bootstrap()
.withPort(0)
.withSslHandshakeTimeout(5000)
.withFiltersSource(new HttpFiltersSourceAdapter() {
@Override
public HttpFilters filterRequest(HttpRequest originalRequest, ChannelHandlerContext ctx) {
ctx.channel().attr(ClientToProxyConnection.SSL_HANDSHAKE_TIMEOUT_MILLIS).set(1000);
return super.filterRequest(originalRequest, ctx);
}
})
.withManInTheMiddle(new SelfSignedMitmManagerFactory())
.start();

DefaultHttpClient httpClient = new DefaultHttpClient();
final HttpHost proxy = new HttpHost("127.0.0.1", proxyServer.getListenAddress().getPort(), "http");
httpClient.getParams().setParameter(ConnRoutePNames.DEFAULT_PROXY, proxy);

HttpGet get = new HttpGet(idleServerUri);

long start = System.nanoTime();
HttpResponse response = httpClient.execute(get);
long stop = System.nanoTime();

EntityUtils.consumeQuietly(response.getEntity());

upstream.stop();

assertEquals("Expected to receive an HTTP 502 (Bad Gateway) response after proxy could not establish an SSL session within 1 second", 502, response.getStatusLine().getStatusCode());
assertThat("Expected SSL handshake timeout to happen after approximately 1 second",
TimeUnit.MILLISECONDS.convert(stop - start, TimeUnit.NANOSECONDS), lessThan(2000L));
}

private static final class IdleServer {
private final InetAddress host;
private final int port;
private final Deque<Socket> connections = new ArrayDeque<>();
private volatile ServerSocket serverSocket;

public IdleServer(InetAddress host, int port) {
this.host = host;
this.port = port;
}

public void start() {
if (serverSocket != null) {
return;
}
try {
this.serverSocket = new ServerSocket(port, 100, host);
while (serverSocket != null) {
Socket connection = serverSocket.accept();
connections.add(connection);
System.out.println("Accepted a new connection: " + connection);
}
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
closeAllConnections();
System.out.println("Idle server stopped.");
}
}

public void stop() throws IOException {
serverSocket.close();
serverSocket = null;
System.out.println("Stopping idle server...");
}

private void closeAllConnections() {
while (!connections.isEmpty()) {
try {
Socket connection = connections.poll();
connection.close();
} catch (IOException e) {
System.out.println("Cannot close socket " + connections + " - ignoring");
}
}
}
}
}

0 comments on commit 916dd57

Please sign in to comment.