From 8e672fac5f8e9f7c718d360cfd86ba64dec1f958 Mon Sep 17 00:00:00 2001 From: Alexey Bakhtin Date: Sat, 21 Sep 2024 22:40:53 -0700 Subject: [PATCH] Backport f4b140b4200fc0f49161395501d3dbcba7a79059 --- .../classes/jdk/internal/net/http/Stream.java | 107 +++--- .../httpclient/http2/TrailingHeadersTest.java | 324 ++++++++++++++++++ .../http2/server/BodyOutputStream.java | 6 +- .../http2/server/Http2TestServer.java | 2 +- .../server/Http2TestServerConnection.java | 23 +- 5 files changed, 420 insertions(+), 42 deletions(-) create mode 100644 test/jdk/java/net/httpclient/http2/TrailingHeadersTest.java diff --git a/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java b/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java index 3acf54011e5..921dfbebf48 100644 --- a/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java +++ b/src/java.net.http/share/classes/jdk/internal/net/http/Stream.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -136,6 +136,7 @@ class Stream extends ExchangeImpl { private volatile boolean remotelyClosed; private volatile boolean closed; private volatile boolean endStreamSent; + private volatile boolean finalResponseCodeReceived; // Indicates the first reason that was invoked when sending a ResetFrame // to the server. A streamState of 0 indicates that no reset was sent. // (see markStream(int code) @@ -474,30 +475,44 @@ DecodingCallback rspHeadersConsumer() { protected void handleResponse() throws IOException { HttpHeaders responseHeaders = responseHeadersBuilder.build(); - responseCode = (int)responseHeaders - .firstValueAsLong(":status") - .orElseThrow(() -> new IOException("no statuscode in response")); - response = new Response( - request, exchange, responseHeaders, connection(), - responseCode, HttpClient.Version.HTTP_2); + if (!finalResponseCodeReceived) { + responseCode = (int) responseHeaders + .firstValueAsLong(":status") + .orElseThrow(() -> new IOException("no statuscode in response")); + // If informational code, response is partially complete + if (responseCode < 100 || responseCode > 199) + this.finalResponseCodeReceived = true; + + response = new Response( + request, exchange, responseHeaders, connection(), + responseCode, HttpClient.Version.HTTP_2); /* TODO: review if needs to be removed the value is not used, but in case `content-length` doesn't parse as long, there will be NumberFormatException. If left as is, make sure code up the stack handles NFE correctly. */ - responseHeaders.firstValueAsLong("content-length"); + responseHeaders.firstValueAsLong("content-length"); - if (Log.headers()) { - StringBuilder sb = new StringBuilder("RESPONSE HEADERS:\n"); - Log.dumpHeaders(sb, " ", responseHeaders); - Log.logHeaders(sb.toString()); - } + if (Log.headers()) { + StringBuilder sb = new StringBuilder("RESPONSE HEADERS:\n"); + Log.dumpHeaders(sb, " ", responseHeaders); + Log.logHeaders(sb.toString()); + } + + // this will clear the response headers + rspHeadersConsumer.reset(); - // this will clear the response headers - rspHeadersConsumer.reset(); + completeResponse(response); + } else { + if (Log.headers()) { + StringBuilder sb = new StringBuilder("TRAILING HEADERS:\n"); + Log.dumpHeaders(sb, " ", responseHeaders); + Log.logHeaders(sb.toString()); + } + rspHeadersConsumer.reset(); + } - completeResponse(response); } void incoming_reset(ResetFrame frame) { @@ -1313,6 +1328,7 @@ static class PushedStream extends Stream { CompletableFuture> responseCF; final HttpRequestImpl pushReq; HttpResponse.BodyHandler pushHandler; + private volatile boolean finalPushResponseCodeReceived; PushedStream(PushGroup pushGroup, Http2Connection connection, @@ -1409,35 +1425,48 @@ void completeResponseExceptionally(Throwable t) { @Override protected void handleResponse() { HttpHeaders responseHeaders = responseHeadersBuilder.build(); - responseCode = (int)responseHeaders - .firstValueAsLong(":status") - .orElse(-1); - if (responseCode == -1) { - completeResponseExceptionally(new IOException("No status code")); - } + if (!finalPushResponseCodeReceived) { + responseCode = (int)responseHeaders + .firstValueAsLong(":status") + .orElse(-1); - this.response = new Response( - pushReq, exchange, responseHeaders, connection(), - responseCode, HttpClient.Version.HTTP_2); + if (responseCode == -1) { + completeResponseExceptionally(new IOException("No status code")); + } - /* TODO: review if needs to be removed - the value is not used, but in case `content-length` doesn't parse - as long, there will be NumberFormatException. If left as is, make - sure code up the stack handles NFE correctly. */ - responseHeaders.firstValueAsLong("content-length"); + this.finalPushResponseCodeReceived = true; - if (Log.headers()) { - StringBuilder sb = new StringBuilder("RESPONSE HEADERS"); - sb.append(" (streamid=").append(streamid).append("):\n"); - Log.dumpHeaders(sb, " ", responseHeaders); - Log.logHeaders(sb.toString()); - } + this.response = new Response( + pushReq, exchange, responseHeaders, connection(), + responseCode, HttpClient.Version.HTTP_2); - rspHeadersConsumer.reset(); + /* TODO: review if needs to be removed + the value is not used, but in case `content-length` doesn't parse + as long, there will be NumberFormatException. If left as is, make + sure code up the stack handles NFE correctly. */ + responseHeaders.firstValueAsLong("content-length"); - // different implementations for normal streams and pushed streams - completeResponse(response); + if (Log.headers()) { + StringBuilder sb = new StringBuilder("RESPONSE HEADERS"); + sb.append(" (streamid=").append(streamid).append("):\n"); + Log.dumpHeaders(sb, " ", responseHeaders); + Log.logHeaders(sb.toString()); + } + + rspHeadersConsumer.reset(); + + // different implementations for normal streams and pushed streams + completeResponse(response); + } else { + if (Log.headers()) { + StringBuilder sb = new StringBuilder("TRAILING HEADERS"); + sb.append(" (streamid=").append(streamid).append("):\n"); + Log.dumpHeaders(sb, " ", responseHeaders); + Log.logHeaders(sb.toString()); + } + rspHeadersConsumer.reset(); + } } } diff --git a/test/jdk/java/net/httpclient/http2/TrailingHeadersTest.java b/test/jdk/java/net/httpclient/http2/TrailingHeadersTest.java new file mode 100644 index 00000000000..b8d766bb2ca --- /dev/null +++ b/test/jdk/java/net/httpclient/http2/TrailingHeadersTest.java @@ -0,0 +1,324 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + @test + * @summary Trailing headers should be ignored by the client when using HTTP/2 + * and not affect the rest of the exchange. + * @bug 8296410 + * @library server + * @build Http2TestServer + * @modules java.base/sun.net.www.http + * java.net.http/jdk.internal.net.http.common + * java.net.http/jdk.internal.net.http.frame + * java.net.http/jdk.internal.net.http.hpack + * @run testng/othervm -Djdk.httpclient.HttpClient.log=all TrailingHeadersTest + */ + +import jdk.internal.net.http.common.HttpHeadersBuilder; +import jdk.internal.net.http.frame.DataFrame; +import jdk.internal.net.http.frame.HeaderFrame; +import jdk.internal.net.http.frame.HeadersFrame; +import org.testng.TestException; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLSession; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.PrintStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandlers; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executors; +import java.util.function.BiPredicate; + +import static java.net.http.HttpClient.Version.HTTP_2; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; + +public class TrailingHeadersTest { + + Http2TestServer http2TestServer; + URI trailingURI, trailng1xxURI, trailingPushPromiseURI, warmupURI; + static PrintStream testLog = System.err; + + // Set up simple client-side push promise handler + ConcurrentMap>> pushPromiseMap = new ConcurrentHashMap<>(); + + @BeforeMethod + public void beforeMethod() { + pushPromiseMap = new ConcurrentHashMap<>(); + } + + @BeforeTest + public void setup() throws Exception { + Properties props = new Properties(); + // For triggering trailing headers to send after Push Promise Response headers are sent + props.setProperty("sendTrailingHeadersAfterPushPromise", "1"); + + http2TestServer = new Http2TestServer("Test_Server", + false, + 0, + null, + 0, + props, + null); + http2TestServer.setExchangeSupplier(TrailingHeadersExchange::new); + http2TestServer.addHandler(new ResponseTrailersHandler(), "/ResponseTrailingHeaders"); + http2TestServer.addHandler(new InformationalTrailersHandler(), "/InfoRespTrailingHeaders"); + http2TestServer.addHandler(new PushPromiseTrailersHandler(), "/PushPromiseTrailingHeaders"); + http2TestServer.addHandler(new WarmupHandler(), "/WarmupHandler"); + + http2TestServer.start(); + + trailingURI = URI.create("http://" + http2TestServer.serverAuthority() + "/ResponseTrailingHeaders"); + trailng1xxURI = URI.create("http://" + http2TestServer.serverAuthority() + "/InfoRespTrailingHeaders"); + trailingPushPromiseURI = URI.create("http://" + http2TestServer.serverAuthority() + "/PushPromiseTrailingHeaders"); + + // Used to ensure HTTP/2 upgrade takes place + warmupURI = URI.create("http://" + http2TestServer.serverAuthority() + "/WarmupHandler"); + } + + @AfterTest + public void teardown() { + http2TestServer.stop(); + } + + @Test(dataProvider = "httpRequests") + public void testTrailingHeaders(String description, HttpRequest hRequest, HttpResponse.PushPromiseHandler pph) { + testLog.println("testTrailingHeaders(): " + description); + HttpClient httpClient = HttpClient.newBuilder().build(); + performWarmupRequest(httpClient); + CompletableFuture> cf = httpClient.sendAsync(hRequest, BodyHandlers.ofString(UTF_8), pph); + + testLog.println("testTrailingHeaders(): Performing request: " + hRequest); + HttpResponse resp = cf.join(); + + assertEquals(resp.statusCode(), 200, "Status code of response should be 200"); + + // Verify Push Promise was successful if necessary + if (pph != null) + verifyPushPromise(); + + testLog.println("testTrailingHeaders(): Request successfully completed"); + } + + private void verifyPushPromise() { + assertEquals(pushPromiseMap.size(), 1, "Push Promise should not be greater than 1"); + // This will only iterate once + for (HttpRequest r : pushPromiseMap.keySet()) { + CompletableFuture> serverPushResp = pushPromiseMap.get(r); + // Get the push promise HttpResponse result if present + HttpResponse resp = serverPushResp.join(); + assertEquals(resp.body(), "Sample_Push_Data", "Unexpected Push Promise response body"); + assertEquals(resp.statusCode(), 200, "Status code of Push Promise response should be 200"); + } + } + + private void performWarmupRequest(HttpClient httpClient) { + HttpRequest warmupReq = HttpRequest.newBuilder(warmupURI).version(HTTP_2) + .GET() + .build(); + httpClient.sendAsync(warmupReq, BodyHandlers.discarding()).join(); + } + + @DataProvider(name = "httpRequests") + public Object[][] uris() { + HttpResponse.PushPromiseHandler pph = (initial, pushRequest, acceptor) -> { + HttpResponse.BodyHandler s = HttpResponse.BodyHandlers.ofString(UTF_8); + pushPromiseMap.put(pushRequest, acceptor.apply(s)); + }; + + HttpRequest httpGetTrailing = HttpRequest.newBuilder(trailingURI).version(HTTP_2) + .GET() + .build(); + + HttpRequest httpPost1xxTrailing = HttpRequest.newBuilder(trailng1xxURI).version(HTTP_2) + .POST(HttpRequest.BodyPublishers.ofString("Test Post")) + .expectContinue(true) + .build(); + + HttpRequest httpGetPushPromiseTrailing = HttpRequest.newBuilder(trailingPushPromiseURI).version(HTTP_2) + .GET() + .build(); + + return new Object[][] { + { "Test GET with Trailing Headers", httpGetTrailing, null }, + { "Test POST with 1xx response & Trailing Headers", httpPost1xxTrailing, null }, + { "Test Push Promise with Trailing Headers", httpGetPushPromiseTrailing, pph } + }; + } + + static class TrailingHeadersExchange extends Http2TestExchangeImpl { + + byte[] resp = "Sample_Data".getBytes(StandardCharsets.UTF_8); + + + TrailingHeadersExchange(int streamid, String method, HttpHeaders reqheaders, HttpHeadersBuilder rspheadersBuilder, + URI uri, InputStream is, SSLSession sslSession, BodyOutputStream os, + Http2TestServerConnection conn, boolean pushAllowed) { + super(streamid, method, reqheaders, rspheadersBuilder, uri, is, sslSession, os, conn, pushAllowed); + } + + public void sendResponseThenTrailers() throws IOException { + /* + HttpHeadersBuilder hb = this.conn.createNewHeadersBuilder(); + hb.setHeader("x-sample", "val"); + HeaderFrame headerFrame = new HeadersFrame(this.streamid, 0, this.conn.encodeHeaders(hb.build())); + */ + // TODO: see if there is a safe way to encode headers without interrupting connection thread + HeaderFrame headerFrame = new HeadersFrame(this.streamid, 0, List.of()); + headerFrame.setFlag(HeaderFrame.END_HEADERS); + headerFrame.setFlag(HeaderFrame.END_STREAM); + + this.sendResponseHeaders(200, resp.length); + DataFrame dataFrame = new DataFrame(this.streamid, 0, ByteBuffer.wrap(resp)); + this.conn.addToOutputQ(dataFrame); + this.conn.addToOutputQ(headerFrame); + } + + @Override + public void serverPush(URI uri, HttpHeaders headers, InputStream content) { + HttpHeadersBuilder headersBuilder = new HttpHeadersBuilder(); + headersBuilder.setHeader(":method", "GET"); + headersBuilder.setHeader(":scheme", uri.getScheme()); + headersBuilder.setHeader(":authority", uri.getAuthority()); + headersBuilder.setHeader(":path", uri.getPath()); + for (Map.Entry> entry : headers.map().entrySet()) { + for (String value : entry.getValue()) + headersBuilder.addHeader(entry.getKey(), value); + } + HttpHeaders combinedHeaders = headersBuilder.build(); + OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, combinedHeaders, content); + pp.setFlag(HeaderFrame.END_HEADERS); + + try { + this.conn.addToOutputQ(pp); + } catch (IOException ex) { + testLog.println("serverPush(): pushPromise exception: " + ex); + } + } + } + + static class WarmupHandler implements Http2Handler { + + @Override + public void handle(Http2TestExchange exchange) throws IOException { + exchange.sendResponseHeaders(200, 0); + } + } + + static class ResponseTrailersHandler implements Http2Handler { + + @Override + public void handle(Http2TestExchange exchange) throws IOException { + if (exchange.getProtocol().equals("HTTP/2")) { + if (exchange instanceof TrailingHeadersExchange) { + TrailingHeadersExchange trailingHeadersExchange = (TrailingHeadersExchange)exchange; + trailingHeadersExchange.sendResponseThenTrailers(); + } + } else { + testLog.println("ResponseTrailersHandler: Incorrect protocol version"); + exchange.sendResponseHeaders(400, 0); + } + } + } + + static class InformationalTrailersHandler implements Http2Handler { + + @Override + public void handle(Http2TestExchange exchange) throws IOException { + if (exchange.getProtocol().equals("HTTP/2")) { + if (exchange instanceof TrailingHeadersExchange) { + TrailingHeadersExchange trailingHeadersExchange = (TrailingHeadersExchange)exchange; + testLog.println(this.getClass().getCanonicalName() + ": Sending status 100"); + trailingHeadersExchange.sendResponseHeaders(100, 0); + + try (InputStream is = exchange.getRequestBody()) { + is.readAllBytes(); + trailingHeadersExchange.sendResponseThenTrailers(); + } + } + } else { + testLog.println("InformationalTrailersHandler: Incorrect protocol version"); + exchange.sendResponseHeaders(400, 0); + } + } + } + + static class PushPromiseTrailersHandler implements Http2Handler { + + @Override + public void handle(Http2TestExchange exchange) throws IOException { + if (exchange.getProtocol().equals("HTTP/2")) { + if (exchange instanceof TrailingHeadersExchange) { + TrailingHeadersExchange trailingHeadersExchange = (TrailingHeadersExchange)exchange; + try (InputStream is = exchange.getRequestBody()) { + is.readAllBytes(); + } + + if (exchange.serverPushAllowed()) { + pushPromise(trailingHeadersExchange); + } + + try (OutputStream os = trailingHeadersExchange.getResponseBody()) { + byte[] bytes = "Sample_Data".getBytes(UTF_8); + trailingHeadersExchange.sendResponseHeaders(200, bytes.length); + os.write(bytes); + } + } + } + } + + static final BiPredicate ACCEPT_ALL = (x, y) -> true; + + private void pushPromise(Http2TestExchange exchange) { + URI requestURI = exchange.getRequestURI(); + URI uri = requestURI.resolve("/promise"); + InputStream is = new ByteArrayInputStream("Sample_Push_Data".getBytes(UTF_8)); + Map> map = new HashMap<>(); + map.put("x-promise", List.of("promise-header")); + HttpHeaders headers = HttpHeaders.of(map, ACCEPT_ALL); + exchange.serverPush(uri, headers, is); + testLog.println("PushPromiseTrailersHandler: Push Promise complete"); + } + } +} \ No newline at end of file diff --git a/test/jdk/java/net/httpclient/http2/server/BodyOutputStream.java b/test/jdk/java/net/httpclient/http2/server/BodyOutputStream.java index 008d9bdffc1..d08495e709d 100644 --- a/test/jdk/java/net/httpclient/http2/server/BodyOutputStream.java +++ b/test/jdk/java/net/httpclient/http2/server/BodyOutputStream.java @@ -128,10 +128,14 @@ public void close() { closed = true; } try { - send(EMPTY_BARRAY, 0, 0, DataFrame.END_STREAM); + sendEndStream(); } catch (IOException ex) { System.err.println("TestServer: OutputStream.close exception: " + ex); ex.printStackTrace(); } } + + protected void sendEndStream() throws IOException { + send(EMPTY_BARRAY, 0, 0, DataFrame.END_STREAM); + } } diff --git a/test/jdk/java/net/httpclient/http2/server/Http2TestServer.java b/test/jdk/java/net/httpclient/http2/server/Http2TestServer.java index 33e5ade4f83..e4388901070 100644 --- a/test/jdk/java/net/httpclient/http2/server/Http2TestServer.java +++ b/test/jdk/java/net/httpclient/http2/server/Http2TestServer.java @@ -171,7 +171,7 @@ public Http2TestServer(String serverName, this.secure = secure; this.exec = exec == null ? getDefaultExecutor() : exec; this.handlers = Collections.synchronizedMap(new HashMap<>()); - this.properties = properties; + this.properties = properties == null ? new Properties() : properties; this.connections = new HashMap<>(); } diff --git a/test/jdk/java/net/httpclient/http2/server/Http2TestServerConnection.java b/test/jdk/java/net/httpclient/http2/server/Http2TestServerConnection.java index 3c3b9ea985b..5f0e83a1716 100644 --- a/test/jdk/java/net/httpclient/http2/server/Http2TestServerConnection.java +++ b/test/jdk/java/net/httpclient/http2/server/Http2TestServerConnection.java @@ -215,6 +215,10 @@ private PingRequest getNextRequest() { return pings.poll(); } + public void addToOutputQ(final Http2Frame frame) throws IOException { + outputQ.put(frame); + } + /** * Handles incoming Ping, which could be an ack * or a client originated Ping @@ -918,13 +922,25 @@ private void handlePush(OutgoingPushPromise op) throws IOException { final BodyOutputStream oo = new BodyOutputStream( promisedStreamid, clientSettings.getParameter( - SettingsFrame.INITIAL_WINDOW_SIZE), this); + SettingsFrame.INITIAL_WINDOW_SIZE), this) { + + @Override + protected void sendEndStream() throws IOException { + if (properties.getProperty("sendTrailingHeadersAfterPushPromise", "0").equals("1")) { + conn.outputQ.put(getTrailingHeadersFrame(promisedStreamid, List.of())); + } else { + super.sendEndStream(); + } + } + }; + outStreams.put(promisedStreamid, oo); oo.goodToGo(); exec.submit(() -> { try { ResponseHeaders oh = getPushResponse(promisedStreamid); outputQ.put(oh); + ii.transferTo(oo); } catch (Throwable ex) { System.err.printf("TestServer: pushing response error: %s\n", @@ -937,6 +953,11 @@ private void handlePush(OutgoingPushPromise op) throws IOException { } + private HeadersFrame getTrailingHeadersFrame(int promisedStreamid, List headerBlocks) { + // TODO: see if there is a safe way to encode headers without interrupting connection thread + return new HeadersFrame(promisedStreamid, (HeaderFrame.END_HEADERS | HeaderFrame.END_STREAM), headerBlocks); + } + // returns a minimal response with status 200 // that is the response to the push promise just sent private ResponseHeaders getPushResponse(int streamid) {