diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e73ffdf843a4..207abafcc742b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix 'org.apache.hc.core5.http.ParseException: Invalid protocol version' under JDK 16+ ([#4827](https://github.com/opensearch-project/OpenSearch/pull/4827)) - Fixed compression support for h2c protocol ([#4944](https://github.com/opensearch-project/OpenSearch/pull/4944)) - Reject bulk requests with invalid actions ([#5299](https://github.com/opensearch-project/OpenSearch/issues/5299)) +- Support OpenSSL Provider with default Netty allocator ([#5460](https://github.com/opensearch-project/OpenSearch/pull/5460)) ### Security diff --git a/modules/transport-netty4/src/main/java/org/opensearch/transport/Netty4NioServerSocketChannel.java b/modules/transport-netty4/src/main/java/org/opensearch/transport/Netty4NioServerSocketChannel.java new file mode 100644 index 0000000000000..8a8b1da6ef5dd --- /dev/null +++ b/modules/transport-netty4/src/main/java/org/opensearch/transport/Netty4NioServerSocketChannel.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport; + +import io.netty.channel.socket.InternetProtocolFamily; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.util.internal.SocketUtils; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.SelectorProvider; +import java.util.List; + +public class Netty4NioServerSocketChannel extends NioServerSocketChannel { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Netty4NioServerSocketChannel.class); + + public Netty4NioServerSocketChannel() { + super(); + } + + public Netty4NioServerSocketChannel(SelectorProvider provider) { + super(provider); + } + + public Netty4NioServerSocketChannel(SelectorProvider provider, InternetProtocolFamily family) { + super(provider, family); + } + + public Netty4NioServerSocketChannel(ServerSocketChannel channel) { + super(channel); + } + + @Override + protected int doReadMessages(List buf) throws Exception { + SocketChannel ch = SocketUtils.accept(javaChannel()); + + try { + if (ch != null) { + buf.add(new Netty4NioSocketChannel(this, ch)); + return 1; + } + } catch (Throwable t) { + logger.warn("Failed to create a new channel from an accepted socket.", t); + + try { + ch.close(); + } catch (Throwable t2) { + logger.warn("Failed to close a socket.", t2); + } + } + + return 0; + } +} diff --git a/modules/transport-netty4/src/main/java/org/opensearch/transport/NettyAllocator.java b/modules/transport-netty4/src/main/java/org/opensearch/transport/NettyAllocator.java index e25853d864813..f2f6538d305d9 100644 --- a/modules/transport-netty4/src/main/java/org/opensearch/transport/NettyAllocator.java +++ b/modules/transport-netty4/src/main/java/org/opensearch/transport/NettyAllocator.java @@ -39,7 +39,6 @@ import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ServerChannel; -import io.netty.channel.socket.nio.NioServerSocketChannel; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.Booleans; @@ -181,7 +180,7 @@ public static Class getServerChannelType() { if (ALLOCATOR instanceof NoDirectBuffers) { return CopyBytesServerSocketChannel.class; } else { - return NioServerSocketChannel.class; + return Netty4NioServerSocketChannel.class; } } diff --git a/server/src/main/java/org/opensearch/common/bytes/BytesReference.java b/server/src/main/java/org/opensearch/common/bytes/BytesReference.java index 85dcf949d479e..97100f905315b 100644 --- a/server/src/main/java/org/opensearch/common/bytes/BytesReference.java +++ b/server/src/main/java/org/opensearch/common/bytes/BytesReference.java @@ -122,8 +122,13 @@ static BytesReference fromByteBuffers(ByteBuffer[] buffers) { * Returns BytesReference composed of the provided ByteBuffer. */ static BytesReference fromByteBuffer(ByteBuffer buffer) { - assert buffer.hasArray(); - return new BytesArray(buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); + if (buffer.hasArray()) { + return new BytesArray(buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); + } else { + final byte[] array = new byte[buffer.remaining()]; + buffer.asReadOnlyBuffer().get(array, 0, buffer.remaining()); + return new BytesArray(array); + } } /** diff --git a/server/src/test/java/org/opensearch/common/bytes/ByteBuffersBytesReferenceTests.java b/server/src/test/java/org/opensearch/common/bytes/ByteBuffersBytesReferenceTests.java new file mode 100644 index 0000000000000..4665a8e113ff2 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/bytes/ByteBuffersBytesReferenceTests.java @@ -0,0 +1,79 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.bytes; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.hamcrest.Matchers; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collection; +import java.util.function.Function; + +public class ByteBuffersBytesReferenceTests extends AbstractBytesReferenceTestCase { + @ParametersFactory + public static Collection allocator() { + return Arrays.asList( + new Object[] { (Function) ByteBuffer::allocateDirect }, + new Object[] { (Function) ByteBuffer::allocate } + ); + } + + private final Function allocator; + + public ByteBuffersBytesReferenceTests(Function allocator) { + this.allocator = allocator; + } + + @Override + protected BytesReference newBytesReference(int length) throws IOException { + return newBytesReference(length, randomInt(length)); + } + + @Override + protected BytesReference newBytesReferenceWithOffsetOfZero(int length) throws IOException { + return newBytesReference(length, 0); + } + + private BytesReference newBytesReference(int length, int offset) throws IOException { + // we know bytes stream output always creates a paged bytes reference, we use it to create randomized content + final ByteBuffer buffer = allocator.apply(length + offset); + for (int i = 0; i < length + offset; i++) { + buffer.put((byte) random().nextInt(1 << 8)); + } + assertEquals(length + offset, buffer.limit()); + buffer.flip().position(offset); + + BytesReference ref = BytesReference.fromByteBuffer(buffer); + assertEquals(length, ref.length()); + assertTrue(ref instanceof BytesArray); + assertThat(ref.length(), Matchers.equalTo(length)); + return ref; + } + + public void testArray() throws IOException { + int[] sizes = { 0, randomInt(PAGE_SIZE), PAGE_SIZE, randomIntBetween(2, PAGE_SIZE * randomIntBetween(2, 5)) }; + + for (int i = 0; i < sizes.length; i++) { + BytesArray pbr = (BytesArray) newBytesReference(sizes[i]); + byte[] array = pbr.array(); + assertNotNull(array); + assertEquals(sizes[i], array.length - pbr.offset()); + assertSame(array, pbr.array()); + } + } + + public void testArrayOffset() throws IOException { + int length = randomInt(PAGE_SIZE * randomIntBetween(2, 5)); + BytesArray pbr = (BytesArray) newBytesReferenceWithOffsetOfZero(length); + assertEquals(0, pbr.offset()); + } +}